]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure compiler is not optional in create_for_statement()
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Mar 2025 22:01:15 +0000 (17:01 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Mar 2025 22:38:42 +0000 (17:38 -0500)
this involved moving some methods around and changing the
target of legacy orm/query.py calling upon this method to
use an ORM-specific method instead

Change-Id: Ib977f08e52398d0e082acf7d88abecb9908ca8b6

lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
test/ext/test_hybrid.py
test/orm/test_froms.py

index a67331fe80a49deadf1c4045e41ec04f14b4d555..158a81712b6ba84a0c568e7381a03d376840585c 100644 (file)
@@ -273,10 +273,10 @@ class _AbstractORMCompileState(CompileState):
     @classmethod
     def create_for_statement(
         cls,
-        statement: Union[Select, FromStatement],
-        compiler: Optional[SQLCompiler],
+        statement: Executable,
+        compiler: SQLCompiler,
         **kw: Any,
-    ) -> _AbstractORMCompileState:
+    ) -> CompileState:
         """Create a context for a statement given a :class:`.Compiler`.
 
         This method is always invoked in the context of SQLCompiler.process().
@@ -449,15 +449,30 @@ class _ORMCompileState(_AbstractORMCompileState):
     def __init__(self, *arg, **kw):
         raise NotImplementedError()
 
-    if TYPE_CHECKING:
+    @classmethod
+    def create_for_statement(
+        cls,
+        statement: Executable,
+        compiler: SQLCompiler,
+        **kw: Any,
+    ) -> _ORMCompileState:
+        return cls._create_orm_context(
+            cast("Union[Select, FromStatement]", statement),
+            toplevel=not compiler.stack,
+            compiler=compiler,
+            **kw,
+        )
 
-        @classmethod
-        def create_for_statement(
-            cls,
-            statement: Union[Select, FromStatement],
-            compiler: Optional[SQLCompiler],
-            **kw: Any,
-        ) -> _ORMCompileState: ...
+    @classmethod
+    def _create_orm_context(
+        cls,
+        statement: Union[Select, FromStatement],
+        *,
+        toplevel: bool,
+        compiler: Optional[SQLCompiler],
+        **kw: Any,
+    ) -> _ORMCompileState:
+        raise NotImplementedError()
 
     def _append_dedupe_col_collection(self, obj, col_collection):
         dedupe = self.dedupe_columns
@@ -767,12 +782,16 @@ class _ORMFromStatementCompileState(_ORMCompileState):
     eager_joins = _EMPTY_DICT
 
     @classmethod
-    def create_for_statement(
+    def _create_orm_context(
         cls,
-        statement_container: Union[Select, FromStatement],
+        statement: Union[Select, FromStatement],
+        *,
+        toplevel: bool,
         compiler: Optional[SQLCompiler],
         **kw: Any,
     ) -> _ORMFromStatementCompileState:
+        statement_container = statement
+
         assert isinstance(statement_container, FromStatement)
 
         if compiler is not None and compiler.stack:
@@ -1079,21 +1098,17 @@ class _ORMSelectCompileState(_ORMCompileState, SelectState):
     _having_criteria = ()
 
     @classmethod
-    def create_for_statement(
+    def _create_orm_context(
         cls,
         statement: Union[Select, FromStatement],
+        *,
+        toplevel: bool,
         compiler: Optional[SQLCompiler],
         **kw: Any,
     ) -> _ORMSelectCompileState:
-        """compiler hook, we arrive here from compiler.visit_select() only."""
 
         self = cls.__new__(cls)
 
-        if compiler is not None:
-            toplevel = not compiler.stack
-        else:
-            toplevel = True
-
         select_statement = statement
 
         # if we are a select() that was never a legacy Query, we won't
index ac6746adba941e3c68e95be33a6b566d34baaa4f..28c282b4872e9199a00a377e103e41f61b36ae2c 100644 (file)
@@ -3361,7 +3361,9 @@ class Query(
             _ORMCompileState._get_plugin_class_for_plugin(stmt, "orm"),
         )
 
-        return compile_state_cls.create_for_statement(stmt, None)
+        return compile_state_cls._create_orm_context(
+            stmt, toplevel=True, compiler=None
+        )
 
     def _compile_context(self, for_statement: bool = False) -> QueryContext:
         compile_state = self._compile_state(for_statement=for_statement)
index a93ea4e42e8a0452385c8af5e5bb80d72ad3e26d..801814f334c91a66bfa45a32b178da4f14f9f047 100644 (file)
@@ -67,6 +67,7 @@ if TYPE_CHECKING:
     from ._orm_types import DMLStrategyArgument
     from ._orm_types import SynchronizeSessionArgument
     from ._typing import _CLE
+    from .compiler import SQLCompiler
     from .elements import BindParameter
     from .elements import ClauseList
     from .elements import ColumnClause  # noqa
@@ -656,7 +657,9 @@ class CompileState:
     _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap]
 
     @classmethod
-    def create_for_statement(cls, statement, compiler, **kw):
+    def create_for_statement(
+        cls, statement: Executable, compiler: SQLCompiler, **kw: Any
+    ) -> CompileState:
         # factory construction.
 
         if statement._propagate_attrs:
index 825123a977eee99242b55165de60e971fbcbb96f..bd92f6aa854e0ae8c64174c26a5e3174c6f67243 100644 (file)
@@ -300,8 +300,7 @@ class CompilerElement(Visitable):
             if bind:
                 dialect = bind.dialect
             elif self.stringify_dialect == "default":
-                default = util.preloaded.engine_default
-                dialect = default.StrCompileDialect()
+                dialect = self._default_dialect()
             else:
                 url = util.preloaded.engine_url
                 dialect = url.URL.create(
@@ -310,6 +309,10 @@ class CompilerElement(Visitable):
 
         return self._compiler(dialect, **kw)
 
+    def _default_dialect(self):
+        default = util.preloaded.engine_default
+        return default.StrCompileDialect()
+
     def _compiler(self, dialect: Dialect, **kw: Any) -> Compiled:
         """Return a compiler appropriate for this ClauseElement, given a
         Dialect."""
@@ -406,6 +409,10 @@ class ClauseElement(
         self._propagate_attrs = util.immutabledict(values)
         return self
 
+    def _default_compiler(self) -> SQLCompiler:
+        dialect = self._default_dialect()
+        return dialect.statement_compiler(dialect, self)  # type: ignore
+
     def _clone(self, **kw: Any) -> Self:
         """Create a shallow copy of this ClauseElement.
 
index c3255a8f1834071bc3533d5ce0cd7cf0fe3cf87b..e53b2bbccc1cc714b0eaa9d8a57718766ddbf711 100644 (file)
@@ -4661,7 +4661,7 @@ class SelectState(util.MemoizedSlots, CompileState):
     def __init__(
         self,
         statement: Select[Unpack[TupleAny]],
-        compiler: Optional[SQLCompiler],
+        compiler: SQLCompiler,
         **kw: Any,
     ):
         self.statement = statement
@@ -5717,8 +5717,9 @@ class Select(
             :attr:`_sql.Select.columns_clause_froms`
 
         """
+        compiler = self._default_compiler()
 
-        return self._compile_state_factory(self, None)._get_display_froms()
+        return self._compile_state_factory(self, compiler)._get_display_froms()
 
     @property
     @util.deprecated(
index 8e3d7e9cd57a193c829d42c68eff4408e6bad01d..09da020743f03389890b492ea5e02cc9b341e0ce 100644 (file)
@@ -22,6 +22,7 @@ from sqlalchemy.orm import declared_attr
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import synonym
+from sqlalchemy.orm.context import _ORMSelectCompileState
 from sqlalchemy.sql import coercions
 from sqlalchemy.sql import operators
 from sqlalchemy.sql import roles
@@ -531,7 +532,9 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT a.id, a.foo FROM a",
         )
 
-        compile_state = stmt._compile_state_factory(stmt, None)
+        compile_state = _ORMSelectCompileState._create_orm_context(
+            stmt, toplevel=True, compiler=None
+        )
         eq_(
             compile_state._column_naming_convention(
                 LABEL_STYLE_DISAMBIGUATE_ONLY, legacy=False
index 9a1ff1ee442014303810bb37ba5a77dee0c2d4d1..ae0c147c715a0bec12ae659894c7387eb853cfac 100644 (file)
@@ -1893,7 +1893,9 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
                 .order_by(User.id)
             )
 
-        compile_state = _ORMSelectCompileState.create_for_statement(stmt, None)
+        compile_state = _ORMSelectCompileState._create_orm_context(
+            stmt, toplevel=True, compiler=None
+        )
         is_(compile_state._primary_entity, None)
 
     def test_column_queries_one(self):