]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure compiler is not optional in create_for_statement()
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Mar 2025 22:01:15 +0000 (17:01 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Mar 2025 21:40:20 +0000 (16:40 -0500)
this involved moving some methods around and changing the
target of legacy orm/query.py calling upon this method to
use an ORM-specific method instead

(cherry picked from commit d9b4d8ff3aae504402d324f3ebf0b8faff78f5dc)
Change-Id: I6f83a5b0e8f43a3eb633216c2f2fe2d28345e9bd

lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
test/ext/test_hybrid.py
test/orm/test_froms.py

index 5e91cdf9e144636865a3523a17b30b044116e811..b04d6d48c28f7523b0357e2d6a86c0d5df58d2f3 100644 (file)
@@ -265,10 +265,10 @@ class AbstractORMCompileState(CompileState):
     @classmethod
     def create_for_statement(
         cls,
-        statement: Union[Select, FromStatement],
-        compiler: Optional[SQLCompiler],
+        statement: Executable,
+        compiler: SQLCompiler,
         **kw: Any,
-    ) -> AbstractORMCompileState:
+    ) -> CompileState:
         """Create a context for a statement given a :class:`.Compiler`.
 
         This method is always invoked in the context of SQLCompiler.process().
@@ -437,15 +437,30 @@ class ORMCompileState(AbstractORMCompileState):
     def __init__(self, *arg, **kw):
         raise NotImplementedError()
 
-    if TYPE_CHECKING:
+    @classmethod
+    def create_for_statement(
+        cls,
+        statement: Executable,
+        compiler: SQLCompiler,
+        **kw: Any,
+    ) -> ORMCompileState:
+        return cls._create_orm_context(
+            cast("Union[Select, FromStatement]", statement),
+            toplevel=not compiler.stack,
+            compiler=compiler,
+            **kw,
+        )
 
-        @classmethod
-        def create_for_statement(
-            cls,
-            statement: Union[Select, FromStatement],
-            compiler: Optional[SQLCompiler],
-            **kw: Any,
-        ) -> ORMCompileState: ...
+    @classmethod
+    def _create_orm_context(
+        cls,
+        statement: Union[Select, FromStatement],
+        *,
+        toplevel: bool,
+        compiler: Optional[SQLCompiler],
+        **kw: Any,
+    ) -> ORMCompileState:
+        raise NotImplementedError()
 
     def _append_dedupe_col_collection(self, obj, col_collection):
         dedupe = self.dedupe_columns
@@ -755,12 +770,16 @@ class ORMFromStatementCompileState(ORMCompileState):
     eager_joins = _EMPTY_DICT
 
     @classmethod
-    def create_for_statement(
+    def _create_orm_context(
         cls,
-        statement_container: Union[Select, FromStatement],
+        statement: Union[Select, FromStatement],
+        *,
+        toplevel: bool,
         compiler: Optional[SQLCompiler],
         **kw: Any,
     ) -> ORMFromStatementCompileState:
+        statement_container = statement
+
         assert isinstance(statement_container, FromStatement)
 
         if compiler is not None and compiler.stack:
@@ -1067,21 +1086,17 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
     _having_criteria = ()
 
     @classmethod
-    def create_for_statement(
+    def _create_orm_context(
         cls,
         statement: Union[Select, FromStatement],
+        *,
+        toplevel: bool,
         compiler: Optional[SQLCompiler],
         **kw: Any,
     ) -> ORMSelectCompileState:
-        """compiler hook, we arrive here from compiler.visit_select() only."""
 
         self = cls.__new__(cls)
 
-        if compiler is not None:
-            toplevel = not compiler.stack
-        else:
-            toplevel = True
-
         select_statement = statement
 
         # if we are a select() that was never a legacy Query, we won't
index 4dbb3009b39359b18965b7253472ae969cf922fc..af496b245f481bd2ce79667f10b2b5ef6fe86cce 100644 (file)
@@ -3340,7 +3340,9 @@ class Query(
             ORMCompileState._get_plugin_class_for_plugin(stmt, "orm"),
         )
 
-        return compile_state_cls.create_for_statement(stmt, None)
+        return compile_state_cls._create_orm_context(
+            stmt, toplevel=True, compiler=None
+        )
 
     def _compile_context(self, for_statement: bool = False) -> QueryContext:
         compile_state = self._compile_state(for_statement=for_statement)
index 6d409a9fb7ed04e1e913481265dbd8bf3183c2f3..7ccef84e0d54902d1b450077e22d14f498a0cc74 100644 (file)
@@ -68,6 +68,7 @@ if TYPE_CHECKING:
     from ._orm_types import DMLStrategyArgument
     from ._orm_types import SynchronizeSessionArgument
     from ._typing import _CLE
+    from .compiler import SQLCompiler
     from .elements import BindParameter
     from .elements import ClauseList
     from .elements import ColumnClause  # noqa
@@ -657,7 +658,9 @@ class CompileState:
     _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap]
 
     @classmethod
-    def create_for_statement(cls, statement, compiler, **kw):
+    def create_for_statement(
+        cls, statement: Executable, compiler: SQLCompiler, **kw: Any
+    ) -> CompileState:
         # factory construction.
 
         if statement._propagate_attrs:
index fde503aaf9bdf4f07cd6cf9ab15978b0a193ca87..cd1dc34e0a1d80e4273c2013c97ef2284fa5eb5e 100644 (file)
@@ -298,8 +298,7 @@ class CompilerElement(Visitable):
             if bind:
                 dialect = bind.dialect
             elif self.stringify_dialect == "default":
-                default = util.preloaded.engine_default
-                dialect = default.StrCompileDialect()
+                dialect = self._default_dialect()
             else:
                 url = util.preloaded.engine_url
                 dialect = url.URL.create(
@@ -308,6 +307,10 @@ class CompilerElement(Visitable):
 
         return self._compiler(dialect, **kw)
 
+    def _default_dialect(self):
+        default = util.preloaded.engine_default
+        return default.StrCompileDialect()
+
     def _compiler(self, dialect: Dialect, **kw: Any) -> Compiled:
         """Return a compiler appropriate for this ClauseElement, given a
         Dialect."""
@@ -404,6 +407,10 @@ class ClauseElement(
         self._propagate_attrs = util.immutabledict(values)
         return self
 
+    def _default_compiler(self) -> SQLCompiler:
+        dialect = self._default_dialect()
+        return dialect.statement_compiler(dialect, self)  # type: ignore
+
     def _clone(self, **kw: Any) -> Self:
         """Create a shallow copy of this ClauseElement.
 
index 5db1e729e7a3bb4251bbdfdce7d931e854331ded..d137ab504ea19d3544e1af542e666645fc1cd027 100644 (file)
@@ -4694,7 +4694,7 @@ class SelectState(util.MemoizedSlots, CompileState):
     def __init__(
         self,
         statement: Select[Any],
-        compiler: Optional[SQLCompiler],
+        compiler: SQLCompiler,
         **kw: Any,
     ):
         self.statement = statement
@@ -5742,8 +5742,9 @@ class Select(
             :attr:`_sql.Select.columns_clause_froms`
 
         """
+        compiler = self._default_compiler()
 
-        return self._compile_state_factory(self, None)._get_display_froms()
+        return self._compile_state_factory(self, compiler)._get_display_froms()
 
     @property
     @util.deprecated(
index 8e3d7e9cd57a193c829d42c68eff4408e6bad01d..f6ad0de8d4da11275421a5635030f4853ea77c08 100644 (file)
@@ -22,6 +22,7 @@ from sqlalchemy.orm import declared_attr
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import synonym
+from sqlalchemy.orm.context import ORMSelectCompileState
 from sqlalchemy.sql import coercions
 from sqlalchemy.sql import operators
 from sqlalchemy.sql import roles
@@ -531,7 +532,9 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT a.id, a.foo FROM a",
         )
 
-        compile_state = stmt._compile_state_factory(stmt, None)
+        compile_state = ORMSelectCompileState._create_orm_context(
+            stmt, toplevel=True, compiler=None
+        )
         eq_(
             compile_state._column_naming_convention(
                 LABEL_STYLE_DISAMBIGUATE_ONLY, legacy=False
index 51c86a5f1dad90dbc35e5603b5d147e46d46ffa0..e0d75db7e16e1b515a6307e5bd61fa7f98996d15 100644 (file)
@@ -1893,7 +1893,9 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
                 .order_by(User.id)
             )
 
-        compile_state = ORMSelectCompileState.create_for_statement(stmt, None)
+        compile_state = ORMSelectCompileState._create_orm_context(
+            stmt, toplevel=True, compiler=None
+        )
         is_(compile_state._primary_entity, None)
 
     def test_column_queries_one(self):