]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Recursive tests are real recursive query
authorEric Masseran <eric.masseran@gmail.com>
Fri, 1 Oct 2021 17:52:30 +0000 (19:52 +0200)
committerEric Masseran <eric.masseran@gmail.com>
Fri, 1 Oct 2021 17:52:30 +0000 (19:52 +0200)
test/sql/test_cte.py

index 2f847279aceffbd20821d0dd3bdc1ced28fada74..f6070239e43af4c81c6d0f14c8f28aec45692704 100644 (file)
@@ -1901,24 +1901,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
         nesting_cte = select(literal(1).label("inner_cte")).cte(
             "nesting", nesting=True
         )
-        stmt = select(
-            select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
-                "cte", recursive=True
-            )
+
+        rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
+            "rec_cte", recursive=True
+        )
+        rec_part = select(rec_cte.c.outer_cte).where(
+            rec_cte.c.outer_cte == literal(1)
         )
+        rec_cte = rec_cte.union(rec_part)
+
+        stmt = select(rec_cte)
 
         self.assert_compile(
             stmt,
-            "WITH RECURSIVE cte(outer_cte) AS (WITH nesting AS "
+            "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
             "(SELECT :param_1 AS inner_cte) "
-            "SELECT nesting.inner_cte AS outer_cte FROM nesting) "
-            "SELECT cte.outer_cte FROM cte",
+            "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+            "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+            "WHERE rec_cte.outer_cte = :param_2) "
+            "SELECT rec_cte.outer_cte FROM rec_cte",
         )
 
     def test_recursive_nesting_cte_in_cte(self):
         nesting_cte = select(literal(1).label("inner_cte")).cte(
-            "nesting", nesting=True, recursive=True
+            "nesting", recursive=True, nesting=True
+        )
+        rec_part = select(nesting_cte.c.inner_cte).where(
+            nesting_cte.c.inner_cte == literal(1)
         )
+        nesting_cte = nesting_cte.union(rec_part)
+
         stmt = select(
             select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
         )
@@ -1926,7 +1938,9 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(
             stmt,
             "WITH cte AS (WITH RECURSIVE nesting(inner_cte) AS "
-            "(SELECT :param_1 AS inner_cte) "
+            "(SELECT :param_1 AS inner_cte UNION "
+            "SELECT nesting.inner_cte AS inner_cte FROM nesting "
+            "WHERE nesting.inner_cte = :param_2) "
             "SELECT nesting.inner_cte AS outer_cte FROM nesting) "
             "SELECT cte.outer_cte FROM cte",
         )
@@ -1976,19 +1990,32 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
         nesting_cte = select(literal(1).label("inner_cte")).cte(
             "nesting", nesting=True, recursive=True
         )
-        stmt = select(
-            select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
-                "cte", recursive=True
-            )
+        nesting_rec_part = select(nesting_cte.c.inner_cte).where(
+            nesting_cte.c.inner_cte == literal(1)
+        )
+        nesting_cte = nesting_cte.union(nesting_rec_part)
+
+        rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
+            "rec_cte", recursive=True
+        )
+        rec_part = select(rec_cte.c.outer_cte).where(
+            rec_cte.c.outer_cte == literal(1)
         )
+        rec_cte = rec_cte.union(rec_part)
+
+        stmt = select(rec_cte)
 
         self.assert_compile(
             stmt,
-            "WITH RECURSIVE cte(outer_cte) AS "
-            "(WITH RECURSIVE nesting(inner_cte) "
-            "AS (SELECT :param_1 AS inner_cte) "
-            "SELECT nesting.inner_cte AS outer_cte FROM nesting) "
-            "SELECT cte.outer_cte FROM cte",
+            "WITH RECURSIVE rec_cte(outer_cte) AS ("
+            "WITH RECURSIVE nesting(inner_cte) AS "
+            "(SELECT :param_1 AS inner_cte UNION "
+            "SELECT nesting.inner_cte AS inner_cte FROM nesting "
+            "WHERE nesting.inner_cte = :param_2) "
+            "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+            "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+            "WHERE rec_cte.outer_cte = :param_3) "
+            "SELECT rec_cte.outer_cte FROM rec_cte",
         )
 
     def test_select_from_insert_cte_with_nesting(self):
@@ -2186,22 +2213,3 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             "WHERE anon_1.inner_cte = :param_2) "
             "SELECT nesting.inner_cte FROM nesting",
         )
-
-    def test_recursive_cte_w_union(self):
-        nesting_cte = select(literal(1).label("inner_cte")).cte(
-            "nesting", recursive=True, nesting=True
-        )
-        rec_part = select(nesting_cte.c.inner_cte).where(
-            nesting_cte.c.inner_cte == literal(1)
-        )
-        nesting_cte = nesting_cte.union(rec_part)
-
-        stmt = select(nesting_cte.c.inner_cte)
-        self.assert_compile(
-            stmt,
-            "WITH RECURSIVE nesting(inner_cte) AS "
-            "(SELECT :param_1 AS inner_cte UNION "
-            "SELECT nesting.inner_cte AS inner_cte FROM nesting "
-            "WHERE nesting.inner_cte = :param_2) "
-            "SELECT nesting.inner_cte FROM nesting",
-        )