]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
generalize adapt_on_names to expect non-named elements
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 10 Feb 2023 13:39:21 +0000 (08:39 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 10 Feb 2023 15:42:11 +0000 (10:42 -0500)
The fix in #9217 opened up adapt_on_names to more kinds of
expressions than it was prepared for; adjust that logic
and also refine in the ORM where we are using it, as we
dont need it (yet) for the DML RETURNING use case.

Fixed regression introduced in version 2.0.2 due to :ticket:`9217` where
using DML RETURNING statements, as well as
:meth:`_sql.Select.from_statement` constructs as was "fixed" in
:ticket:`9217`, in conjunction with ORM mapped classes that used
expressions such as with :func:`_orm.column_property`, would lead to an
internal error within Core where it would attempt to match the expression
by name. The fix repairs the Core issue, and also adjusts the fix in
:ticket:`9217` to not take effect for the DML RETURNING use case, where it
adds unnecessary overhead.

Fixes: #9273
Change-Id: Ie0344efb12ff7df48f21e71e62dc598c76a6a0de

doc/build/changelog/unreleased_20/9273.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/sql/util.py
test/orm/dml/test_bulk_statements.py
test/orm/test_froms.py
test/sql/test_external_traversal.py

diff --git a/doc/build/changelog/unreleased_20/9273.rst b/doc/build/changelog/unreleased_20/9273.rst
new file mode 100644 (file)
index 0000000..3038783
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 9273
+
+    Fixed regression introduced in version 2.0.2 due to :ticket:`9217` where
+    using DML RETURNING statements, as well as
+    :meth:`_sql.Select.from_statement` constructs as was "fixed" in
+    :ticket:`9217`, in conjunction with ORM mapped classes that used
+    expressions such as with :func:`_orm.column_property`, would lead to an
+    internal error within Core where it would attempt to match the expression
+    by name. The fix repairs the Core issue, and also adjusts the fix in
+    :ticket:`9217` to not take effect for the DML RETURNING use case, where it
+    adds unnecessary overhead.
index 324da56ec826c8b270678eb67c6520aa54c9093e..924c4cbc8e7b742f5749e60f9a0b0d3fcae308bb 100644 (file)
@@ -481,7 +481,9 @@ class ORMDMLState(AbstractORMCompileState):
         if orm_level_statement._returning:
 
             fs = FromStatement(
-                orm_level_statement._returning, dml_level_statement
+                orm_level_statement._returning,
+                dml_level_statement,
+                _adapt_on_names=False,
             )
             fs = fs.options(*orm_level_statement._with_options)
             self.select_statement = fs
index 0e631e66f73da21f34294c9a72a1aa5a6e491f93..e6f14daadcb390e23e785232ed7030f1d73a295a 100644 (file)
@@ -619,6 +619,8 @@ class ORMFromStatementCompileState(ORMCompileState):
         **kw: Any,
     ) -> ORMFromStatementCompileState:
 
+        assert isinstance(statement_container, FromStatement)
+
         if compiler is not None:
             toplevel = not compiler.stack
         else:
@@ -731,13 +733,13 @@ class ORMFromStatementCompileState(ORMCompileState):
             # those columns completely, don't interfere with the compiler
             # at all; just in ORM land, use an adapter to convert from
             # our ORM columns to whatever columns are in the statement,
-            # before we look in the result row. Always adapt on names
-            # to accept cases such as issue #9217.
-
+            # before we look in the result row. Adapt on names
+            # to accept cases such as issue #9217, however also allow
+            # this to be overridden for cases such as #9273.
             self._from_obj_alias = ORMStatementAdapter(
                 _TraceAdaptRole.ADAPT_FROM_STATEMENT,
                 self.statement,
-                adapt_on_names=True,
+                adapt_on_names=statement_container._adapt_on_names,
             )
 
         return self
@@ -781,6 +783,8 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
 
     element: Union[ExecutableReturnsRows, TextClause]
 
+    _adapt_on_names: bool
+
     _traverse_internals = [
         ("_raw_columns", InternalTraversal.dp_clauseelement_list),
         ("element", InternalTraversal.dp_clauseelement),
@@ -794,6 +798,7 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
         self,
         entities: Iterable[_ColumnsClauseArgument[Any]],
         element: Union[ExecutableReturnsRows, TextClause],
+        _adapt_on_names: bool = True,
     ):
         self._raw_columns = [
             coercions.expect(
@@ -809,6 +814,7 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
         self._label_style = (
             element._label_style if is_select_base(element) else None
         )
+        self._adapt_on_names = _adapt_on_names
 
     def _compiler_dispatch(self, compiler, **kw):
 
index 1dad9ce6846d478c9774eb4aed93f94ab2453390..0a50197a0d4a4493ee5dbed5b2d7f239dd878644 100644 (file)
@@ -54,6 +54,7 @@ from .elements import ColumnElement
 from .elements import Grouping
 from .elements import KeyedColumnElement
 from .elements import Label
+from .elements import NamedColumn
 from .elements import Null
 from .elements import UnaryExpression
 from .schema import Column
@@ -712,7 +713,6 @@ class _repr_params(_repr_base):
             return "(%s)" % elements
 
     def _get_batches(self, params: Iterable[Any]) -> Any:
-
         lparams = list(params)
         lenparams = len(lparams)
         if lenparams > self.max_params:
@@ -1122,7 +1122,6 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
     def _corresponding_column(
         self, col, require_embedded, _seen=util.EMPTY_SET
     ):
-
         newcol = self.selectable.corresponding_column(
             col, require_embedded=require_embedded
         )
@@ -1135,7 +1134,12 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
                 )
                 if newcol is not None:
                     return newcol
-        if self.adapt_on_names and newcol is None:
+
+        if (
+            self.adapt_on_names
+            and newcol is None
+            and isinstance(col, NamedColumn)
+        ):
             newcol = self.selectable.exported_columns.get(col.name)
         return newcol
 
index 78607e03d80354630534ef742ff0af297d6e0a32..0b26786d418147eae30a004ff12ba798654e7fa8 100644 (file)
@@ -11,12 +11,14 @@ from sqlalchemy import func
 from sqlalchemy import Identity
 from sqlalchemy import insert
 from sqlalchemy import inspect
+from sqlalchemy import literal
 from sqlalchemy import literal_column
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import update
 from sqlalchemy.orm import aliased
+from sqlalchemy.orm import column_property
 from sqlalchemy.orm import load_only
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
@@ -27,10 +29,11 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import provision
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 
 
-class NoReturningTest(fixtures.TestBase):
+class InsertStmtTest(fixtures.TestBase):
     def test_no_returning_error(self, decl_base):
         class A(fixtures.ComparableEntity, decl_base):
             __tablename__ = "a"
@@ -86,6 +89,38 @@ class NoReturningTest(fixtures.TestBase):
             [("d3", 5), ("d4", 6)],
         )
 
+    def test_insert_from_select_col_property(self, decl_base):
+        """test #9273"""
+
+        class User(ComparableEntity, decl_base):
+            __tablename__ = "users"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+            name: Mapped[str] = mapped_column()
+            age: Mapped[int] = mapped_column()
+
+            is_adult: Mapped[bool] = column_property(age >= 18)
+
+        decl_base.metadata.create_all(testing.db)
+
+        stmt = select(
+            literal(1).label("id"),
+            literal("John").label("name"),
+            literal(30).label("age"),
+        )
+
+        insert_stmt = (
+            insert(User)
+            .from_select(["id", "name", "age"], stmt)
+            .returning(User)
+        )
+
+        s = fixture_session()
+        result = s.scalars(insert_stmt)
+
+        eq_(result.all(), [User(id=1, name="John", age=30)])
+
 
 class BulkDMLReturningInhTest:
     def test_insert_col_key_also_works_currently(self):
index e24062469f25b6704f7c3753ba16b29b4056e46e..c2c237587f207d126ea25a435136c4ca7a2a53b1 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy import exists
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
+from sqlalchemy import literal
 from sqlalchemy import literal_column
 from sqlalchemy import select
 from sqlalchemy import String
@@ -25,6 +26,8 @@ from sqlalchemy.orm import configure_mappers
 from sqlalchemy.orm import contains_eager
 from sqlalchemy.orm import declarative_base
 from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.orm.context import ORMSelectCompileState
@@ -36,8 +39,8 @@ from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
-from sqlalchemy.testing import in_
 from sqlalchemy.testing import is_
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from test.orm import _fixtures
@@ -2728,7 +2731,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
             eq_(q.all(), expected)
 
     def test_unrelated_column(self):
-        """Test for 9217"""
+        """Test for #9217"""
 
         User = self.classes.User
 
@@ -2739,8 +2742,32 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
         s = select(User).from_statement(q)
         sess = fixture_session()
         res = sess.scalars(s).one()
-        in_("name", res.__dict__)
-        eq_(res.name, "sandy")
+        eq_(res, User(name="sandy", id=7))
+
+    def test_unrelated_column_col_prop(self, decl_base):
+        """Test for #9217 combined with #9273"""
+
+        class User(ComparableEntity, decl_base):
+            __tablename__ = "some_user_table"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+            name: Mapped[str] = mapped_column()
+            age: Mapped[int] = mapped_column()
+
+            is_adult: Mapped[bool] = column_property(age >= 18)
+
+        stmt = select(
+            literal(1).label("id"),
+            literal("John").label("name"),
+            literal(30).label("age"),
+        )
+
+        s = select(User).from_statement(stmt)
+        sess = fixture_session()
+        res = sess.scalars(s).one()
+
+        eq_(res, User(name="John", age=30, id=1))
 
     def test_expression_selectable_matches_mzero(self):
         User, Address = self.classes.User, self.classes.Address
index 8ccbd8d20e737ed7b008697c6304841ce14d9277..b8f6e5685abc4119c23aa79d28811137139b1665 100644 (file)
@@ -1194,7 +1194,6 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
     def test_this_thing_using_setup_joins_three(self):
-
         j = t1.join(t2, t1.c.col1 == t2.c.col2)
 
         s1 = select(j)
@@ -1239,7 +1238,6 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
     def test_this_thing_using_setup_joins_four(self):
-
         j = t1.join(t2, t1.c.col1 == t2.c.col2)
 
         s1 = select(j)
@@ -1606,6 +1604,36 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
         # not covered by a1, rejected by a2
         is_(a3.columns[c2a1], c2a1)
 
+    @testing.combinations(True, False, argnames="colpresent")
+    @testing.combinations(True, False, argnames="adapt_on_names")
+    @testing.combinations(True, False, argnames="use_label")
+    def test_adapt_binary_col(self, colpresent, use_label, adapt_on_names):
+        """test #9273"""
+
+        if use_label:
+            stmt = select(t1.c.col1, (t1.c.col2 > 18).label("foo"))
+        else:
+            stmt = select(t1.c.col1, (t1.c.col2 > 18))
+
+        sq = stmt.subquery()
+
+        if colpresent:
+            s2 = select(sq.c[0], sq.c[1])
+        else:
+            s2 = select(sq.c[0])
+
+        a1 = sql_util.ColumnAdapter(s2, adapt_on_names=adapt_on_names)
+
+        is_(a1.columns[stmt.selected_columns[0]], s2.selected_columns[0])
+
+        if colpresent:
+            is_(a1.columns[stmt.selected_columns[1]], s2.selected_columns[1])
+        else:
+            is_(
+                a1.columns[stmt.selected_columns[1]],
+                a1.columns[stmt.selected_columns[1]],
+            )
+
 
 class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
@@ -1735,7 +1763,6 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
     def test_adapt_select_w_unlabeled_fn(self):
-
         expr = func.count(t1.c.col1)
         stmt = select(t1, expr)
 
@@ -2335,7 +2362,6 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
         assert s2.is_derived_from(s1)
 
     def test_aliasedselect_to_aliasedselect_straight(self):
-
         # original issue from ticket #904
 
         s1 = select(t1).alias("foo")