From: Mike Bayer Date: Mon, 3 Mar 2025 22:01:15 +0000 (-0500) Subject: ensure compiler is not optional in create_for_statement() X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=d9b4d8ff3aae504402d324f3ebf0b8faff78f5dc;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git ensure compiler is not optional in create_for_statement() this involved moving some methods around and changing the target of legacy orm/query.py calling upon this method to use an ORM-specific method instead Change-Id: Ib977f08e52398d0e082acf7d88abecb9908ca8b6 --- diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index a67331fe80..158a81712b 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -273,10 +273,10 @@ class _AbstractORMCompileState(CompileState): @classmethod def create_for_statement( cls, - statement: Union[Select, FromStatement], - compiler: Optional[SQLCompiler], + statement: Executable, + compiler: SQLCompiler, **kw: Any, - ) -> _AbstractORMCompileState: + ) -> CompileState: """Create a context for a statement given a :class:`.Compiler`. This method is always invoked in the context of SQLCompiler.process(). @@ -449,15 +449,30 @@ class _ORMCompileState(_AbstractORMCompileState): def __init__(self, *arg, **kw): raise NotImplementedError() - if TYPE_CHECKING: + @classmethod + def create_for_statement( + cls, + statement: Executable, + compiler: SQLCompiler, + **kw: Any, + ) -> _ORMCompileState: + return cls._create_orm_context( + cast("Union[Select, FromStatement]", statement), + toplevel=not compiler.stack, + compiler=compiler, + **kw, + ) - @classmethod - def create_for_statement( - cls, - statement: Union[Select, FromStatement], - compiler: Optional[SQLCompiler], - **kw: Any, - ) -> _ORMCompileState: ... + @classmethod + def _create_orm_context( + cls, + statement: Union[Select, FromStatement], + *, + toplevel: bool, + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> _ORMCompileState: + raise NotImplementedError() def _append_dedupe_col_collection(self, obj, col_collection): dedupe = self.dedupe_columns @@ -767,12 +782,16 @@ class _ORMFromStatementCompileState(_ORMCompileState): eager_joins = _EMPTY_DICT @classmethod - def create_for_statement( + def _create_orm_context( cls, - statement_container: Union[Select, FromStatement], + statement: Union[Select, FromStatement], + *, + toplevel: bool, compiler: Optional[SQLCompiler], **kw: Any, ) -> _ORMFromStatementCompileState: + statement_container = statement + assert isinstance(statement_container, FromStatement) if compiler is not None and compiler.stack: @@ -1079,21 +1098,17 @@ class _ORMSelectCompileState(_ORMCompileState, SelectState): _having_criteria = () @classmethod - def create_for_statement( + def _create_orm_context( cls, statement: Union[Select, FromStatement], + *, + toplevel: bool, compiler: Optional[SQLCompiler], **kw: Any, ) -> _ORMSelectCompileState: - """compiler hook, we arrive here from compiler.visit_select() only.""" self = cls.__new__(cls) - if compiler is not None: - toplevel = not compiler.stack - else: - toplevel = True - select_statement = statement # if we are a select() that was never a legacy Query, we won't diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index ac6746adba..28c282b487 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -3361,7 +3361,9 @@ class Query( _ORMCompileState._get_plugin_class_for_plugin(stmt, "orm"), ) - return compile_state_cls.create_for_statement(stmt, None) + return compile_state_cls._create_orm_context( + stmt, toplevel=True, compiler=None + ) def _compile_context(self, for_statement: bool = False) -> QueryContext: compile_state = self._compile_state(for_statement=for_statement) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a93ea4e42e..801814f334 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -67,6 +67,7 @@ if TYPE_CHECKING: from ._orm_types import DMLStrategyArgument from ._orm_types import SynchronizeSessionArgument from ._typing import _CLE + from .compiler import SQLCompiler from .elements import BindParameter from .elements import ClauseList from .elements import ColumnClause # noqa @@ -656,7 +657,9 @@ class CompileState: _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] @classmethod - def create_for_statement(cls, statement, compiler, **kw): + def create_for_statement( + cls, statement: Executable, compiler: SQLCompiler, **kw: Any + ) -> CompileState: # factory construction. if statement._propagate_attrs: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 825123a977..bd92f6aa85 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -300,8 +300,7 @@ class CompilerElement(Visitable): if bind: dialect = bind.dialect elif self.stringify_dialect == "default": - default = util.preloaded.engine_default - dialect = default.StrCompileDialect() + dialect = self._default_dialect() else: url = util.preloaded.engine_url dialect = url.URL.create( @@ -310,6 +309,10 @@ class CompilerElement(Visitable): return self._compiler(dialect, **kw) + def _default_dialect(self): + default = util.preloaded.engine_default + return default.StrCompileDialect() + def _compiler(self, dialect: Dialect, **kw: Any) -> Compiled: """Return a compiler appropriate for this ClauseElement, given a Dialect.""" @@ -406,6 +409,10 @@ class ClauseElement( self._propagate_attrs = util.immutabledict(values) return self + def _default_compiler(self) -> SQLCompiler: + dialect = self._default_dialect() + return dialect.statement_compiler(dialect, self) # type: ignore + def _clone(self, **kw: Any) -> Self: """Create a shallow copy of this ClauseElement. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index c3255a8f18..e53b2bbccc 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -4661,7 +4661,7 @@ class SelectState(util.MemoizedSlots, CompileState): def __init__( self, statement: Select[Unpack[TupleAny]], - compiler: Optional[SQLCompiler], + compiler: SQLCompiler, **kw: Any, ): self.statement = statement @@ -5717,8 +5717,9 @@ class Select( :attr:`_sql.Select.columns_clause_froms` """ + compiler = self._default_compiler() - return self._compile_state_factory(self, None)._get_display_froms() + return self._compile_state_factory(self, compiler)._get_display_froms() @property @util.deprecated( diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index 8e3d7e9cd5..09da020743 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -22,6 +22,7 @@ from sqlalchemy.orm import declared_attr from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm import synonym +from sqlalchemy.orm.context import _ORMSelectCompileState from sqlalchemy.sql import coercions from sqlalchemy.sql import operators from sqlalchemy.sql import roles @@ -531,7 +532,9 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT a.id, a.foo FROM a", ) - compile_state = stmt._compile_state_factory(stmt, None) + compile_state = _ORMSelectCompileState._create_orm_context( + stmt, toplevel=True, compiler=None + ) eq_( compile_state._column_naming_convention( LABEL_STYLE_DISAMBIGUATE_ONLY, legacy=False diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index 9a1ff1ee44..ae0c147c71 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -1893,7 +1893,9 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): .order_by(User.id) ) - compile_state = _ORMSelectCompileState.create_for_statement(stmt, None) + compile_state = _ORMSelectCompileState._create_orm_context( + stmt, toplevel=True, compiler=None + ) is_(compile_state._primary_entity, None) def test_column_queries_one(self):