]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
match ORM mapped cols to PK in interpret_returning_rows
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 Oct 2024 15:15:21 +0000 (11:15 -0400)
committerMichael Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Oct 2024 15:42:20 +0000 (15:42 +0000)
Fixed bug in ORM "update with WHERE clause" feature where an explicit
``.returning()`` would interfere with the "fetch" synchronize strategy due
to an assumption that the ORM mapped class featured the primary key columns
in a specific position within the RETURNING.  This has been fixed to use
appropriate ORM column targeting.

the _interpret_returning_rows method looked to be mostly not used as far
as its joined inheritance features, which appear to have never been
used as joined inheritance mappers are skipped.

Fixes: #11997
Change-Id: I38fe3a84cdeb2eef38fe00d8b9a6a2b56f434bc6

doc/build/changelog/unreleased_20/11997.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
test/orm/dml/test_update_delete_where.py

diff --git a/doc/build/changelog/unreleased_20/11997.rst b/doc/build/changelog/unreleased_20/11997.rst
new file mode 100644 (file)
index 0000000..b239097
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11997
+
+    Fixed bug in ORM "update with WHERE clause" feature where an explicit
+    ``.returning()`` would interfere with the "fetch" synchronize strategy due
+    to an assumption that the ORM mapped class featured the primary key columns
+    in a specific position within the RETURNING.  This has been fixed to use
+    appropriate ORM column targeting.
index a9408f1cce2cbf786a253e5f13f82fab674adb69..3c033be5850d4238a3065ea46dd30602ce512910 100644 (file)
@@ -864,53 +864,39 @@ class BulkUDCompileState(ORMDMLState):
         return return_crit
 
     @classmethod
-    def _interpret_returning_rows(cls, mapper, rows):
-        """translate from local inherited table columns to base mapper
-        primary key columns.
+    def _interpret_returning_rows(cls, result, mapper, rows):
+        """return rows that indicate PK cols in mapper.primary_key position
+        for RETURNING rows.
 
-        Joined inheritance mappers always establish the primary key in terms of
-        the base table.   When we UPDATE a sub-table, we can only get
-        RETURNING for the sub-table's columns.
+        Prior to 2.0.36, this method seemed to be written for some kind of
+        inheritance scenario but the scenario was unused for actual joined
+        inheritance, and the function instead seemed to perform some kind of
+        partial translation that would remove non-PK cols if the PK cols
+        happened to be first in the row, but not otherwise.  The joined
+        inheritance walk feature here seems to have never been used as it was
+        always skipped by the "local_table" check.
 
-        Here, we create a lookup from the local sub table's primary key
-        columns to the base table PK columns so that we can get identity
-        key values from RETURNING that's against the joined inheritance
-        sub-table.
-
-        the complexity here is to support more than one level deep of
-        inheritance, where we have to link columns to each other across
-        the inheritance hierarchy.
+        As of 2.0.36 the function strips away non-PK cols and provides the
+        PK cols for the table in mapper PK order.
 
         """
 
-        if mapper.local_table is not mapper.base_mapper.local_table:
-            return rows
-
-        # this starts as a mapping of
-        # local_pk_col: local_pk_col.
-        # we will then iteratively rewrite the "value" of the dict with
-        # each successive superclass column
-        local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key}
-
-        for mp in mapper.iterate_to_root():
-            if mp.inherits is None:
-                break
-            elif mp.local_table is mp.inherits.local_table:
-                continue
-
-            t_to_e = dict(mp._table_to_equated[mp.inherits.local_table])
-            col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]}
-            for pk, super_ in local_pk_to_base_pk.items():
-                local_pk_to_base_pk[pk] = col_to_col[super_]
+        try:
+            if mapper.local_table is not mapper.base_mapper.local_table:
+                # TODO: dive more into how a local table PK is used for fetch
+                # sync, not clear if this is correct as it depends on the
+                # downstream routine to fetch rows using
+                # local_table.primary_key order
+                pk_keys = result._tuple_getter(mapper.local_table.primary_key)
+            else:
+                pk_keys = result._tuple_getter(mapper.primary_key)
+        except KeyError:
+            # can't use these rows, they don't have PK cols in them
+            # this is an unusual case where the user would have used
+            # .return_defaults()
+            return []
 
-        lookup = {
-            local_pk_to_base_pk[lpk]: idx
-            for idx, lpk in enumerate(mapper.local_table.primary_key)
-        }
-        primary_key_convert = [
-            lookup[bpk] for bpk in mapper.base_mapper.primary_key
-        ]
-        return [tuple(row[idx] for idx in primary_key_convert) for row in rows]
+        return [pk_keys(row) for row in rows]
 
     @classmethod
     def _get_matched_objects_on_criteria(cls, update_options, states):
@@ -1778,9 +1764,8 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
         returned_defaults_rows = result.returned_defaults_rows
         if returned_defaults_rows:
             pk_rows = cls._interpret_returning_rows(
-                target_mapper, returned_defaults_rows
+                result, target_mapper, returned_defaults_rows
             )
-
             matched_rows = [
                 tuple(row) + (update_options._identity_token,)
                 for row in pk_rows
@@ -2110,7 +2095,7 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
 
         if returned_defaults_rows:
             pk_rows = cls._interpret_returning_rows(
-                target_mapper, returned_defaults_rows
+                result, target_mapper, returned_defaults_rows
             )
 
             matched_rows = [
index 8d9feaf63c271a4602cbdbc7b695fb52fe0b7dce..da8efa44fa4f22fa574b41b328c1265a2abf8588 100644 (file)
@@ -3329,6 +3329,7 @@ class OnUpdatePopulationTest(fixtures.TestBase):
         ],
     )
     @testing.variation("synchronize", ["auto", "fetch", "evaluate"])
+    @testing.variation("pk_order", ["first", "middle"])
     def test_update_populate_existing(
         self,
         decl_base,
@@ -3336,15 +3337,20 @@ class OnUpdatePopulationTest(fixtures.TestBase):
         use_onupdate,
         use_returning,
         synchronize,
+        pk_order,
     ):
         """test #11912 and #11917"""
 
         class Employee(ComparableEntity, decl_base):
             __tablename__ = "employee"
 
-            uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True)
+            if pk_order.first:
+                uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True)
             user_name: Mapped[str] = mapped_column(String(200), nullable=False)
 
+            if pk_order.middle:
+                uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True)
+
             if use_onupdate.server:
                 some_server_value: Mapped[str] = mapped_column(
                     server_onupdate=FetchedValue()