]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
check that two CTEs aren't just annotated forms of the same thing
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 20 Feb 2025 17:50:25 +0000 (12:50 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 20 Feb 2025 18:33:54 +0000 (13:33 -0500)
Fixed issue where using :func:`_orm.aliased` around a :class:`.CTE`
construct could cause inappropriate "duplicate CTE" errors in cases where
that aliased construct appeared multiple times in a single statement.

Fixes: #12364
Change-Id: I9625cd83e9baf5312cdc644b38951353708d3b86
(cherry picked from commit 42ddb1fd5f1e29682bcd6ccc7b835999aafec12e)

doc/build/changelog/unreleased_20/12364.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
test/sql/test_cte.py

diff --git a/doc/build/changelog/unreleased_20/12364.rst b/doc/build/changelog/unreleased_20/12364.rst
new file mode 100644 (file)
index 0000000..59f5d24
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 12364
+
+    Fixed issue where using :func:`_orm.aliased` around a :class:`.CTE`
+    construct could cause inappropriate "duplicate CTE" errors in cases where
+    that aliased construct appeared multiple times in a single statement.
index 49e8ce500e8faf14607c07d90f18d41480f1760c..da476849ea06b8426147a7d112053ee89922c40d 100644 (file)
@@ -4074,15 +4074,28 @@ class SQLCompiler(Compiled):
 
                 del self.level_name_by_cte[existing_cte_reference_cte]
             else:
-                # if the two CTEs are deep-copy identical, consider them
-                # the same, **if** they are clones, that is, they came from
-                # the ORM or other visit method
                 if (
-                    cte._is_clone_of is not None
-                    or existing_cte._is_clone_of is not None
-                ) and cte.compare(existing_cte):
+                    # if the two CTEs have the same hash, which we expect
+                    # here means that one/both is an annotated of the other
+                    (hash(cte) == hash(existing_cte))
+                    # or...
+                    or (
+                        (
+                            # if they are clones, i.e. they came from the ORM
+                            # or some other visit method
+                            cte._is_clone_of is not None
+                            or existing_cte._is_clone_of is not None
+                        )
+                        # and are deep-copy identical
+                        and cte.compare(existing_cte)
+                    )
+                ):
+                    # then consider these two CTEs the same
                     is_new_cte = False
                 else:
+                    # otherwise these are two CTEs that either will render
+                    # differently, or were indicated separately by the user,
+                    # with the same name
                     raise exc.CompileError(
                         "Multiple, unrelated CTEs found with "
                         "the same name: %r" % cte_name
index 383f2adaabd0748e5a78ebc054ef828426691d07..d0ecc38c86f26dcf035daf1fc8fb2337c9fc9056 100644 (file)
@@ -8,6 +8,7 @@ from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import true
+from sqlalchemy import union_all
 from sqlalchemy import update
 from sqlalchemy.dialects import mssql
 from sqlalchemy.engine import default
@@ -492,16 +493,22 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
     @testing.combinations(True, False, argnames="identical")
-    @testing.combinations(True, False, argnames="use_clone")
-    def test_conflicting_names(self, identical, use_clone):
+    @testing.variation("clone_type", ["none", "clone", "annotated"])
+    def test_conflicting_names(self, identical, clone_type):
         """test a flat out name conflict."""
 
         s1 = select(1)
         c1 = s1.cte(name="cte1", recursive=True)
-        if use_clone:
+        if clone_type.clone:
             c2 = c1._clone()
             if not identical:
                 c2 = c2.union(select(2))
+        elif clone_type.annotated:
+            # this does not seem to trigger the issue that was fixed in
+            # #12364 howver is still a worthy test
+            c2 = c1._annotate({"foo": "bar"})
+            if not identical:
+                c2 = c2.union(select(2))
         else:
             if identical:
                 s2 = select(1)
@@ -511,12 +518,20 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
 
         s = select(c1, c2)
 
-        if use_clone and identical:
+        if clone_type.clone and identical:
             self.assert_compile(
                 s,
                 'WITH RECURSIVE cte1("1") AS (SELECT 1) SELECT cte1.1, '
                 'cte1.1 AS "1_1" FROM cte1',
             )
+        elif clone_type.annotated and identical:
+            # annotated seems to have a slightly different rendering
+            # scheme here
+            self.assert_compile(
+                s,
+                'WITH RECURSIVE cte1("1") AS (SELECT 1) SELECT cte1.1, '
+                'cte1.1 AS "1__1" FROM cte1',
+            )
         else:
             assert_raises_message(
                 CompileError,
@@ -524,6 +539,32 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
                 s.compile,
             )
 
+    @testing.variation("annotated", [True, False])
+    def test_cte_w_annotated(self, annotated):
+        """test #12364"""
+
+        A = table("a", column("i"), column("j"))
+        B = table("b", column("i"), column("j"))
+
+        a = select(A).where(A.c.i > A.c.j).cte("filtered_a")
+
+        if annotated:
+            a = a._annotate({"foo": "bar"})
+
+        a1 = select(a.c.i, literal(1).label("j"))
+        b = select(B).join(a, a.c.i == B.c.i).where(B.c.j.is_not(None))
+
+        query = union_all(a1, b)
+        self.assert_compile(
+            query,
+            "WITH filtered_a AS "
+            "(SELECT a.i AS i, a.j AS j FROM a WHERE a.i > a.j) "
+            "SELECT filtered_a.i, :param_1 AS j FROM filtered_a "
+            "UNION ALL SELECT b.i, b.j "
+            "FROM b JOIN filtered_a ON filtered_a.i = b.i "
+            "WHERE b.j IS NOT NULL",
+        )
+
     def test_with_recursive_no_name_currently_buggy(self):
         s1 = select(1)
         c1 = s1.cte(name="cte1", recursive=True)