]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support add_cte() for TextualSelect
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 23 Feb 2022 17:50:36 +0000 (12:50 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 23 Feb 2022 20:25:34 +0000 (15:25 -0500)
Fixed issue where the :meth:`.HasCTE.add_cte` method as called upon a
:class:`.TextualSelect` instance was not being accommodated by the SQL
compiler. The fix additionally adds more "SELECT"-like compiler behavior to
:class:`.TextualSelect` including that DML CTEs such as UPDATE and INSERT
may be accommodated.

Fixes: #7760
Change-Id: Id97062d882e9b2a81b8e31c2bfaa9cfc5f77d5c1

doc/build/changelog/unreleased_14/7760.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/sql/compiler.py
test/sql/test_cte.py

diff --git a/doc/build/changelog/unreleased_14/7760.rst b/doc/build/changelog/unreleased_14/7760.rst
new file mode 100644 (file)
index 0000000..2f0d403
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 7760
+
+    Fixed issue where the :meth:`.HasCTE.add_cte` method as called upon a
+    :class:`.TextualSelect` instance was not being accommodated by the SQL
+    compiler. The fix additionally adds more "SELECT"-like compiler behavior to
+    :class:`.TextualSelect` including that DML CTEs such as UPDATE and INSERT
+    may be accommodated.
index f51abde0c3ebabf1c3046dcf1e50a141fcf46650..63ed10d501b141b11dfca7ffea85d5550121f0d4 100644 (file)
@@ -27,6 +27,7 @@ from .. import future
 from .. import inspect
 from .. import sql
 from .. import util
+from ..sql import ClauseElement
 from ..sql import coercions
 from ..sql import expression
 from ..sql import roles
@@ -486,8 +487,8 @@ class ORMFromStatementCompileState(ORMCompileState):
                 entity.setup_compile_state(self)
 
             # we did the setup just to get primary columns.
-            self.statement = expression.TextualSelect(
-                self.statement, self.primary_columns, positional=False
+            self.statement = _AdHocColumnsStatement(
+                self.statement, self.primary_columns
             )
         else:
             # allow TextualSelect with implicit columns as well
@@ -514,6 +515,65 @@ class ORMFromStatementCompileState(ORMCompileState):
         return None
 
 
+class _AdHocColumnsStatement(ClauseElement):
+    """internal object created to somewhat act like a SELECT when we
+    are selecting columns from a DML RETURNING.
+
+
+    """
+
+    __visit_name__ = None
+
+    def __init__(self, text, columns):
+        self.element = text
+        self.column_args = [
+            coercions.expect(roles.ColumnsClauseRole, c) for c in columns
+        ]
+
+    def _generate_cache_key(self):
+        raise NotImplementedError()
+
+    def _gen_cache_key(self, anon_map, bindparams):
+        raise NotImplementedError()
+
+    def _compiler_dispatch(
+        self, compiler, compound_index=None, asfrom=False, **kw
+    ):
+        """provide a fixed _compiler_dispatch method."""
+
+        toplevel = not compiler.stack
+        entry = (
+            compiler._default_stack_entry if toplevel else compiler.stack[-1]
+        )
+
+        populate_result_map = (
+            toplevel
+            # these two might not be needed
+            or (
+                compound_index == 0
+                and entry.get("need_result_map_for_compound", False)
+            )
+            or entry.get("need_result_map_for_nested", False)
+        )
+
+        if populate_result_map:
+            compiler._ordered_columns = (
+                compiler._textual_ordered_columns
+            ) = False
+
+            # enable looser result column matching.  this is shown to be
+            # needed by test_query.py::TextTest
+            compiler._loose_column_name_matching = True
+
+            for c in self.column_args:
+                compiler.process(
+                    c,
+                    within_columns_clause=True,
+                    add_to_result_map=compiler._add_to_result_map,
+                )
+        return compiler.process(self.element, **kw)
+
+
 @sql.base.CompileState.plugin_for("orm", "select")
 class ORMSelectCompileState(ORMCompileState, SelectState):
     _already_joined_edges = ()
index 4a169f719da28befbe6651b3d3f58f6c180ddb76..b140f9297576cf390495d164d65590fd18d76ce6 100644 (file)
@@ -1596,6 +1596,17 @@ class SQLCompiler(Compiled):
         toplevel = not self.stack
         entry = self._default_stack_entry if toplevel else self.stack[-1]
 
+        new_entry = {
+            "correlate_froms": set(),
+            "asfrom_froms": set(),
+            "selectable": taf,
+        }
+        self.stack.append(new_entry)
+
+        if taf._independent_ctes:
+            for cte in taf._independent_ctes:
+                cte._compiler_dispatch(self, **kw)
+
         populate_result_map = (
             toplevel
             or (
@@ -1623,7 +1634,12 @@ class SQLCompiler(Compiled):
                     add_to_result_map=self._add_to_result_map,
                 )
 
-        return self.process(taf.element, **kw)
+        text = self.process(taf.element, **kw)
+        if self.ctes:
+            nesting_level = len(self.stack) if not toplevel else None
+            text = self._render_cte_clause(nesting_level=nesting_level) + text
+
+        return text
 
     def visit_null(self, expr, **kw):
         return "NULL"
index 64479b9692f8397254d3bbf31be81473dbbe2a05..b0569250485f91665d8b64941857d8e28901ebe7 100644 (file)
@@ -1601,6 +1601,51 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             },
         )
 
+    def test_textual_select_uses_independent_cte_one(self):
+        """test #7760"""
+        products = table("products", column("id"), column("price"))
+
+        upd_cte = (
+            products.update().values(price=10).where(products.c.price > 50)
+        ).cte()
+
+        stmt = (
+            text(
+                "SELECT products.id, products.price "
+                "FROM products WHERE products.price < :price_2"
+            )
+            .columns(products.c.id, products.c.price)
+            .bindparams(price_2=45)
+            .add_cte(upd_cte)
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH anon_1 AS (UPDATE products SET price=:param_1 "
+            "WHERE products.price > :price_1) "
+            "SELECT products.id, products.price "
+            "FROM products WHERE products.price < :price_2",
+            checkparams={"param_1": 10, "price_1": 50, "price_2": 45},
+        )
+
+    def test_textual_select_uses_independent_cte_two(self):
+
+        foo = table("foo", column("id"))
+        bar = table("bar", column("id"), column("attr"), column("foo_id"))
+        s1 = select(foo.c.id)
+        s2 = text(
+            "SELECT bar.id, bar.attr FROM bar "
+            "WHERE bar.foo_id IN (SELECT id FROM baz)"
+        ).columns(bar.c.id, bar.c.attr)
+        s3 = s2.add_cte(s1.cte(name="baz"))
+
+        self.assert_compile(
+            s3,
+            "WITH baz AS (SELECT foo.id AS id FROM foo) "
+            "SELECT bar.id, bar.attr FROM bar WHERE bar.foo_id IN "
+            "(SELECT id FROM baz)",
+        )
+
     def test_insert_uses_independent_cte(self):
         products = table("products", column("id"), column("price"))