]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement is_derived_from() for DML
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Feb 2025 16:38:53 +0000 (11:38 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Feb 2025 19:43:24 +0000 (14:43 -0500)
Fixed bug where using DML returning such as :meth:`.Insert.returning` with
an ORM model that has :func:`_orm.column_property` constructs that contain
subqueries would fail with an internal error.

Fixes: #12326
Change-Id: I419f645769a346c229944b30ac8fd4a0efe1646d
(cherry picked from commit b281402140683279c2aca2363f2acdb94929507f)

doc/build/changelog/unreleased_20/12326.rst [new file with mode: 0644]
lib/sqlalchemy/sql/dml.py
test/orm/dml/test_bulk_statements.py

diff --git a/doc/build/changelog/unreleased_20/12326.rst b/doc/build/changelog/unreleased_20/12326.rst
new file mode 100644 (file)
index 0000000..88e5de2
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 12326
+
+    Fixed bug where using DML returning such as :meth:`.Insert.returning` with
+    an ORM model that has :func:`_orm.column_property` constructs that contain
+    subqueries would fail with an internal error.
index f0e6edbb560324525e1a30286d9c77e910c9df5b..f5071146be2464e2def98f73f4372c8a35ce2999 100644 (file)
@@ -695,6 +695,16 @@ class UpdateBase(
 
         return self
 
+    def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
+        """Return ``True`` if this :class:`.ReturnsRows` is
+        'derived' from the given :class:`.FromClause`.
+
+        Since these are DMLs, we dont want such statements ever being adapted
+        so we return False for derives.
+
+        """
+        return False
+
     @_generative
     def returning(
         self,
index 992a18947b793b1583b3ca02d8d0833fca7321c0..6d69b2250c3fab6033b097800dfbc06d21df3e3b 100644 (file)
@@ -277,6 +277,86 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
             ),
         )
 
+    @testing.requires.insert_returning
+    @testing.variation(
+        "insert_type",
+        [("values", testing.requires.multivalues_inserts), "bulk"],
+    )
+    def test_returning_col_property(
+        self, decl_base, insert_type: testing.Variation
+    ):
+        """test #12326"""
+
+        class User(ComparableEntity, decl_base):
+            __tablename__ = "user"
+
+            id: Mapped[int] = mapped_column(
+                primary_key=True, autoincrement=False
+            )
+            name: Mapped[str]
+            age: Mapped[int]
+
+        decl_base.metadata.create_all(testing.db)
+
+        a_alias = aliased(User)
+        User.colprop = column_property(
+            select(func.max(a_alias.age))
+            .where(a_alias.id != User.id)
+            .scalar_subquery()
+        )
+
+        sess = fixture_session()
+
+        if insert_type.values:
+            stmt = insert(User).values(
+                [
+                    dict(id=1, name="john", age=25),
+                    dict(id=2, name="jack", age=47),
+                    dict(id=3, name="jill", age=29),
+                    dict(id=4, name="jane", age=37),
+                ],
+            )
+            params = None
+        elif insert_type.bulk:
+            stmt = insert(User)
+            params = [
+                dict(id=1, name="john", age=25),
+                dict(id=2, name="jack", age=47),
+                dict(id=3, name="jill", age=29),
+                dict(id=4, name="jane", age=37),
+            ]
+        else:
+            insert_type.fail()
+
+        stmt = stmt.returning(User)
+
+        result = sess.execute(stmt, params=params)
+
+        # the RETURNING doesn't have the column property in it.
+        # so to load these, they are all lazy loaded
+        with self.sql_execution_asserter() as asserter:
+            eq_(
+                result.scalars().all(),
+                [
+                    User(id=1, name="john", age=25, colprop=47),
+                    User(id=2, name="jack", age=47, colprop=37),
+                    User(id=3, name="jill", age=29, colprop=47),
+                    User(id=4, name="jane", age=37, colprop=47),
+                ],
+            )
+
+        # assert they're all lazy loaded
+        asserter.assert_(
+            *[
+                CompiledSQL(
+                    'SELECT (SELECT max(user_1.age) AS max_1 FROM "user" '
+                    'AS user_1 WHERE user_1.id != "user".id) AS anon_1 '
+                    'FROM "user" WHERE "user".id = :pk_1'
+                )
+                for i in range(4)
+            ]
+        )
+
     @testing.requires.insert_returning
     @testing.requires.returning_star
     @testing.variation(
@@ -1080,6 +1160,47 @@ class UpdateStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
             ],
         )
 
+    @testing.requires.update_returning
+    def test_returning_col_property(self, decl_base):
+        """test #12326"""
+
+        class User(ComparableEntity, decl_base):
+            __tablename__ = "user"
+
+            id: Mapped[int] = mapped_column(
+                primary_key=True, autoincrement=False
+            )
+            name: Mapped[str]
+            age: Mapped[int]
+
+        decl_base.metadata.create_all(testing.db)
+
+        a_alias = aliased(User)
+        User.colprop = column_property(
+            select(func.max(a_alias.age))
+            .where(a_alias.id != User.id)
+            .scalar_subquery()
+        )
+
+        sess = fixture_session()
+
+        sess.execute(
+            insert(User),
+            [
+                dict(id=1, name="john", age=25),
+                dict(id=2, name="jack", age=47),
+                dict(id=3, name="jill", age=29),
+                dict(id=4, name="jane", age=37),
+            ],
+        )
+
+        stmt = (
+            update(User).values(age=30).where(User.age == 29).returning(User)
+        )
+
+        row = sess.execute(stmt).one()
+        eq_(row[0], User(id=3, name="jill", age=30, colprop=47))
+
 
 class BulkDMLReturningInhTest:
     use_sentinel = False