]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement independent CTEs
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 12 Jul 2021 18:28:19 +0000 (14:28 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 12 Jul 2021 21:10:58 +0000 (17:10 -0400)
Added new method :meth:`_sql.HasCTE.add_cte` to each of the
:func:`_sql.select`, :func:`_sql.insert`, :func:`_sql.update` and
:func:`_sql.delete` constructs. This method will add the given
:class:`_sql.CTE` as an "independent" CTE of the statement, meaning it
renders in the WITH clause above the statement unconditionally even if it
is not otherwise referenced in the primary statement. This is a popular use
case on the PostgreSQL database where a CTE is used for a DML statement
that runs against database rows independently of the primary statement.

Fixes: #6752
Change-Id: Ibf635763e40269cbd10f4c17e208850d8e8d0188

doc/build/changelog/unreleased_14/6752.rst [new file with mode: 0644]
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/roles.py
lib/sqlalchemy/sql/selectable.py
test/dialect/postgresql/test_compiler.py
test/sql/test_compare.py
test/sql/test_cte.py

diff --git a/doc/build/changelog/unreleased_14/6752.rst b/doc/build/changelog/unreleased_14/6752.rst
new file mode 100644 (file)
index 0000000..f97b0f8
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: usecase, sql
+    :tickets: 6752
+
+    Added new method :meth:`_sql.HasCTE.add_cte` to each of the
+    :func:`_sql.select`, :func:`_sql.insert`, :func:`_sql.update` and
+    :func:`_sql.delete` constructs. This method will add the given
+    :class:`_sql.CTE` as an "independent" CTE of the statement, meaning it
+    renders in the WITH clause above the statement unconditionally even if it
+    is not otherwise referenced in the primary statement. This is a popular use
+    case on the PostgreSQL database where a CTE is used for a DML statement
+    that runs against database rows independently of the primary statement.
index 1da1fee86b7d0c0d625c1970d4638dbe77378b9f..16a68c8ffd06f38861ed6a4da517ba67dca9b87b 100644 (file)
@@ -871,6 +871,10 @@ class HasCTEImpl(ReturnsRowsImpl):
     __slots__ = ()
 
 
+class IsCTEImpl(RoleImpl):
+    __slots__ = ()
+
+
 class JoinTargetImpl(RoleImpl):
     __slots__ = ()
 
index 4b3b2c293c6c81a11d643a375873e6757691e908..880479a37f691d01f3277a24ac487ada2dea8529 100644 (file)
@@ -2313,8 +2313,7 @@ class SQLCompiler(Compiled):
                 ) and not existing.proxy_set.intersection(bindparam.proxy_set):
                     raise exc.CompileError(
                         "Bind parameter '%s' conflicts with "
-                        "unique bind parameter of the same name"
-                        % bindparam.key
+                        "unique bind parameter of the same name" % name
                     )
                 elif existing._is_crud or bindparam._is_crud:
                     raise exc.CompileError(
@@ -3075,6 +3074,10 @@ class SQLCompiler(Compiled):
         else:
             byfrom = None
 
+        if select_stmt._independent_ctes:
+            for cte in select_stmt._independent_ctes:
+                cte._compiler_dispatch(self, **kwargs)
+
         if select_stmt._prefixes:
             text += self._generate_prefixes(
                 select_stmt, select_stmt._prefixes, **kwargs
@@ -3551,6 +3554,10 @@ class SQLCompiler(Compiled):
         if insert_stmt._hints:
             _, table_text = self._setup_crud_hints(insert_stmt, table_text)
 
+        if insert_stmt._independent_ctes:
+            for cte in insert_stmt._independent_ctes:
+                cte._compiler_dispatch(self, **kw)
+
         text += table_text
 
         if crud_params_single or not supports_default_values:
@@ -3700,6 +3707,10 @@ class SQLCompiler(Compiled):
         else:
             dialect_hints = None
 
+        if update_stmt._independent_ctes:
+            for cte in update_stmt._independent_ctes:
+                cte._compiler_dispatch(self, **kw)
+
         text += table_text
 
         text += " SET "
@@ -3808,6 +3819,10 @@ class SQLCompiler(Compiled):
         else:
             dialect_hints = None
 
+        if delete_stmt._independent_ctes:
+            for cte in delete_stmt._independent_ctes:
+                cte._compiler_dispatch(self, **kw)
+
         text += table_text
 
         if delete_stmt._returning:
index de847fb7fac6924749d2a2999665b7c76dacfc83..74f5a1d05b3d8e86bd60b3d98e20d753367bf5e3 100644 (file)
@@ -224,7 +224,17 @@ def _handle_values_anonymous_param(compiler, col, value, name, **kw):
     # rather than having
     # compiler.visit_bindparam()->compiler._truncated_identifier make up a
     # name.  Saves on call counts also.
-    if value.unique and isinstance(value.key, elements._truncated_label):
+
+    # for INSERT/UPDATE that's a CTE, we don't need names to match to
+    # external parameters and these would also conflict in the case where
+    # multiple insert/update are combined together using CTEs
+    is_cte = "visiting_cte" in kw
+
+    if (
+        not is_cte
+        and value.unique
+        and isinstance(value.key, elements._truncated_label)
+    ):
         compiler.truncated_names[("bindparam", value.key)] = name
 
     if value.type._isnull:
@@ -460,7 +470,6 @@ def _append_param_parameter(
     values,
     kw,
 ):
-
     value = parameters.pop(col_key)
 
     col_value = compiler.preparer.format_column(
index dd012ac86ab6d6e66171da87af476621131a45a7..048475040f6361c006b1818593036babce20da55 100644 (file)
@@ -829,6 +829,7 @@ class Insert(ValuesBase):
         + HasPrefixes._has_prefixes_traverse_internals
         + DialectKWArgs._dialect_kwargs_traverse_internals
         + Executable._executable_traverse_internals
+        + HasCTE._has_ctes_traverse_internals
     )
 
     @ValuesBase._constructor_20_deprecations(
@@ -1119,6 +1120,7 @@ class Update(DMLWhereBase, ValuesBase):
         + HasPrefixes._has_prefixes_traverse_internals
         + DialectKWArgs._dialect_kwargs_traverse_internals
         + Executable._executable_traverse_internals
+        + HasCTE._has_ctes_traverse_internals
     )
 
     @ValuesBase._constructor_20_deprecations(
@@ -1357,6 +1359,7 @@ class Delete(DMLWhereBase, UpdateBase):
         + HasPrefixes._has_prefixes_traverse_internals
         + DialectKWArgs._dialect_kwargs_traverse_internals
         + Executable._executable_traverse_internals
+        + HasCTE._has_ctes_traverse_internals
     )
 
     @ValuesBase._constructor_20_deprecations(
index a5eefc7b545e588524c651ab8a8e8cde777ff511..b9010397cf5775a9b79f4367e4a012d2985a16c1 100644 (file)
@@ -182,6 +182,10 @@ class HasCTERole(ReturnsRowsRole):
     pass
 
 
+class IsCTERole(SQLRole):
+    _role_name = "CTE object"
+
+
 class CompoundElementRole(AllowsLambdaRole, SQLRole):
     """SELECT statements inside a CompoundSelect, e.g. UNION, EXTRACT, etc."""
 
index 235c74ea76d7a5d1d93056243034ec7b75dd5b25..bd9ce0f20e4e7a150a35b52c38eac5104a7ce010 100644 (file)
@@ -1986,6 +1986,7 @@ class TableSample(AliasedReturnsRows):
 
 class CTE(
     roles.DMLTableRole,
+    roles.IsCTERole,
     Generative,
     HasPrefixes,
     HasSuffixes,
@@ -2110,6 +2111,79 @@ class HasCTE(roles.HasCTERole):
 
     """
 
+    _has_ctes_traverse_internals = [
+        ("_independent_ctes", InternalTraversal.dp_clauseelement_list),
+    ]
+
+    _independent_ctes = ()
+
+    @_generative
+    def add_cte(self, cte):
+        """Add a :class:`_sql.CTE` to this statement object that will be
+        independently rendered even if not referenced in the statement
+        otherwise.
+
+        This feature is useful for the use case of embedding a DML statement
+        such as an INSERT or UPDATE as a CTE inline with a primary statement
+        that may draw from its results indirectly; while PostgreSQL is known
+        to support this usage, it may not be supported by other backends.
+
+        E.g.::
+
+            from sqlalchemy import table, column, select
+            t = table('t', column('c1'), column('c2'))
+
+            ins = t.insert().values({"c1": "x", "c2": "y"}).cte()
+
+            stmt = select(t).add_cte(ins)
+
+        Would render::
+
+            WITH anon_1 AS
+            (INSERT INTO t (c1, c2) VALUES (:param_1, :param_2))
+            SELECT t.c1, t.c2
+            FROM t
+
+        Above, the "anon_1" CTE is not referred towards in the SELECT
+        statement, however still accomplishes the task of running an INSERT
+        statement.
+
+        Similarly in a DML-related context, using the PostgreSQL
+        :class:`_postgresql.Insert` construct to generate an "upsert"::
+
+            from sqlalchemy import table, column
+            from sqlalchemy.dialects.postgresql import insert
+
+            t = table("t", column("c1"), column("c2"))
+
+            delete_statement_cte = (
+                t.delete().where(t.c.c1 < 1).cte("deletions")
+            )
+
+            insert_stmt = insert(t).values({"c1": 1, "c2": 2})
+            update_statement = insert_stmt.on_conflict_do_update(
+                index_elements=[t.c.c1],
+                set_={
+                    "c1": insert_stmt.excluded.c1,
+                    "c2": insert_stmt.excluded.c2,
+                },
+            ).add_cte(delete_statement_cte)
+
+            print(update_statement)
+
+        The above statement renders as::
+
+            WITH deletions AS
+            (DELETE FROM t WHERE t.c1 < %(c1_1)s)
+            INSERT INTO t (c1, c2) VALUES (%(c1)s, %(c2)s)
+            ON CONFLICT (c1) DO UPDATE SET c1 = excluded.c1, c2 = excluded.c2
+
+        .. versionadded:: 1.4.21
+
+        """
+        cte = coercions.expect(roles.IsCTERole, cte)
+        self._independent_ctes += (cte,)
+
     def cte(self, name=None, recursive=False):
         r"""Return a new :class:`_expression.CTE`,
         or Common Table Expression instance.
@@ -4622,6 +4696,7 @@ class Select(
             ("_distinct_on", InternalTraversal.dp_clauseelement_tuple),
             ("_label_style", InternalTraversal.dp_plain_obj),
         ]
+        + HasCTE._has_ctes_traverse_internals
         + HasPrefixes._has_prefixes_traverse_internals
         + HasSuffixes._has_suffixes_traverse_internals
         + HasHints._has_hints_traverse_internals
index e48de9d21d7411f08e890e8aacdad5078a450967..c08038df6a14e79ea7cf4304feb0f0c5349850c6 100644 (file)
@@ -2518,7 +2518,7 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(
             stmt,
             "WITH i_upsert AS "
-            "(INSERT INTO mytable (name) VALUES (%(name)s) "
+            "(INSERT INTO mytable (name) VALUES (%(param_1)s) "
             "ON CONFLICT (name, description) "
             "WHERE description != %(description_1)s "
             "DO UPDATE SET name = excluded.name "
@@ -2527,6 +2527,30 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
             "FROM i_upsert",
         )
 
+    def test_combined_with_cte(self):
+        t = table("t", column("c1"), column("c2"))
+
+        delete_statement_cte = t.delete().where(t.c.c1 < 1).cte("deletions")
+
+        insert_stmt = insert(t).values([{"c1": 1, "c2": 2}])
+        update_stmt = insert_stmt.on_conflict_do_update(
+            index_elements=[t.c.c1],
+            set_={
+                col.name: col
+                for col in insert_stmt.excluded
+                if col.name in ("c1", "c2")
+            },
+        ).add_cte(delete_statement_cte)
+
+        self.assert_compile(
+            update_stmt,
+            "WITH deletions AS (DELETE FROM t WHERE t.c1 < %(c1_1)s) "
+            "INSERT INTO t (c1, c2) VALUES (%(c1_m0)s, %(c2_m0)s) "
+            "ON CONFLICT (c1) DO UPDATE SET c1 = excluded.c1, "
+            "c2 = excluded.c2",
+            checkparams={"c1_m0": 1, "c2_m0": 2, "c1_1": 1},
+        )
+
     def test_quote_raw_string_col(self):
         t = table("t", column("FancyName"), column("other name"))
 
index e96a47553b17d47ab36ac57abe555653e4d6c94b..1ffa9b50af8d1ff4f113b9d733a37c252759069f 100644 (file)
@@ -490,6 +490,16 @@ class CoreFixtures(object):
             select(table_a.c.a).join(table_b, table_a.c.a == table_b.c.b),
             select(table_a.c.a).join(table_c, table_a.c.a == table_c.c.x),
         ),
+        lambda: (
+            select(table_a.c.a),
+            select(table_a.c.a).add_cte(table_b.insert().cte()),
+            table_a.insert(),
+            table_a.delete(),
+            table_a.update(),
+            table_a.insert().add_cte(table_b.insert().cte()),
+            table_a.delete().add_cte(table_b.insert().cte()),
+            table_a.update().add_cte(table_b.insert().cte()),
+        ),
         lambda: (
             select(table_a.c.a).cte(),
             select(table_a.c.a).cte(recursive=True),
index 01186c340cc297529389ab12ae724e8fa94bc99d..e8a8a3150ce9deb6d2907c1d6e1bd0dddb5bb3ac 100644 (file)
@@ -1015,8 +1015,8 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
 
         self.assert_compile(
             insert,
-            "WITH upsert AS (UPDATE orders SET amount=:amount, "
-            "product=:product, quantity=:quantity "
+            "WITH upsert AS (UPDATE orders SET amount=:param_5, "
+            "product=:param_6, quantity=:param_7 "
             "WHERE orders.region = :region_1 "
             "RETURNING orders.region, orders.amount, "
             "orders.product, orders.quantity) "
@@ -1025,6 +1025,16 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             ":param_3 AS anon_3, :param_4 AS anon_4 WHERE NOT (EXISTS "
             "(SELECT upsert.region, upsert.amount, upsert.product, "
             "upsert.quantity FROM upsert))",
+            checkparams={
+                "param_1": "Region1",
+                "param_2": 1.0,
+                "param_3": "Product1",
+                "param_4": 1,
+                "param_5": 1.0,
+                "param_6": "Product1",
+                "param_7": 1,
+                "region_1": "Region1",
+            },
         )
 
         eq_(insert.compile().isinsert, True)
@@ -1106,9 +1116,10 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
 
         self.assert_compile(
             stmt.select(),
-            "WITH anon_1 AS (UPDATE orders SET region=:region "
+            "WITH anon_1 AS (UPDATE orders SET region=:param_1 "
             "WHERE orders.region = :region_1 RETURNING orders.region) "
             "SELECT anon_1.region FROM anon_1",
+            checkparams={"param_1": "y", "region_1": "x"},
         )
 
         eq_(stmt.select().compile().isupdate, False)
@@ -1122,8 +1133,9 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(
             stmt.select(),
             "WITH anon_1 AS (INSERT INTO orders (region) "
-            "VALUES (:region) RETURNING orders.region) "
+            "VALUES (:param_1) RETURNING orders.region) "
             "SELECT anon_1.region FROM anon_1",
+            checkparams={"param_1": "y"},
         )
         eq_(stmt.select().compile().isinsert, False)
 
@@ -1196,10 +1208,11 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(
             stmt,
             "WITH t AS "
-            "(UPDATE products SET price=:price "
+            "(UPDATE products SET price=:param_1 "
             "RETURNING products.id, products.price) "
             "SELECT t.id, t.price "
             "FROM t",
+            checkparams={"param_1": "someprice"},
         )
         eq_(stmt.compile().isupdate, False)
 
@@ -1257,10 +1270,11 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(
             stmt,
             "WITH pd AS "
-            "(INSERT INTO products (id, price) VALUES (:id, :price) "
+            "(INSERT INTO products (id, price) VALUES (:param_1, :param_2) "
             "RETURNING products.id, products.price) "
             "SELECT pd.id, pd.price "
             "FROM pd",
+            checkparams={"param_1": 1, "param_2": 27.0},
         )
         eq_(stmt.compile().isinsert, False)
 
@@ -1353,6 +1367,132 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
         )
         eq_(stmt.compile().isdelete, True)
 
+    def test_select_uses_independent_cte(self):
+        products = table("products", column("id"), column("price"))
+
+        upd_cte = (
+            products.update().values(price=10).where(products.c.price > 50)
+        ).cte()
+
+        stmt = products.select().where(products.c.price < 45).add_cte(upd_cte)
+
+        self.assert_compile(
+            stmt,
+            "WITH anon_1 AS (UPDATE products SET price=:param_1 "
+            "WHERE products.price > :price_1) "
+            "SELECT products.id, products.price "
+            "FROM products WHERE products.price < :price_2",
+            checkparams={"param_1": 10, "price_1": 50, "price_2": 45},
+        )
+
+    def test_insert_uses_independent_cte(self):
+        products = table("products", column("id"), column("price"))
+
+        upd_cte = (
+            products.update().values(price=10).where(products.c.price > 50)
+        ).cte()
+
+        stmt = (
+            products.insert().values({"id": 1, "price": 20}).add_cte(upd_cte)
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH anon_1 AS (UPDATE products SET price=:param_1 "
+            "WHERE products.price > :price_1) "
+            "INSERT INTO products (id, price) VALUES (:id, :price)",
+            checkparams={"id": 1, "price": 20, "param_1": 10, "price_1": 50},
+        )
+
+    def test_update_uses_independent_cte(self):
+        products = table("products", column("id"), column("price"))
+
+        upd_cte = (
+            products.update().values(price=10).where(products.c.price > 50)
+        ).cte()
+
+        stmt = (
+            products.update()
+            .values(price=5)
+            .where(products.c.price < 50)
+            .add_cte(upd_cte)
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH anon_1 AS (UPDATE products SET price=:param_1 "
+            "WHERE products.price > :price_1) UPDATE products "
+            "SET price=:price WHERE products.price < :price_2",
+            checkparams={
+                "param_1": 10,
+                "price": 5,
+                "price_1": 50,
+                "price_2": 50,
+            },
+        )
+
+    def test_update_w_insert_independent_cte(self):
+        products = table("products", column("id"), column("price"))
+
+        ins_cte = (products.insert().values({"id": 1, "price": 10})).cte()
+
+        stmt = (
+            products.update()
+            .values(price=5)
+            .where(products.c.price < 50)
+            .add_cte(ins_cte)
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH anon_1 AS (INSERT INTO products (id, price) "
+            "VALUES (:param_1, :param_2)) "
+            "UPDATE products SET price=:price WHERE products.price < :price_1",
+            checkparams={
+                "price": 5,
+                "param_1": 1,
+                "param_2": 10,
+                "price_1": 50,
+            },
+        )
+
+    def test_delete_uses_independent_cte(self):
+        products = table("products", column("id"), column("price"))
+
+        upd_cte = (
+            products.update().values(price=10).where(products.c.price > 50)
+        ).cte()
+
+        stmt = products.delete().where(products.c.price < 45).add_cte(upd_cte)
+
+        self.assert_compile(
+            stmt,
+            "WITH anon_1 AS (UPDATE products SET price=:param_1 "
+            "WHERE products.price > :price_1) "
+            "DELETE FROM products WHERE products.price < :price_2",
+            checkparams={"param_1": 10, "price_1": 50, "price_2": 45},
+        )
+
+    def test_independent_cte_can_be_referenced(self):
+        products = table("products", column("id"), column("price"))
+
+        cte = products.select().cte("pd")
+
+        stmt = (
+            products.update()
+            .where(products.c.price == cte.c.price)
+            .add_cte(cte)
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH pd AS "
+            "(SELECT products.id AS id, products.price AS price "
+            "FROM products) "
+            "UPDATE products SET id=:id, price=:price FROM pd "
+            "WHERE products.price = pd.price",
+        )
+
     def test_standalone_function(self):
         a = table("a", column("x"))
         a_stmt = select(a)