From: Mike Bayer Date: Mon, 3 Mar 2025 22:01:15 +0000 (-0500) Subject: ensure compiler is not optional in create_for_statement() X-Git-Tag: rel_2_0_39~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f998ae83d2c3dcd7f625e3d6a611cf2f7c56907c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git ensure compiler is not optional in create_for_statement() this involved moving some methods around and changing the target of legacy orm/query.py calling upon this method to use an ORM-specific method instead (cherry picked from commit d9b4d8ff3aae504402d324f3ebf0b8faff78f5dc) Change-Id: I6f83a5b0e8f43a3eb633216c2f2fe2d28345e9bd --- diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 5e91cdf9e1..b04d6d48c2 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -265,10 +265,10 @@ class AbstractORMCompileState(CompileState): @classmethod def create_for_statement( cls, - statement: Union[Select, FromStatement], - compiler: Optional[SQLCompiler], + statement: Executable, + compiler: SQLCompiler, **kw: Any, - ) -> AbstractORMCompileState: + ) -> CompileState: """Create a context for a statement given a :class:`.Compiler`. This method is always invoked in the context of SQLCompiler.process(). @@ -437,15 +437,30 @@ class ORMCompileState(AbstractORMCompileState): def __init__(self, *arg, **kw): raise NotImplementedError() - if TYPE_CHECKING: + @classmethod + def create_for_statement( + cls, + statement: Executable, + compiler: SQLCompiler, + **kw: Any, + ) -> ORMCompileState: + return cls._create_orm_context( + cast("Union[Select, FromStatement]", statement), + toplevel=not compiler.stack, + compiler=compiler, + **kw, + ) - @classmethod - def create_for_statement( - cls, - statement: Union[Select, FromStatement], - compiler: Optional[SQLCompiler], - **kw: Any, - ) -> ORMCompileState: ... + @classmethod + def _create_orm_context( + cls, + statement: Union[Select, FromStatement], + *, + toplevel: bool, + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: + raise NotImplementedError() def _append_dedupe_col_collection(self, obj, col_collection): dedupe = self.dedupe_columns @@ -755,12 +770,16 @@ class ORMFromStatementCompileState(ORMCompileState): eager_joins = _EMPTY_DICT @classmethod - def create_for_statement( + def _create_orm_context( cls, - statement_container: Union[Select, FromStatement], + statement: Union[Select, FromStatement], + *, + toplevel: bool, compiler: Optional[SQLCompiler], **kw: Any, ) -> ORMFromStatementCompileState: + statement_container = statement + assert isinstance(statement_container, FromStatement) if compiler is not None and compiler.stack: @@ -1067,21 +1086,17 @@ class ORMSelectCompileState(ORMCompileState, SelectState): _having_criteria = () @classmethod - def create_for_statement( + def _create_orm_context( cls, statement: Union[Select, FromStatement], + *, + toplevel: bool, compiler: Optional[SQLCompiler], **kw: Any, ) -> ORMSelectCompileState: - """compiler hook, we arrive here from compiler.visit_select() only.""" self = cls.__new__(cls) - if compiler is not None: - toplevel = not compiler.stack - else: - toplevel = True - select_statement = statement # if we are a select() that was never a legacy Query, we won't diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 4dbb3009b3..af496b245f 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -3340,7 +3340,9 @@ class Query( ORMCompileState._get_plugin_class_for_plugin(stmt, "orm"), ) - return compile_state_cls.create_for_statement(stmt, None) + return compile_state_cls._create_orm_context( + stmt, toplevel=True, compiler=None + ) def _compile_context(self, for_statement: bool = False) -> QueryContext: compile_state = self._compile_state(for_statement=for_statement) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 6d409a9fb7..7ccef84e0d 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -68,6 +68,7 @@ if TYPE_CHECKING: from ._orm_types import DMLStrategyArgument from ._orm_types import SynchronizeSessionArgument from ._typing import _CLE + from .compiler import SQLCompiler from .elements import BindParameter from .elements import ClauseList from .elements import ColumnClause # noqa @@ -657,7 +658,9 @@ class CompileState: _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] @classmethod - def create_for_statement(cls, statement, compiler, **kw): + def create_for_statement( + cls, statement: Executable, compiler: SQLCompiler, **kw: Any + ) -> CompileState: # factory construction. if statement._propagate_attrs: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index fde503aaf9..cd1dc34e0a 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -298,8 +298,7 @@ class CompilerElement(Visitable): if bind: dialect = bind.dialect elif self.stringify_dialect == "default": - default = util.preloaded.engine_default - dialect = default.StrCompileDialect() + dialect = self._default_dialect() else: url = util.preloaded.engine_url dialect = url.URL.create( @@ -308,6 +307,10 @@ class CompilerElement(Visitable): return self._compiler(dialect, **kw) + def _default_dialect(self): + default = util.preloaded.engine_default + return default.StrCompileDialect() + def _compiler(self, dialect: Dialect, **kw: Any) -> Compiled: """Return a compiler appropriate for this ClauseElement, given a Dialect.""" @@ -404,6 +407,10 @@ class ClauseElement( self._propagate_attrs = util.immutabledict(values) return self + def _default_compiler(self) -> SQLCompiler: + dialect = self._default_dialect() + return dialect.statement_compiler(dialect, self) # type: ignore + def _clone(self, **kw: Any) -> Self: """Create a shallow copy of this ClauseElement. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 5db1e729e7..d137ab504e 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -4694,7 +4694,7 @@ class SelectState(util.MemoizedSlots, CompileState): def __init__( self, statement: Select[Any], - compiler: Optional[SQLCompiler], + compiler: SQLCompiler, **kw: Any, ): self.statement = statement @@ -5742,8 +5742,9 @@ class Select( :attr:`_sql.Select.columns_clause_froms` """ + compiler = self._default_compiler() - return self._compile_state_factory(self, None)._get_display_froms() + return self._compile_state_factory(self, compiler)._get_display_froms() @property @util.deprecated( diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index 8e3d7e9cd5..f6ad0de8d4 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -22,6 +22,7 @@ from sqlalchemy.orm import declared_attr from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm import synonym +from sqlalchemy.orm.context import ORMSelectCompileState from sqlalchemy.sql import coercions from sqlalchemy.sql import operators from sqlalchemy.sql import roles @@ -531,7 +532,9 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT a.id, a.foo FROM a", ) - compile_state = stmt._compile_state_factory(stmt, None) + compile_state = ORMSelectCompileState._create_orm_context( + stmt, toplevel=True, compiler=None + ) eq_( compile_state._column_naming_convention( LABEL_STYLE_DISAMBIGUATE_ONLY, legacy=False diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index 51c86a5f1d..e0d75db7e1 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -1893,7 +1893,9 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): .order_by(User.id) ) - compile_state = ORMSelectCompileState.create_for_statement(stmt, None) + compile_state = ORMSelectCompileState._create_orm_context( + stmt, toplevel=True, compiler=None + ) is_(compile_state._primary_entity, None) def test_column_queries_one(self):