]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
repair Join.is_derived_from() to not rely on simple identity
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Jun 2021 21:47:07 +0000 (17:47 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Jun 2021 21:48:55 +0000 (17:48 -0400)
Fixed issue where query production for joinedload against a complex left
hand side involving joined-table inheritance could fail to produce a
correct query, due to a clause adaption issue.

Fixes: #6595
Change-Id: Id4b839d52447cdc103b392dd8946c4cfa7a829e1

doc/build/changelog/unreleased_14/6595.rst [new file with mode: 0644]
lib/sqlalchemy/sql/selectable.py
test/orm/test_eager_relations.py
test/sql/test_selectable.py

diff --git a/doc/build/changelog/unreleased_14/6595.rst b/doc/build/changelog/unreleased_14/6595.rst
new file mode 100644 (file)
index 0000000..a9f22da
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 6595
+
+    Fixed issue where query production for joinedload against a complex left
+    hand side involving joined-table inheritance could fail to produce a
+    correct query, due to a clause adaption issue.
index e24585fa0ab403912faac0dc0fff656414ab3167..1610191d1e7e58b3a95bf8cfe596cdda119a57f4 100644 (file)
@@ -1115,7 +1115,9 @@ class Join(roles.DMLTableRole, FromClause):
 
     def is_derived_from(self, fromclause):
         return (
-            fromclause is self
+            # use hash() to ensure direct comparison to annotated works
+            # as well
+            hash(fromclause) == hash(self)
             or self.left.is_derived_from(fromclause)
             or self.right.is_derived_from(fromclause)
         )
index 4e11f986378763c43b7e6a94208bad6cc6e0b1c1..f888a5129848c655ca0f67ee16dbb18f9bbe93c0 100644 (file)
@@ -5488,6 +5488,109 @@ class CyclicalInheritingEagerTestThree(
         )
 
 
+class LoadFromJoinedInhWUnion(
+    fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL
+):
+    """test for #6595"""
+
+    __dialect__ = "default"
+    run_create_tables = None
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Tag(Base):
+            __tablename__ = "tags"
+            id = Column(Integer, primary_key=True)
+            name = Column(String(50), primary_key=True)
+
+            sample_id = Column("sample_id", Integer, ForeignKey("sample.id"))
+
+        class BaseDataFile(Base):
+            __tablename__ = "base_data_file"
+            id = Column(Integer, primary_key=True)
+            type = Column(String(50))
+            __mapper_args__ = {
+                "polymorphic_identity": "base_data_file",
+                "polymorphic_on": type,
+            }
+
+        class Sample(BaseDataFile):
+            __tablename__ = "sample"
+            __mapper_args__ = {"polymorphic_identity": "sample"}
+            id = Column(
+                Integer,
+                ForeignKey("base_data_file.id"),
+                primary_key=True,
+            )
+            tags = relationship(
+                "Tag",
+            )
+
+    def test_one(self):
+        Sample = self.classes.Sample
+
+        session = fixture_session()
+        user_sample_query = session.query(Sample)
+
+        unioned = user_sample_query.union(user_sample_query)
+
+        q = unioned.options(joinedload(Sample.tags)).limit(10)
+
+        self.assert_compile(
+            q,
+            "SELECT anon_1.anon_2_sample_id AS anon_1_anon_2_sample_id, "
+            "anon_1.anon_2_base_data_file_type "
+            "AS anon_1_anon_2_base_data_file_type, "
+            "tags_1.id AS tags_1_id, tags_1.name AS tags_1_name, "
+            "tags_1.sample_id AS tags_1_sample_id FROM "
+            "(SELECT anon_2.sample_id AS anon_2_sample_id, "
+            "anon_2.base_data_file_type AS anon_2_base_data_file_type "
+            "FROM (SELECT sample.id AS sample_id, "
+            "base_data_file.id AS base_data_file_id, "
+            "base_data_file.type AS base_data_file_type "
+            "FROM base_data_file JOIN sample ON base_data_file.id = sample.id "
+            "UNION SELECT sample.id AS sample_id, "
+            "base_data_file.id AS base_data_file_id, "
+            "base_data_file.type AS base_data_file_type "
+            "FROM base_data_file "
+            "JOIN sample ON base_data_file.id = sample.id) AS anon_2 "
+            "LIMIT :param_1) AS anon_1 "
+            "LEFT OUTER JOIN tags AS tags_1 "
+            "ON anon_1.anon_2_sample_id = tags_1.sample_id",
+        )
+
+    def test_two(self):
+        Sample = self.classes.Sample
+
+        session = fixture_session()
+        user_sample_query = session.query(Sample)
+
+        unioned = user_sample_query.union(user_sample_query)
+
+        q = unioned.options(joinedload(Sample.tags))
+
+        self.assert_compile(
+            q,
+            "SELECT anon_1.sample_id AS anon_1_sample_id, "
+            "anon_1.base_data_file_type AS anon_1_base_data_file_type, "
+            "tags_1.id AS tags_1_id, tags_1.name AS tags_1_name, "
+            "tags_1.sample_id AS tags_1_sample_id "
+            "FROM (SELECT sample.id AS sample_id, "
+            "base_data_file.id AS base_data_file_id, "
+            "base_data_file.type AS base_data_file_type "
+            "FROM base_data_file JOIN sample ON base_data_file.id = sample.id "
+            "UNION SELECT sample.id AS sample_id, "
+            "base_data_file.id AS base_data_file_id, "
+            "base_data_file.type AS base_data_file_type "
+            "FROM base_data_file "
+            "JOIN sample ON base_data_file.id = sample.id) "
+            "AS anon_1 LEFT OUTER JOIN tags AS tags_1 "
+            "ON anon_1.sample_id = tags_1.sample_id",
+        )
+
+
 class EnsureColumnsAddedTest(
     fixtures.DeclarativeMappedTest, testing.AssertsCompiledSQL
 ):
index add07e01322c1235e3ae8dbfa9971d187734f4af..efa3be52374559710f0f9946fb99e3a8ca891a8f 100644 (file)
@@ -2594,6 +2594,38 @@ class DerivedTest(fixtures.TestBase, AssertsExecutionResults):
         assert select(t1, t2).alias("foo").is_derived_from(t1)
         assert not t2.select().alias("foo").is_derived_from(t1)
 
+    def test_join(self):
+        meta = MetaData()
+
+        t1 = Table(
+            "t1",
+            meta,
+            Column("c1", Integer, primary_key=True),
+            Column("c2", String(30)),
+        )
+        t2 = Table(
+            "t2",
+            meta,
+            Column("c1", Integer, primary_key=True),
+            Column("c2", String(30)),
+        )
+        t3 = Table(
+            "t3",
+            meta,
+            Column("c1", Integer, primary_key=True),
+            Column("c2", String(30)),
+        )
+
+        j1 = t1.join(t2, t1.c.c1 == t2.c.c1)
+
+        assert j1.is_derived_from(j1)
+
+        assert j1.is_derived_from(t1)
+
+        assert j1._annotate({"foo": "bar"}).is_derived_from(j1)
+
+        assert not j1.is_derived_from(t3)
+
 
 class AnnotationsTest(fixtures.TestBase):
     def test_hashing(self):