]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
deep compare CTEs before considering them conflicting
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Aug 2022 21:25:05 +0000 (17:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Aug 2022 21:26:49 +0000 (17:26 -0400)
Fixed issue where referencing a CTE multiple times in conjunction with a
polymorphic SELECT could result in multiple "clones" of the same CTE being
constructed, which would then trigger these two CTEs as duplicates. To
resolve, the two CTEs are deep-compared when this occurs to ensure that
they are equivalent, then are treated as equivalent.

Fixes: #8357
Change-Id: I1f634a9cf7a6c4256912aac1a00506aecea3b0e2
(cherry picked from commit 85fa363c846f4ed287565c43c32e2cca29470e25)

doc/build/changelog/unreleased_14/8357.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
test/orm/inheritance/test_polymorphic_rel.py
test/sql/test_cte.py

diff --git a/doc/build/changelog/unreleased_14/8357.rst b/doc/build/changelog/unreleased_14/8357.rst
new file mode 100644 (file)
index 0000000..129368b
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 8357
+
+    Fixed issue where referencing a CTE multiple times in conjunction with a
+    polymorphic SELECT could result in multiple "clones" of the same CTE being
+    constructed, which would then trigger these two CTEs as duplicates. To
+    resolve, the two CTEs are deep-compared when this occurs to ensure that
+    they are equivalent, then are treated as equivalent.
+
index 330f3c3bc86672f19f6a2cb9917af1b0f9644e05..c9b6ba670c2c2368af372ecbf02b516f44e52634 100644 (file)
@@ -2708,10 +2708,19 @@ class SQLCompiler(Compiled):
 
                 del self.level_name_by_cte[existing_cte_reference_cte]
             else:
-                raise exc.CompileError(
-                    "Multiple, unrelated CTEs found with "
-                    "the same name: %r" % cte_name
-                )
+                # if the two CTEs are deep-copy identical, consider them
+                # the same, **if** they are clones, that is, they came from
+                # the ORM or other visit method
+                if (
+                    cte._is_clone_of is not None
+                    or existing_cte._is_clone_of is not None
+                ) and cte.compare(existing_cte):
+                    is_new_cte = False
+                else:
+                    raise exc.CompileError(
+                        "Multiple, unrelated CTEs found with "
+                        "the same name: %r" % cte_name
+                    )
 
         if not asfrom and not is_new_cte:
             return None
index aa8d9eaec687676285cf783f8fff7e767680b7ca..9ccec61ee1296cffe4f87a6c60eb06435dfdb741 100644 (file)
@@ -1,5 +1,6 @@
 from sqlalchemy import desc
 from sqlalchemy import exc as sa_exc
+from sqlalchemy import exists
 from sqlalchemy import func
 from sqlalchemy import select
 from sqlalchemy import testing
@@ -64,6 +65,44 @@ class _PolymorphicTestBase(object):
         )
         e1, e2, e3, b1, m1 = cls.e1, cls.e2, cls.e3, cls.b1, cls.m1
 
+    @testing.requires.ctes
+    def test_cte_clone_issue(self):
+        """test #8357"""
+
+        sess = fixture_session()
+
+        cte = select(Engineer.person_id).cte(name="test_cte")
+
+        stmt = (
+            select(Engineer)
+            .where(exists().where(Engineer.person_id == cte.c.person_id))
+            .where(exists().where(Engineer.person_id == cte.c.person_id))
+        ).order_by(Engineer.person_id)
+
+        self.assert_compile(
+            stmt,
+            "WITH test_cte AS (SELECT engineers.person_id AS person_id "
+            "FROM people JOIN engineers ON people.person_id = "
+            "engineers.person_id) SELECT engineers.person_id, "
+            "people.person_id AS person_id_1, people.company_id, "
+            "people.name, people.type, engineers.status, "
+            "engineers.engineer_name, engineers.primary_language FROM people "
+            "JOIN engineers ON people.person_id = engineers.person_id WHERE "
+            "(EXISTS (SELECT * FROM test_cte WHERE engineers.person_id = "
+            "test_cte.person_id)) AND (EXISTS (SELECT * FROM test_cte "
+            "WHERE engineers.person_id = test_cte.person_id)) "
+            "ORDER BY engineers.person_id",
+        )
+        result = sess.scalars(stmt)
+        eq_(
+            result.all(),
+            [
+                Engineer(name="dilbert"),
+                Engineer(name="wally"),
+                Engineer(name="vlad"),
+            ],
+        )
+
     def test_loads_at_once(self):
         """
         Test that all objects load from the full query, when
index d146ae6066483da4d3dcc822e41a448dbae8c5a0..fed371f62946d5613923945a99ae83eefe871cf9 100644 (file)
@@ -486,20 +486,38 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT cs1.x, cs2.x AS x_1 FROM bar AS cs1, cte AS cs2",
         )
 
-    def test_conflicting_names(self):
+    @testing.combinations(True, False, argnames="identical")
+    @testing.combinations(True, False, argnames="use_clone")
+    def test_conflicting_names(self, identical, use_clone):
         """test a flat out name conflict."""
 
         s1 = select(1)
         c1 = s1.cte(name="cte1", recursive=True)
-        s2 = select(1)
-        c2 = s2.cte(name="cte1", recursive=True)
+        if use_clone:
+            c2 = c1._clone()
+            if not identical:
+                c2 = c2.union(select(2))
+        else:
+            if identical:
+                s2 = select(1)
+            else:
+                s2 = select(column("q"))
+            c2 = s2.cte(name="cte1", recursive=True)
 
         s = select(c1, c2)
-        assert_raises_message(
-            CompileError,
-            "Multiple, unrelated CTEs found " "with the same name: 'cte1'",
-            s.compile,
-        )
+
+        if use_clone and identical:
+            self.assert_compile(
+                s,
+                'WITH RECURSIVE cte1("1") AS (SELECT 1) SELECT cte1.1, '
+                'cte1.1 AS "1_1" FROM cte1',
+            )
+        else:
+            assert_raises_message(
+                CompileError,
+                "Multiple, unrelated CTEs found " "with the same name: 'cte1'",
+                s.compile,
+            )
 
     def test_with_recursive_no_name_currently_buggy(self):
         s1 = select(1)