From eb0861e8e69f8ce702301c558e552e1aeb2e9eba Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 10 Feb 2023 08:39:21 -0500 Subject: [PATCH] generalize adapt_on_names to expect non-named elements The fix in #9217 opened up adapt_on_names to more kinds of expressions than it was prepared for; adjust that logic and also refine in the ORM where we are using it, as we dont need it (yet) for the DML RETURNING use case. Fixed regression introduced in version 2.0.2 due to :ticket:`9217` where using DML RETURNING statements, as well as :meth:`_sql.Select.from_statement` constructs as was "fixed" in :ticket:`9217`, in conjunction with ORM mapped classes that used expressions such as with :func:`_orm.column_property`, would lead to an internal error within Core where it would attempt to match the expression by name. The fix repairs the Core issue, and also adjusts the fix in :ticket:`9217` to not take effect for the DML RETURNING use case, where it adds unnecessary overhead. Fixes: #9273 Change-Id: Ie0344efb12ff7df48f21e71e62dc598c76a6a0de --- doc/build/changelog/unreleased_20/9273.rst | 13 ++++++++ lib/sqlalchemy/orm/bulk_persistence.py | 4 ++- lib/sqlalchemy/orm/context.py | 14 +++++--- lib/sqlalchemy/sql/util.py | 10 ++++-- test/orm/dml/test_bulk_statements.py | 37 +++++++++++++++++++++- test/orm/test_froms.py | 35 +++++++++++++++++--- test/sql/test_external_traversal.py | 34 +++++++++++++++++--- 7 files changed, 130 insertions(+), 17 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/9273.rst diff --git a/doc/build/changelog/unreleased_20/9273.rst b/doc/build/changelog/unreleased_20/9273.rst new file mode 100644 index 0000000000..3038783796 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9273.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, orm, regression + :tickets: 9273 + + Fixed regression introduced in version 2.0.2 due to :ticket:`9217` where + using DML RETURNING statements, as well as + :meth:`_sql.Select.from_statement` constructs as was "fixed" in + :ticket:`9217`, in conjunction with ORM mapped classes that used + expressions such as with :func:`_orm.column_property`, would lead to an + internal error within Core where it would attempt to match the expression + by name. The fix repairs the Core issue, and also adjusts the fix in + :ticket:`9217` to not take effect for the DML RETURNING use case, where it + adds unnecessary overhead. diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 324da56ec8..924c4cbc8e 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -481,7 +481,9 @@ class ORMDMLState(AbstractORMCompileState): if orm_level_statement._returning: fs = FromStatement( - orm_level_statement._returning, dml_level_statement + orm_level_statement._returning, + dml_level_statement, + _adapt_on_names=False, ) fs = fs.options(*orm_level_statement._with_options) self.select_statement = fs diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 0e631e66f7..e6f14daadc 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -619,6 +619,8 @@ class ORMFromStatementCompileState(ORMCompileState): **kw: Any, ) -> ORMFromStatementCompileState: + assert isinstance(statement_container, FromStatement) + if compiler is not None: toplevel = not compiler.stack else: @@ -731,13 +733,13 @@ class ORMFromStatementCompileState(ORMCompileState): # those columns completely, don't interfere with the compiler # at all; just in ORM land, use an adapter to convert from # our ORM columns to whatever columns are in the statement, - # before we look in the result row. Always adapt on names - # to accept cases such as issue #9217. - + # before we look in the result row. Adapt on names + # to accept cases such as issue #9217, however also allow + # this to be overridden for cases such as #9273. self._from_obj_alias = ORMStatementAdapter( _TraceAdaptRole.ADAPT_FROM_STATEMENT, self.statement, - adapt_on_names=True, + adapt_on_names=statement_container._adapt_on_names, ) return self @@ -781,6 +783,8 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): element: Union[ExecutableReturnsRows, TextClause] + _adapt_on_names: bool + _traverse_internals = [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), ("element", InternalTraversal.dp_clauseelement), @@ -794,6 +798,7 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): self, entities: Iterable[_ColumnsClauseArgument[Any]], element: Union[ExecutableReturnsRows, TextClause], + _adapt_on_names: bool = True, ): self._raw_columns = [ coercions.expect( @@ -809,6 +814,7 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): self._label_style = ( element._label_style if is_select_base(element) else None ) + self._adapt_on_names = _adapt_on_names def _compiler_dispatch(self, compiler, **kw): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 1dad9ce684..0a50197a0d 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -54,6 +54,7 @@ from .elements import ColumnElement from .elements import Grouping from .elements import KeyedColumnElement from .elements import Label +from .elements import NamedColumn from .elements import Null from .elements import UnaryExpression from .schema import Column @@ -712,7 +713,6 @@ class _repr_params(_repr_base): return "(%s)" % elements def _get_batches(self, params: Iterable[Any]) -> Any: - lparams = list(params) lenparams = len(lparams) if lenparams > self.max_params: @@ -1122,7 +1122,6 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): def _corresponding_column( self, col, require_embedded, _seen=util.EMPTY_SET ): - newcol = self.selectable.corresponding_column( col, require_embedded=require_embedded ) @@ -1135,7 +1134,12 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): ) if newcol is not None: return newcol - if self.adapt_on_names and newcol is None: + + if ( + self.adapt_on_names + and newcol is None + and isinstance(col, NamedColumn) + ): newcol = self.selectable.exported_columns.get(col.name) return newcol diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py index 78607e03d8..0b26786d41 100644 --- a/test/orm/dml/test_bulk_statements.py +++ b/test/orm/dml/test_bulk_statements.py @@ -11,12 +11,14 @@ from sqlalchemy import func from sqlalchemy import Identity from sqlalchemy import insert from sqlalchemy import inspect +from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import update from sqlalchemy.orm import aliased +from sqlalchemy.orm import column_property from sqlalchemy.orm import load_only from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -27,10 +29,11 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock from sqlalchemy.testing import provision from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session -class NoReturningTest(fixtures.TestBase): +class InsertStmtTest(fixtures.TestBase): def test_no_returning_error(self, decl_base): class A(fixtures.ComparableEntity, decl_base): __tablename__ = "a" @@ -86,6 +89,38 @@ class NoReturningTest(fixtures.TestBase): [("d3", 5), ("d4", 6)], ) + def test_insert_from_select_col_property(self, decl_base): + """test #9273""" + + class User(ComparableEntity, decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + + name: Mapped[str] = mapped_column() + age: Mapped[int] = mapped_column() + + is_adult: Mapped[bool] = column_property(age >= 18) + + decl_base.metadata.create_all(testing.db) + + stmt = select( + literal(1).label("id"), + literal("John").label("name"), + literal(30).label("age"), + ) + + insert_stmt = ( + insert(User) + .from_select(["id", "name", "age"], stmt) + .returning(User) + ) + + s = fixture_session() + result = s.scalars(insert_stmt) + + eq_(result.all(), [User(id=1, name="John", age=30)]) + class BulkDMLReturningInhTest: def test_insert_col_key_also_works_currently(self): diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index e24062469f..c2c237587f 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -6,6 +6,7 @@ from sqlalchemy import exists from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import select from sqlalchemy import String @@ -25,6 +26,8 @@ from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import contains_eager from sqlalchemy.orm import declarative_base from sqlalchemy.orm import joinedload +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm.context import ORMSelectCompileState @@ -36,8 +39,8 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from test.orm import _fixtures @@ -2728,7 +2731,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): eq_(q.all(), expected) def test_unrelated_column(self): - """Test for 9217""" + """Test for #9217""" User = self.classes.User @@ -2739,8 +2742,32 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): s = select(User).from_statement(q) sess = fixture_session() res = sess.scalars(s).one() - in_("name", res.__dict__) - eq_(res.name, "sandy") + eq_(res, User(name="sandy", id=7)) + + def test_unrelated_column_col_prop(self, decl_base): + """Test for #9217 combined with #9273""" + + class User(ComparableEntity, decl_base): + __tablename__ = "some_user_table" + + id: Mapped[int] = mapped_column(primary_key=True) + + name: Mapped[str] = mapped_column() + age: Mapped[int] = mapped_column() + + is_adult: Mapped[bool] = column_property(age >= 18) + + stmt = select( + literal(1).label("id"), + literal("John").label("name"), + literal(30).label("age"), + ) + + s = select(User).from_statement(stmt) + sess = fixture_session() + res = sess.scalars(s).one() + + eq_(res, User(name="John", age=30, id=1)) def test_expression_selectable_matches_mzero(self): User, Address = self.classes.User, self.classes.Address diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index 8ccbd8d20e..b8f6e5685a 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -1194,7 +1194,6 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_this_thing_using_setup_joins_three(self): - j = t1.join(t2, t1.c.col1 == t2.c.col2) s1 = select(j) @@ -1239,7 +1238,6 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_this_thing_using_setup_joins_four(self): - j = t1.join(t2, t1.c.col1 == t2.c.col2) s1 = select(j) @@ -1606,6 +1604,36 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL): # not covered by a1, rejected by a2 is_(a3.columns[c2a1], c2a1) + @testing.combinations(True, False, argnames="colpresent") + @testing.combinations(True, False, argnames="adapt_on_names") + @testing.combinations(True, False, argnames="use_label") + def test_adapt_binary_col(self, colpresent, use_label, adapt_on_names): + """test #9273""" + + if use_label: + stmt = select(t1.c.col1, (t1.c.col2 > 18).label("foo")) + else: + stmt = select(t1.c.col1, (t1.c.col2 > 18)) + + sq = stmt.subquery() + + if colpresent: + s2 = select(sq.c[0], sq.c[1]) + else: + s2 = select(sq.c[0]) + + a1 = sql_util.ColumnAdapter(s2, adapt_on_names=adapt_on_names) + + is_(a1.columns[stmt.selected_columns[0]], s2.selected_columns[0]) + + if colpresent: + is_(a1.columns[stmt.selected_columns[1]], s2.selected_columns[1]) + else: + is_( + a1.columns[stmt.selected_columns[1]], + a1.columns[stmt.selected_columns[1]], + ) + class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @@ -1735,7 +1763,6 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_adapt_select_w_unlabeled_fn(self): - expr = func.count(t1.c.col1) stmt = select(t1, expr) @@ -2335,7 +2362,6 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): assert s2.is_derived_from(s1) def test_aliasedselect_to_aliasedselect_straight(self): - # original issue from ticket #904 s1 = select(t1).alias("foo") -- 2.47.2