From: Mike Bayer Date: Mon, 12 Jul 2021 18:28:19 +0000 (-0400) Subject: implement independent CTEs X-Git-Tag: rel_1_4_21~9^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=204ff1f60cf911b00b7494942fc58bc715dddeed;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implement independent CTEs Added new method :meth:`_sql.HasCTE.add_cte` to each of the :func:`_sql.select`, :func:`_sql.insert`, :func:`_sql.update` and :func:`_sql.delete` constructs. This method will add the given :class:`_sql.CTE` as an "independent" CTE of the statement, meaning it renders in the WITH clause above the statement unconditionally even if it is not otherwise referenced in the primary statement. This is a popular use case on the PostgreSQL database where a CTE is used for a DML statement that runs against database rows independently of the primary statement. Fixes: #6752 Change-Id: Ibf635763e40269cbd10f4c17e208850d8e8d0188 --- diff --git a/doc/build/changelog/unreleased_14/6752.rst b/doc/build/changelog/unreleased_14/6752.rst new file mode 100644 index 0000000000..f97b0f8c88 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6752.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: usecase, sql + :tickets: 6752 + + Added new method :meth:`_sql.HasCTE.add_cte` to each of the + :func:`_sql.select`, :func:`_sql.insert`, :func:`_sql.update` and + :func:`_sql.delete` constructs. This method will add the given + :class:`_sql.CTE` as an "independent" CTE of the statement, meaning it + renders in the WITH clause above the statement unconditionally even if it + is not otherwise referenced in the primary statement. This is a popular use + case on the PostgreSQL database where a CTE is used for a DML statement + that runs against database rows independently of the primary statement. diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 1da1fee86b..16a68c8ffd 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -871,6 +871,10 @@ class HasCTEImpl(ReturnsRowsImpl): __slots__ = () +class IsCTEImpl(RoleImpl): + __slots__ = () + + class JoinTargetImpl(RoleImpl): __slots__ = () diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 4b3b2c293c..880479a37f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2313,8 +2313,7 @@ class SQLCompiler(Compiled): ) and not existing.proxy_set.intersection(bindparam.proxy_set): raise exc.CompileError( "Bind parameter '%s' conflicts with " - "unique bind parameter of the same name" - % bindparam.key + "unique bind parameter of the same name" % name ) elif existing._is_crud or bindparam._is_crud: raise exc.CompileError( @@ -3075,6 +3074,10 @@ class SQLCompiler(Compiled): else: byfrom = None + if select_stmt._independent_ctes: + for cte in select_stmt._independent_ctes: + cte._compiler_dispatch(self, **kwargs) + if select_stmt._prefixes: text += self._generate_prefixes( select_stmt, select_stmt._prefixes, **kwargs @@ -3551,6 +3554,10 @@ class SQLCompiler(Compiled): if insert_stmt._hints: _, table_text = self._setup_crud_hints(insert_stmt, table_text) + if insert_stmt._independent_ctes: + for cte in insert_stmt._independent_ctes: + cte._compiler_dispatch(self, **kw) + text += table_text if crud_params_single or not supports_default_values: @@ -3700,6 +3707,10 @@ class SQLCompiler(Compiled): else: dialect_hints = None + if update_stmt._independent_ctes: + for cte in update_stmt._independent_ctes: + cte._compiler_dispatch(self, **kw) + text += table_text text += " SET " @@ -3808,6 +3819,10 @@ class SQLCompiler(Compiled): else: dialect_hints = None + if delete_stmt._independent_ctes: + for cte in delete_stmt._independent_ctes: + cte._compiler_dispatch(self, **kw) + text += table_text if delete_stmt._returning: diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index de847fb7fa..74f5a1d05b 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -224,7 +224,17 @@ def _handle_values_anonymous_param(compiler, col, value, name, **kw): # rather than having # compiler.visit_bindparam()->compiler._truncated_identifier make up a # name. Saves on call counts also. - if value.unique and isinstance(value.key, elements._truncated_label): + + # for INSERT/UPDATE that's a CTE, we don't need names to match to + # external parameters and these would also conflict in the case where + # multiple insert/update are combined together using CTEs + is_cte = "visiting_cte" in kw + + if ( + not is_cte + and value.unique + and isinstance(value.key, elements._truncated_label) + ): compiler.truncated_names[("bindparam", value.key)] = name if value.type._isnull: @@ -460,7 +470,6 @@ def _append_param_parameter( values, kw, ): - value = parameters.pop(col_key) col_value = compiler.preparer.format_column( diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index dd012ac86a..048475040f 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -829,6 +829,7 @@ class Insert(ValuesBase): + HasPrefixes._has_prefixes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals + Executable._executable_traverse_internals + + HasCTE._has_ctes_traverse_internals ) @ValuesBase._constructor_20_deprecations( @@ -1119,6 +1120,7 @@ class Update(DMLWhereBase, ValuesBase): + HasPrefixes._has_prefixes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals + Executable._executable_traverse_internals + + HasCTE._has_ctes_traverse_internals ) @ValuesBase._constructor_20_deprecations( @@ -1357,6 +1359,7 @@ class Delete(DMLWhereBase, UpdateBase): + HasPrefixes._has_prefixes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals + Executable._executable_traverse_internals + + HasCTE._has_ctes_traverse_internals ) @ValuesBase._constructor_20_deprecations( diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index a5eefc7b54..b9010397cf 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -182,6 +182,10 @@ class HasCTERole(ReturnsRowsRole): pass +class IsCTERole(SQLRole): + _role_name = "CTE object" + + class CompoundElementRole(AllowsLambdaRole, SQLRole): """SELECT statements inside a CompoundSelect, e.g. UNION, EXTRACT, etc.""" diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 235c74ea76..bd9ce0f20e 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1986,6 +1986,7 @@ class TableSample(AliasedReturnsRows): class CTE( roles.DMLTableRole, + roles.IsCTERole, Generative, HasPrefixes, HasSuffixes, @@ -2110,6 +2111,79 @@ class HasCTE(roles.HasCTERole): """ + _has_ctes_traverse_internals = [ + ("_independent_ctes", InternalTraversal.dp_clauseelement_list), + ] + + _independent_ctes = () + + @_generative + def add_cte(self, cte): + """Add a :class:`_sql.CTE` to this statement object that will be + independently rendered even if not referenced in the statement + otherwise. + + This feature is useful for the use case of embedding a DML statement + such as an INSERT or UPDATE as a CTE inline with a primary statement + that may draw from its results indirectly; while PostgreSQL is known + to support this usage, it may not be supported by other backends. + + E.g.:: + + from sqlalchemy import table, column, select + t = table('t', column('c1'), column('c2')) + + ins = t.insert().values({"c1": "x", "c2": "y"}).cte() + + stmt = select(t).add_cte(ins) + + Would render:: + + WITH anon_1 AS + (INSERT INTO t (c1, c2) VALUES (:param_1, :param_2)) + SELECT t.c1, t.c2 + FROM t + + Above, the "anon_1" CTE is not referred towards in the SELECT + statement, however still accomplishes the task of running an INSERT + statement. + + Similarly in a DML-related context, using the PostgreSQL + :class:`_postgresql.Insert` construct to generate an "upsert":: + + from sqlalchemy import table, column + from sqlalchemy.dialects.postgresql import insert + + t = table("t", column("c1"), column("c2")) + + delete_statement_cte = ( + t.delete().where(t.c.c1 < 1).cte("deletions") + ) + + insert_stmt = insert(t).values({"c1": 1, "c2": 2}) + update_statement = insert_stmt.on_conflict_do_update( + index_elements=[t.c.c1], + set_={ + "c1": insert_stmt.excluded.c1, + "c2": insert_stmt.excluded.c2, + }, + ).add_cte(delete_statement_cte) + + print(update_statement) + + The above statement renders as:: + + WITH deletions AS + (DELETE FROM t WHERE t.c1 < %(c1_1)s) + INSERT INTO t (c1, c2) VALUES (%(c1)s, %(c2)s) + ON CONFLICT (c1) DO UPDATE SET c1 = excluded.c1, c2 = excluded.c2 + + .. versionadded:: 1.4.21 + + """ + cte = coercions.expect(roles.IsCTERole, cte) + self._independent_ctes += (cte,) + def cte(self, name=None, recursive=False): r"""Return a new :class:`_expression.CTE`, or Common Table Expression instance. @@ -4622,6 +4696,7 @@ class Select( ("_distinct_on", InternalTraversal.dp_clauseelement_tuple), ("_label_style", InternalTraversal.dp_plain_obj), ] + + HasCTE._has_ctes_traverse_internals + HasPrefixes._has_prefixes_traverse_internals + HasSuffixes._has_suffixes_traverse_internals + HasHints._has_hints_traverse_internals diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index e48de9d21d..c08038df6a 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -2518,7 +2518,7 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( stmt, "WITH i_upsert AS " - "(INSERT INTO mytable (name) VALUES (%(name)s) " + "(INSERT INTO mytable (name) VALUES (%(param_1)s) " "ON CONFLICT (name, description) " "WHERE description != %(description_1)s " "DO UPDATE SET name = excluded.name " @@ -2527,6 +2527,30 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): "FROM i_upsert", ) + def test_combined_with_cte(self): + t = table("t", column("c1"), column("c2")) + + delete_statement_cte = t.delete().where(t.c.c1 < 1).cte("deletions") + + insert_stmt = insert(t).values([{"c1": 1, "c2": 2}]) + update_stmt = insert_stmt.on_conflict_do_update( + index_elements=[t.c.c1], + set_={ + col.name: col + for col in insert_stmt.excluded + if col.name in ("c1", "c2") + }, + ).add_cte(delete_statement_cte) + + self.assert_compile( + update_stmt, + "WITH deletions AS (DELETE FROM t WHERE t.c1 < %(c1_1)s) " + "INSERT INTO t (c1, c2) VALUES (%(c1_m0)s, %(c2_m0)s) " + "ON CONFLICT (c1) DO UPDATE SET c1 = excluded.c1, " + "c2 = excluded.c2", + checkparams={"c1_m0": 1, "c2_m0": 2, "c1_1": 1}, + ) + def test_quote_raw_string_col(self): t = table("t", column("FancyName"), column("other name")) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index e96a47553b..1ffa9b50af 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -490,6 +490,16 @@ class CoreFixtures(object): select(table_a.c.a).join(table_b, table_a.c.a == table_b.c.b), select(table_a.c.a).join(table_c, table_a.c.a == table_c.c.x), ), + lambda: ( + select(table_a.c.a), + select(table_a.c.a).add_cte(table_b.insert().cte()), + table_a.insert(), + table_a.delete(), + table_a.update(), + table_a.insert().add_cte(table_b.insert().cte()), + table_a.delete().add_cte(table_b.insert().cte()), + table_a.update().add_cte(table_b.insert().cte()), + ), lambda: ( select(table_a.c.a).cte(), select(table_a.c.a).cte(recursive=True), diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 01186c340c..e8a8a3150c 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1015,8 +1015,8 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( insert, - "WITH upsert AS (UPDATE orders SET amount=:amount, " - "product=:product, quantity=:quantity " + "WITH upsert AS (UPDATE orders SET amount=:param_5, " + "product=:param_6, quantity=:param_7 " "WHERE orders.region = :region_1 " "RETURNING orders.region, orders.amount, " "orders.product, orders.quantity) " @@ -1025,6 +1025,16 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): ":param_3 AS anon_3, :param_4 AS anon_4 WHERE NOT (EXISTS " "(SELECT upsert.region, upsert.amount, upsert.product, " "upsert.quantity FROM upsert))", + checkparams={ + "param_1": "Region1", + "param_2": 1.0, + "param_3": "Product1", + "param_4": 1, + "param_5": 1.0, + "param_6": "Product1", + "param_7": 1, + "region_1": "Region1", + }, ) eq_(insert.compile().isinsert, True) @@ -1106,9 +1116,10 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( stmt.select(), - "WITH anon_1 AS (UPDATE orders SET region=:region " + "WITH anon_1 AS (UPDATE orders SET region=:param_1 " "WHERE orders.region = :region_1 RETURNING orders.region) " "SELECT anon_1.region FROM anon_1", + checkparams={"param_1": "y", "region_1": "x"}, ) eq_(stmt.select().compile().isupdate, False) @@ -1122,8 +1133,9 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( stmt.select(), "WITH anon_1 AS (INSERT INTO orders (region) " - "VALUES (:region) RETURNING orders.region) " + "VALUES (:param_1) RETURNING orders.region) " "SELECT anon_1.region FROM anon_1", + checkparams={"param_1": "y"}, ) eq_(stmt.select().compile().isinsert, False) @@ -1196,10 +1208,11 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( stmt, "WITH t AS " - "(UPDATE products SET price=:price " + "(UPDATE products SET price=:param_1 " "RETURNING products.id, products.price) " "SELECT t.id, t.price " "FROM t", + checkparams={"param_1": "someprice"}, ) eq_(stmt.compile().isupdate, False) @@ -1257,10 +1270,11 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( stmt, "WITH pd AS " - "(INSERT INTO products (id, price) VALUES (:id, :price) " + "(INSERT INTO products (id, price) VALUES (:param_1, :param_2) " "RETURNING products.id, products.price) " "SELECT pd.id, pd.price " "FROM pd", + checkparams={"param_1": 1, "param_2": 27.0}, ) eq_(stmt.compile().isinsert, False) @@ -1353,6 +1367,132 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): ) eq_(stmt.compile().isdelete, True) + def test_select_uses_independent_cte(self): + products = table("products", column("id"), column("price")) + + upd_cte = ( + products.update().values(price=10).where(products.c.price > 50) + ).cte() + + stmt = products.select().where(products.c.price < 45).add_cte(upd_cte) + + self.assert_compile( + stmt, + "WITH anon_1 AS (UPDATE products SET price=:param_1 " + "WHERE products.price > :price_1) " + "SELECT products.id, products.price " + "FROM products WHERE products.price < :price_2", + checkparams={"param_1": 10, "price_1": 50, "price_2": 45}, + ) + + def test_insert_uses_independent_cte(self): + products = table("products", column("id"), column("price")) + + upd_cte = ( + products.update().values(price=10).where(products.c.price > 50) + ).cte() + + stmt = ( + products.insert().values({"id": 1, "price": 20}).add_cte(upd_cte) + ) + + self.assert_compile( + stmt, + "WITH anon_1 AS (UPDATE products SET price=:param_1 " + "WHERE products.price > :price_1) " + "INSERT INTO products (id, price) VALUES (:id, :price)", + checkparams={"id": 1, "price": 20, "param_1": 10, "price_1": 50}, + ) + + def test_update_uses_independent_cte(self): + products = table("products", column("id"), column("price")) + + upd_cte = ( + products.update().values(price=10).where(products.c.price > 50) + ).cte() + + stmt = ( + products.update() + .values(price=5) + .where(products.c.price < 50) + .add_cte(upd_cte) + ) + + self.assert_compile( + stmt, + "WITH anon_1 AS (UPDATE products SET price=:param_1 " + "WHERE products.price > :price_1) UPDATE products " + "SET price=:price WHERE products.price < :price_2", + checkparams={ + "param_1": 10, + "price": 5, + "price_1": 50, + "price_2": 50, + }, + ) + + def test_update_w_insert_independent_cte(self): + products = table("products", column("id"), column("price")) + + ins_cte = (products.insert().values({"id": 1, "price": 10})).cte() + + stmt = ( + products.update() + .values(price=5) + .where(products.c.price < 50) + .add_cte(ins_cte) + ) + + self.assert_compile( + stmt, + "WITH anon_1 AS (INSERT INTO products (id, price) " + "VALUES (:param_1, :param_2)) " + "UPDATE products SET price=:price WHERE products.price < :price_1", + checkparams={ + "price": 5, + "param_1": 1, + "param_2": 10, + "price_1": 50, + }, + ) + + def test_delete_uses_independent_cte(self): + products = table("products", column("id"), column("price")) + + upd_cte = ( + products.update().values(price=10).where(products.c.price > 50) + ).cte() + + stmt = products.delete().where(products.c.price < 45).add_cte(upd_cte) + + self.assert_compile( + stmt, + "WITH anon_1 AS (UPDATE products SET price=:param_1 " + "WHERE products.price > :price_1) " + "DELETE FROM products WHERE products.price < :price_2", + checkparams={"param_1": 10, "price_1": 50, "price_2": 45}, + ) + + def test_independent_cte_can_be_referenced(self): + products = table("products", column("id"), column("price")) + + cte = products.select().cte("pd") + + stmt = ( + products.update() + .where(products.c.price == cte.c.price) + .add_cte(cte) + ) + + self.assert_compile( + stmt, + "WITH pd AS " + "(SELECT products.id AS id, products.price AS price " + "FROM products) " + "UPDATE products SET id=:id, price=:price FROM pd " + "WHERE products.price = pd.price", + ) + def test_standalone_function(self): a = table("a", column("x")) a_stmt = select(a)