]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix regression based on mis-match of set/frozenset
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 Jan 2023 23:15:04 +0000 (18:15 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 28 Jan 2023 03:02:36 +0000 (22:02 -0500)
Fixed regression where ORM models that used joined table inheritance with a
composite foreign key would encounter an internal error in the mapper
internals.

Fixes: #9164
Change-Id: I8fdcdf6d72f3304bee191498d5554555b0ab7855

doc/build/changelog/unreleased_20/9164.rst [new file with mode: 0644]
lib/sqlalchemy/orm/mapper.py
test/orm/inheritance/test_basic.py

diff --git a/doc/build/changelog/unreleased_20/9164.rst b/doc/build/changelog/unreleased_20/9164.rst
new file mode 100644 (file)
index 0000000..b11a27f
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 9164
+
+    Fixed regression where ORM models that used joined table inheritance with a
+    composite foreign key would encounter an internal error in the mapper
+    internals.
+
+
index dd3aef8d69785a029662e4ee34bebff3e5e6be68..a3b209e4a6507042d40de7394df0bc892c4f6375 100644 (file)
@@ -3920,12 +3920,16 @@ class Mapper(
             ],
         ] = util.defaultdict(list)
 
+        def set_union(x, y):
+            return x.union(y)
+
         for table in self._sorted_tables:
             cols = set(table.c)
+
             for m in self.iterate_to_root():
                 if m._inherits_equated_pairs and cols.intersection(
                     reduce(
-                        set.union,  # type: ignore
+                        set_union,
                         [l.proxy_set for l, r in m._inherits_equated_pairs],
                     )
                 ):
index 905f0c50d6b681e2bdf711a8f1af8d59eb2d6d3f..37368f3ad63ca2d56faf12a516e9369539b5dcae 100644 (file)
@@ -4,6 +4,7 @@ from sqlalchemy import column
 from sqlalchemy import event
 from sqlalchemy import exc as sa_exc
 from sqlalchemy import ForeignKey
+from sqlalchemy import ForeignKeyConstraint
 from sqlalchemy import func
 from sqlalchemy import inspect
 from sqlalchemy import Integer
@@ -22,6 +23,8 @@ from sqlalchemy.orm import declarative_base
 from sqlalchemy.orm import deferred
 from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import object_mapper
 from sqlalchemy.orm import polymorphic_union
 from sqlalchemy.orm import relationship
@@ -4056,6 +4059,59 @@ class UnexpectedPolymorphicIdentityTest(fixtures.DeclarativeMappedTest):
         )
 
 
+class CompositeJoinedInTest(fixtures.DeclarativeMappedTest):
+    """test #9164"""
+
+    run_setup_mappers = "once"
+    __dialect__ = "default"
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class A(fixtures.ComparableEntity, Base):
+            __tablename__ = "table_a"
+
+            order_id: Mapped[str] = mapped_column(primary_key=True)
+            _sku: Mapped[str] = mapped_column(primary_key=True)
+
+            __mapper_args__ = {
+                "polymorphic_identity": "a",
+                "polymorphic_on": "type",
+            }
+
+            type: Mapped[str]
+
+            def __init__(self, order_id: str, sku: str):
+                self.order_id = order_id
+                self._sku = sku
+
+        class B(A):
+            __tablename__ = "table_b"
+
+            _increment_id: Mapped[str] = mapped_column(primary_key=True)
+            _sku: Mapped[str] = mapped_column(primary_key=True)
+
+            __table_args__ = (
+                ForeignKeyConstraint(
+                    ["_increment_id", "_sku"],
+                    ["table_a.order_id", "table_a._sku"],
+                ),
+            )
+
+            __mapper_args__ = {"polymorphic_identity": "b"}
+
+    def test_round_trip(self):
+        B = self.classes.B
+
+        sess = fixture_session()
+        b1 = B(order_id="iid1", sku="sku1")
+        sess.add(b1)
+        sess.commit()
+
+        eq_(sess.scalar(select(B)), b1)
+
+
 class NameConflictTest(fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):