From: Mike Bayer Date: Thu, 11 Feb 2016 17:12:19 +0000 (-0500) Subject: - CTE functionality has been expanded to support all DML, allowing X-Git-Tag: rel_1_1_0b1~98^2~27 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e5f1a3fb7dc1888ed187fdeae8171e4ff322dab6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - CTE functionality has been expanded to support all DML, allowing INSERT, UPDATE, and DELETE statements to both specify their own WITH clause, as well as for these statements themselves to be CTE expressions when they include a RETURNING clause. fixes #2551 --- diff --git a/doc/build/changelog/changelog_11.rst b/doc/build/changelog/changelog_11.rst index 2473a02a2c..273bffb83d 100644 --- a/doc/build/changelog/changelog_11.rst +++ b/doc/build/changelog/changelog_11.rst @@ -21,6 +21,19 @@ .. changelog:: :version: 1.1.0b1 + .. change:: + :tags: feature, sql + :tickets: 2551 + + CTE functionality has been expanded to support all DML, allowing + INSERT, UPDATE, and DELETE statements to both specify their own + WITH clause, as well as for these statements themselves to be + CTE expressions when they include a RETURNING clause. + + .. seealso:: + + :ref:`change_2551` + .. change:: :tags: bug, orm :tickets: 3641 diff --git a/doc/build/changelog/migration_11.rst b/doc/build/changelog/migration_11.rst index 3d65ede80d..7eb8e800f3 100644 --- a/doc/build/changelog/migration_11.rst +++ b/doc/build/changelog/migration_11.rst @@ -529,6 +529,65 @@ remains unchanged. New Features and Improvements - Core ==================================== +.. _change_2551: + +CTE Support for INSERT, UPDATE, DELETE +-------------------------------------- + +One of the most widely requested features is support for common table +expressions (CTE) that work with INSERT, UPDATE, DELETE, and is now implemented. +An INSERT/UPDATE/DELETE can both draw from a WITH clause that's stated at the +top of the SQL, as well as can be used as a CTE itself in the context of +a larger statement. + +As part of this change, an INSERT from SELECT that includes a CTE will now +render the CTE at the top of the entire statement, rather than nested +in the SELECT statement as was the case in 1.0. + +Below is an example that renders UPDATE, INSERT and SELECT all in one +statement:: + + >>> from sqlalchemy import table, column, select, literal, exists + >>> orders = table( + ... 'orders', + ... column('region'), + ... column('amount'), + ... column('product'), + ... column('quantity') + ... ) + >>> + >>> upsert = ( + ... orders.update() + ... .where(orders.c.region == 'Region1') + ... .values(amount=1.0, product='Product1', quantity=1) + ... .returning(*(orders.c._all_columns)).cte('upsert')) + >>> + >>> insert = orders.insert().from_select( + ... orders.c.keys(), + ... select([ + ... literal('Region1'), literal(1.0), + ... literal('Product1'), literal(1) + ... ]).where(~exists(upsert.select())) + ... ) + >>> + >>> print(insert) # note formatting added for clarity + WITH upsert AS + (UPDATE orders SET amount=:amount, product=:product, quantity=:quantity + WHERE orders.region = :region_1 + RETURNING orders.region, orders.amount, orders.product, orders.quantity + ) + INSERT INTO orders (region, amount, product, quantity) + SELECT + :param_1 AS anon_1, :param_2 AS anon_2, + :param_3 AS anon_3, :param_4 AS anon_4 + WHERE NOT ( + EXISTS ( + SELECT upsert.region, upsert.amount, + upsert.product, upsert.quantity + FROM upsert)) + +:ticket:`2551` + .. _change_3216: The ``.autoincrement`` directive is no longer implicitly enabled for a composite primary key column diff --git a/doc/build/core/selectable.rst b/doc/build/core/selectable.rst index e73ce7b64c..a582ab4dc5 100644 --- a/doc/build/core/selectable.rst +++ b/doc/build/core/selectable.rst @@ -57,6 +57,9 @@ elements are themselves :class:`.ColumnElement` subclasses). :members: :inherited-members: +.. autoclass:: HasCTE + :members: + .. autoclass:: HasPrefixes :members: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 8a25f570af..ad7b9130bf 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -478,8 +478,6 @@ class Query(object): """Return the full SELECT statement represented by this :class:`.Query` represented as a common table expression (CTE). - .. versionadded:: 0.7.6 - Parameters and usage are the same as those of the :meth:`.SelectBase.cte` method; see that method for further details. @@ -528,7 +526,7 @@ class Query(object): .. seealso:: - :meth:`.SelectBase.cte` + :meth:`.HasCTE.cte` """ return self.enable_eagerloads(False).\ diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index cc9a49a914..a2fc0fe68a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -418,6 +418,11 @@ class SQLCompiler(Compiled): self.truncated_names = {} Compiled.__init__(self, dialect, statement, **kwargs) + if ( + self.isinsert or self.isupdate or self.isdelete + ) and statement._returning: + self.returning = statement._returning + if self.positional and dialect.paramstyle == 'numeric': self._apply_numbered_params() @@ -1659,7 +1664,7 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes and self._is_toplevel_select(select): + if self.ctes and toplevel: text = self._render_cte_clause() + text if select._suffixes: @@ -1673,20 +1678,6 @@ class SQLCompiler(Compiled): else: return text - def _is_toplevel_select(self, select): - """Return True if the stack is placed at the given select, and - is also the outermost SELECT, meaning there is either no stack - before this one, or the enclosing stack is a topmost INSERT. - - """ - return ( - self.stack[-1]['selectable'] is select and - ( - len(self.stack) == 1 or self.isinsert and len(self.stack) == 2 - and self.statement is self.stack[0]['selectable'] - ) - ) - def _setup_select_hints(self, select): byfrom = dict([ (from_, hinttext % { @@ -1876,14 +1867,16 @@ class SQLCompiler(Compiled): ) return dialect_hints, table_text - def visit_insert(self, insert_stmt, **kw): + def visit_insert(self, insert_stmt, asfrom=False, **kw): + toplevel = not self.stack + self.stack.append( {'correlate_froms': set(), "asfrom_froms": set(), "selectable": insert_stmt}) - self.isinsert = True - crud_params = crud._get_crud_params(self, insert_stmt, **kw) + crud_params = crud._setup_crud_params( + self, insert_stmt, crud.ISINSERT, **kw) if not crud_params and \ not self.dialect.supports_default_values and \ @@ -1929,12 +1922,13 @@ class SQLCompiler(Compiled): for c in crud_params_single]) if self.returning or insert_stmt._returning: - self.returning = self.returning or insert_stmt._returning returning_clause = self.returning_clause( - insert_stmt, self.returning) + insert_stmt, self.returning or insert_stmt._returning) if self.returning_precedes_values: text += " " + returning_clause + else: + returning_clause = None if insert_stmt.select is not None: text += " %s" % self.process(self._insert_from_select, **kw) @@ -1953,12 +1947,18 @@ class SQLCompiler(Compiled): text += " VALUES (%s)" % \ ', '.join([c[1] for c in crud_params]) - if self.returning and not self.returning_precedes_values: + if returning_clause and not self.returning_precedes_values: text += " " + returning_clause + if self.ctes and toplevel: + text = self._render_cte_clause() + text + self.stack.pop(-1) - return text + if asfrom: + return "(" + text + ")" + else: + return text def update_limit_clause(self, update_stmt): """Provide a hook for MySQL to add LIMIT to the UPDATE""" @@ -1972,8 +1972,8 @@ class SQLCompiler(Compiled): MySQL overrides this. """ - return from_table._compiler_dispatch(self, asfrom=True, - iscrud=True, **kw) + kw['asfrom'] = True + return from_table._compiler_dispatch(self, iscrud=True, **kw) def update_from_clause(self, update_stmt, from_table, extra_froms, @@ -1990,14 +1990,14 @@ class SQLCompiler(Compiled): fromhints=from_hints, **kw) for t in extra_froms) - def visit_update(self, update_stmt, **kw): + def visit_update(self, update_stmt, asfrom=False, **kw): + toplevel = not self.stack + self.stack.append( {'correlate_froms': set([update_stmt.table]), "asfrom_froms": set([update_stmt.table]), "selectable": update_stmt}) - self.isupdate = True - extra_froms = update_stmt._extra_froms text = "UPDATE " @@ -2009,7 +2009,8 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - crud_params = crud._get_crud_params(self, update_stmt, **kw) + crud_params = crud._setup_crud_params( + self, update_stmt, crud.ISUPDATE, **kw) if update_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( @@ -2029,11 +2030,9 @@ class SQLCompiler(Compiled): ) if self.returning or update_stmt._returning: - if not self.returning: - self.returning = update_stmt._returning if self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning) + update_stmt, self.returning or update_stmt._returning) if extra_froms: extra_from_text = self.update_from_clause( @@ -2053,23 +2052,33 @@ class SQLCompiler(Compiled): if limit_clause: text += " " + limit_clause - if self.returning and not self.returning_precedes_values: + if (self.returning or update_stmt._returning) and \ + not self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning) + update_stmt, self.returning or update_stmt._returning) + + if self.ctes and toplevel: + text = self._render_cte_clause() + text self.stack.pop(-1) - return text + if asfrom: + return "(" + text + ")" + else: + return text @util.memoized_property def _key_getters_for_crud_column(self): - return crud._key_getters_for_crud_column(self) + return crud._key_getters_for_crud_column(self, self.statement) + + def visit_delete(self, delete_stmt, asfrom=False, **kw): + toplevel = not self.stack - def visit_delete(self, delete_stmt, **kw): self.stack.append({'correlate_froms': set([delete_stmt.table]), "asfrom_froms": set([delete_stmt.table]), "selectable": delete_stmt}) - self.isdelete = True + + crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw) text = "DELETE " @@ -2088,7 +2097,6 @@ class SQLCompiler(Compiled): text += table_text if delete_stmt._returning: - self.returning = delete_stmt._returning if self.returning_precedes_values: text += " " + self.returning_clause( delete_stmt, delete_stmt._returning) @@ -2098,13 +2106,19 @@ class SQLCompiler(Compiled): if t: text += " WHERE " + t - if self.returning and not self.returning_precedes_values: + if delete_stmt._returning and not self.returning_precedes_values: text += " " + self.returning_clause( delete_stmt, delete_stmt._returning) + if self.ctes and toplevel: + text = self._render_cte_clause() + text + self.stack.pop(-1) - return text + if asfrom: + return "(" + text + ")" + else: + return text def visit_savepoint(self, savepoint_stmt): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index a01b72e61b..58cd80995e 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -25,6 +25,41 @@ values present. """) +ISINSERT = util.symbol('ISINSERT') +ISUPDATE = util.symbol('ISUPDATE') +ISDELETE = util.symbol('ISDELETE') + + +def _setup_crud_params(compiler, stmt, local_stmt_type, **kw): + restore_isinsert = compiler.isinsert + restore_isupdate = compiler.isupdate + restore_isdelete = compiler.isdelete + + should_restore = ( + restore_isinsert or restore_isupdate or restore_isdelete + ) or len(compiler.stack) > 1 + + if local_stmt_type is ISINSERT: + compiler.isupdate = False + compiler.isinsert = True + elif local_stmt_type is ISUPDATE: + compiler.isupdate = True + compiler.isinsert = False + elif local_stmt_type is ISDELETE: + if not should_restore: + compiler.isdelete = True + else: + assert False, "ISINSERT, ISUPDATE, or ISDELETE expected" + + try: + if local_stmt_type in (ISINSERT, ISUPDATE): + return _get_crud_params(compiler, stmt, **kw) + finally: + if should_restore: + compiler.isinsert = restore_isinsert + compiler.isupdate = restore_isupdate + compiler.isdelete = restore_isdelete + def _get_crud_params(compiler, stmt, **kw): """create a set of tuples representing column/string pairs for use @@ -59,7 +94,7 @@ def _get_crud_params(compiler, stmt, **kw): # but in the case of mysql multi-table update, the rules for # .key must conditionally take tablename into account _column_as_key, _getattr_col_key, _col_bind_name = \ - _key_getters_for_crud_column(compiler) + _key_getters_for_crud_column(compiler, stmt) # if we have statement parameters - set defaults in the # compiled params @@ -128,15 +163,15 @@ def _create_bind_param( return bindparam -def _key_getters_for_crud_column(compiler): - if compiler.isupdate and compiler.statement._extra_froms: +def _key_getters_for_crud_column(compiler, stmt): + if compiler.isupdate and stmt._extra_froms: # when extra tables are present, refer to the columns # in those extra tables as table-qualified, including in # dictionaries and when rendering bind param names. # the "main" table of the statement remains unqualified, # allowing the most compatibility with a non-multi-table # statement. - _et = set(compiler.statement._extra_froms) + _et = set(stmt._extra_froms) def _column_as_key(key): str_key = elements._column_as_key(key) @@ -609,7 +644,9 @@ def _get_returning_modifiers(compiler, stmt): stmt.table.implicit_returning and stmt._return_defaults) else: - implicit_return_defaults = False + # this line is unused, currently we are always + # isinsert or isupdate + implicit_return_defaults = False # pragma: no cover if implicit_return_defaults: if stmt._return_defaults is True: diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 7b506f9db9..8f368dcdbb 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -9,15 +9,18 @@ Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`. """ -from .base import Executable, _generative, _from_objects, DialectKWArgs +from .base import Executable, _generative, _from_objects, DialectKWArgs, \ + ColumnCollection from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \ _column_as_key -from .selectable import _interpret_as_from, _interpret_as_select, HasPrefixes +from .selectable import _interpret_as_from, _interpret_as_select, \ + HasPrefixes, HasCTE from .. import util from .. import exc -class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement): +class UpdateBase( + HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements. """ diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index a63eac2f83..36f7f7fe12 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -46,8 +46,8 @@ from .base import ColumnCollection, Generative, Executable, \ from .selectable import Alias, Join, Select, Selectable, TableClause, \ CompoundSelect, CTE, FromClause, FromGrouping, SelectBase, \ - alias, GenerativeSelect, \ - subquery, HasPrefixes, HasSuffixes, Exists, ScalarSelect, TextAsFrom + alias, GenerativeSelect, subquery, HasCTE, HasPrefixes, HasSuffixes, \ + Exists, ScalarSelect, TextAsFrom from .dml import Insert, Update, Delete, UpdateBase, ValuesBase diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index c3906c2f2f..fcd22a786c 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1195,6 +1195,15 @@ class CTE(Generative, HasSuffixes, Alias): self._suffixes = _suffixes super(CTE, self).__init__(selectable, name=name) + @util.dependencies("sqlalchemy.sql.dml") + def _populate_column_collection(self, dml): + if isinstance(self.element, dml.UpdateBase): + for col in self.element._returning: + col._make_proxy(self) + else: + for col in self.element.columns._all_columns: + col._make_proxy(self) + def alias(self, name=None, flat=False): return CTE( self.original, @@ -1223,6 +1232,164 @@ class CTE(Generative, HasSuffixes, Alias): ) +class HasCTE(object): + """Mixin that declares a class to include CTE support. + + .. versionadded:: 1.1 + + """ + + def cte(self, name=None, recursive=False): + """Return a new :class:`.CTE`, or Common Table Expression instance. + + Common table expressions are a SQL standard whereby SELECT + statements can draw upon secondary statements specified along + with the primary statement, using a clause called "WITH". + Special semantics regarding UNION can also be employed to + allow "recursive" queries, where a SELECT statement can draw + upon the set of rows that have previously been selected. + + CTEs can also be applied to DML constructs UPDATE, INSERT + and DELETE on some databases, both as a source of CTE rows + when combined with RETURNING, as well as a consumer of + CTE rows. + + SQLAlchemy detects :class:`.CTE` objects, which are treated + similarly to :class:`.Alias` objects, as special elements + to be delivered to the FROM clause of the statement as well + as to a WITH clause at the top of the statement. + + .. versionchanged:: 1.1 Added support for UPDATE/INSERT/DELETE as + CTE, CTEs added to UPDATE/INSERT/DELETE. + + :param name: name given to the common table expression. Like + :meth:`._FromClause.alias`, the name can be left as ``None`` + in which case an anonymous symbol will be used at query + compile time. + :param recursive: if ``True``, will render ``WITH RECURSIVE``. + A recursive common table expression is intended to be used in + conjunction with UNION ALL in order to derive rows + from those already selected. + + The following examples include two from Postgresql's documentation at + http://www.postgresql.org/docs/current/static/queries-with.html, + as well as additional examples. + + Example 1, non recursive:: + + from sqlalchemy import (Table, Column, String, Integer, + MetaData, select, func) + + metadata = MetaData() + + orders = Table('orders', metadata, + Column('region', String), + Column('amount', Integer), + Column('product', String), + Column('quantity', Integer) + ) + + regional_sales = select([ + orders.c.region, + func.sum(orders.c.amount).label('total_sales') + ]).group_by(orders.c.region).cte("regional_sales") + + + top_regions = select([regional_sales.c.region]).\\ + where( + regional_sales.c.total_sales > + select([ + func.sum(regional_sales.c.total_sales)/10 + ]) + ).cte("top_regions") + + statement = select([ + orders.c.region, + orders.c.product, + func.sum(orders.c.quantity).label("product_units"), + func.sum(orders.c.amount).label("product_sales") + ]).where(orders.c.region.in_( + select([top_regions.c.region]) + )).group_by(orders.c.region, orders.c.product) + + result = conn.execute(statement).fetchall() + + Example 2, WITH RECURSIVE:: + + from sqlalchemy import (Table, Column, String, Integer, + MetaData, select, func) + + metadata = MetaData() + + parts = Table('parts', metadata, + Column('part', String), + Column('sub_part', String), + Column('quantity', Integer), + ) + + included_parts = select([ + parts.c.sub_part, + parts.c.part, + parts.c.quantity]).\\ + where(parts.c.part=='our part').\\ + cte(recursive=True) + + + incl_alias = included_parts.alias() + parts_alias = parts.alias() + included_parts = included_parts.union_all( + select([ + parts_alias.c.sub_part, + parts_alias.c.part, + parts_alias.c.quantity + ]). + where(parts_alias.c.part==incl_alias.c.sub_part) + ) + + statement = select([ + included_parts.c.sub_part, + func.sum(included_parts.c.quantity). + label('total_quantity') + ]).\\ + group_by(included_parts.c.sub_part) + + result = conn.execute(statement).fetchall() + + Example 3, an upsert using UPDATE and INSERT with CTEs:: + + orders = table( + 'orders', + column('region'), + column('amount'), + column('product'), + column('quantity') + ) + + upsert = ( + orders.update() + .where(orders.c.region == 'Region1') + .values(amount=1.0, product='Product1', quantity=1) + .returning(*(orders.c._all_columns)).cte('upsert')) + + insert = orders.insert().from_select( + orders.c.keys(), + select([ + literal('Region1'), literal(1.0), + literal('Product1'), literal(1) + ).where(exists(upsert.select())) + ) + + connection.execute(insert) + + .. seealso:: + + :meth:`.orm.query.Query.cte` - ORM version of + :meth:`.HasCTE.cte`. + + """ + return CTE(self, name=name, recursive=recursive) + + class FromGrouping(FromClause): """Represent a grouping of a FROM clause""" __visit_name__ = 'grouping' @@ -1497,7 +1664,7 @@ class ForUpdateArg(ClauseElement): self.of = None -class SelectBase(Executable, FromClause): +class SelectBase(HasCTE, Executable, FromClause): """Base class for SELECT statements. @@ -1531,125 +1698,6 @@ class SelectBase(Executable, FromClause): """ return self.as_scalar().label(name) - def cte(self, name=None, recursive=False): - """Return a new :class:`.CTE`, or Common Table Expression instance. - - Common table expressions are a SQL standard whereby SELECT - statements can draw upon secondary statements specified along - with the primary statement, using a clause called "WITH". - Special semantics regarding UNION can also be employed to - allow "recursive" queries, where a SELECT statement can draw - upon the set of rows that have previously been selected. - - SQLAlchemy detects :class:`.CTE` objects, which are treated - similarly to :class:`.Alias` objects, as special elements - to be delivered to the FROM clause of the statement as well - as to a WITH clause at the top of the statement. - - .. versionadded:: 0.7.6 - - :param name: name given to the common table expression. Like - :meth:`._FromClause.alias`, the name can be left as ``None`` - in which case an anonymous symbol will be used at query - compile time. - :param recursive: if ``True``, will render ``WITH RECURSIVE``. - A recursive common table expression is intended to be used in - conjunction with UNION ALL in order to derive rows - from those already selected. - - The following examples illustrate two examples from - Postgresql's documentation at - http://www.postgresql.org/docs/8.4/static/queries-with.html. - - Example 1, non recursive:: - - from sqlalchemy import (Table, Column, String, Integer, - MetaData, select, func) - - metadata = MetaData() - - orders = Table('orders', metadata, - Column('region', String), - Column('amount', Integer), - Column('product', String), - Column('quantity', Integer) - ) - - regional_sales = select([ - orders.c.region, - func.sum(orders.c.amount).label('total_sales') - ]).group_by(orders.c.region).cte("regional_sales") - - - top_regions = select([regional_sales.c.region]).\\ - where( - regional_sales.c.total_sales > - select([ - func.sum(regional_sales.c.total_sales)/10 - ]) - ).cte("top_regions") - - statement = select([ - orders.c.region, - orders.c.product, - func.sum(orders.c.quantity).label("product_units"), - func.sum(orders.c.amount).label("product_sales") - ]).where(orders.c.region.in_( - select([top_regions.c.region]) - )).group_by(orders.c.region, orders.c.product) - - result = conn.execute(statement).fetchall() - - Example 2, WITH RECURSIVE:: - - from sqlalchemy import (Table, Column, String, Integer, - MetaData, select, func) - - metadata = MetaData() - - parts = Table('parts', metadata, - Column('part', String), - Column('sub_part', String), - Column('quantity', Integer), - ) - - included_parts = select([ - parts.c.sub_part, - parts.c.part, - parts.c.quantity]).\\ - where(parts.c.part=='our part').\\ - cte(recursive=True) - - - incl_alias = included_parts.alias() - parts_alias = parts.alias() - included_parts = included_parts.union_all( - select([ - parts_alias.c.sub_part, - parts_alias.c.part, - parts_alias.c.quantity - ]). - where(parts_alias.c.part==incl_alias.c.sub_part) - ) - - statement = select([ - included_parts.c.sub_part, - func.sum(included_parts.c.quantity). - label('total_quantity') - ]).\\ - group_by(included_parts.c.sub_part) - - result = conn.execute(statement).fetchall() - - - .. seealso:: - - :meth:`.orm.query.Query.cte` - ORM version of - :meth:`.SelectBase.cte`. - - """ - return CTE(self, name=name, recursive=recursive) - @_generative @util.deprecated('0.6', message="``autocommit()`` is deprecated. Use " diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index bb5a962561..21f9f68fb8 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -296,6 +296,8 @@ class AssertsCompiledSQL(object): dialect = config.db.dialect elif dialect == 'default': dialect = default.DefaultDialect() + elif dialect == 'default_enhanced': + dialect = default.StrCompileDialect() elif isinstance(dialect, util.string_types): dialect = url.URL(dialect).get_dialect()() diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index b59914afce..aa674403e6 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1,6 +1,6 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import AssertsCompiledSQL, assert_raises_message -from sqlalchemy.sql import table, column, select, func, literal +from sqlalchemy.sql import table, column, select, func, literal, exists, and_ from sqlalchemy.dialects import mssql from sqlalchemy.engine import default from sqlalchemy.exc import CompileError @@ -8,7 +8,7 @@ from sqlalchemy.exc import CompileError class CTETest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = 'default_enhanced' def test_nonrecursive(self): orders = table('orders', @@ -492,3 +492,151 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): 'regional_sales WHERE "order"."order" > regional_sales."order"', dialect='postgresql' ) + + def test_upsert_from_select(self): + orders = table( + 'orders', + column('region'), + column('amount'), + column('product'), + column('quantity') + ) + + upsert = ( + orders.update() + .where(orders.c.region == 'Region1') + .values(amount=1.0, product='Product1', quantity=1) + .returning(*(orders.c._all_columns)).cte('upsert')) + + insert = orders.insert().from_select( + orders.c.keys(), + select([ + literal('Region1'), literal(1.0), + literal('Product1'), literal(1) + ]).where(~exists(upsert.select())) + ) + + self.assert_compile( + insert, + "WITH upsert AS (UPDATE orders SET amount=:amount, " + "product=:product, quantity=:quantity " + "WHERE orders.region = :region_1 " + "RETURNING orders.region, orders.amount, " + "orders.product, orders.quantity) " + "INSERT INTO orders (region, amount, product, quantity) " + "SELECT :param_1 AS anon_1, :param_2 AS anon_2, " + ":param_3 AS anon_3, :param_4 AS anon_4 WHERE NOT (EXISTS " + "(SELECT upsert.region, upsert.amount, upsert.product, " + "upsert.quantity FROM upsert))" + ) + + def test_pg_example_one(self): + products = table('products', column('id'), column('date')) + products_log = table('products_log', column('id'), column('date')) + + moved_rows = products.delete().where(and_( + products.c.date >= 'dateone', + products.c.date < 'datetwo')).returning(*products.c).\ + cte('moved_rows') + + stmt = products_log.insert().from_select( + products_log.c, moved_rows.select()) + self.assert_compile( + stmt, + "WITH moved_rows AS " + "(DELETE FROM products WHERE products.date >= :date_1 " + "AND products.date < :date_2 " + "RETURNING products.id, products.date) " + "INSERT INTO products_log (id, date) " + "SELECT moved_rows.id, moved_rows.date FROM moved_rows" + ) + + def test_pg_example_two(self): + products = table('products', column('id'), column('price')) + + t = products.update().values(price='someprice').\ + returning(*products.c).cte('t') + stmt = t.select() + + self.assert_compile( + stmt, + "WITH t AS " + "(UPDATE products SET price=:price " + "RETURNING products.id, products.price) " + "SELECT t.id, t.price " + "FROM t" + ) + + def test_pg_example_three(self): + + parts = table( + 'parts', + column('part'), + column('sub_part'), + ) + + included_parts = select([ + parts.c.sub_part, + parts.c.part]).\ + where(parts.c.part == 'our part').\ + cte("included_parts", recursive=True) + + pr = included_parts.alias('pr') + p = parts.alias('p') + included_parts = included_parts.union_all( + select([ + p.c.sub_part, + p.c.part]). + where(p.c.part == pr.c.sub_part) + ) + stmt = parts.delete().where( + parts.c.part.in_(select([included_parts.c.part]))).returning( + parts.c.part) + + # the outer RETURNING is a bonus over what PG's docs have + self.assert_compile( + stmt, + "WITH RECURSIVE included_parts(sub_part, part) AS " + "(SELECT parts.sub_part AS sub_part, parts.part AS part " + "FROM parts " + "WHERE parts.part = :part_1 " + "UNION ALL SELECT p.sub_part AS sub_part, p.part AS part " + "FROM parts AS p, included_parts AS pr " + "WHERE p.part = pr.sub_part) " + "DELETE FROM parts WHERE parts.part IN " + "(SELECT included_parts.part FROM included_parts) " + "RETURNING parts.part" + ) + + def test_insert_in_the_cte(self): + products = table('products', column('id'), column('price')) + + cte = products.insert().values(id=1, price=27.0).\ + returning(*products.c).cte('pd') + + stmt = select([cte]) + + self.assert_compile( + stmt, + "WITH pd AS " + "(INSERT INTO products (id, price) VALUES (:id, :price) " + "RETURNING products.id, products.price) " + "SELECT pd.id, pd.price " + "FROM pd" + ) + + def test_update_pulls_from_cte(self): + products = table('products', column('id'), column('price')) + + cte = products.select().cte('pd') + + stmt = products.update().where(products.c.price == cte.c.price) + + self.assert_compile( + stmt, + "WITH pd AS " + "(SELECT products.id AS id, products.price AS price " + "FROM products) " + "UPDATE products SET id=:id, price=:price FROM pd " + "WHERE products.price = pd.price" + ) diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index ea4de032c8..513757d5be 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -188,9 +188,10 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): from_select(("otherid", "othername"), sel) self.assert_compile( ins, - "INSERT INTO myothertable (otherid, othername) WITH anon_1 AS " + "WITH anon_1 AS " "(SELECT mytable.name AS name FROM mytable " "WHERE mytable.name = :name_1) " + "INSERT INTO myothertable (otherid, othername) " "SELECT mytable.myid, mytable.name FROM mytable, anon_1 " "WHERE mytable.name = anon_1.name", checkparams={"name_1": "bar"} @@ -205,9 +206,9 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): self.assert_compile( ins, - "INSERT INTO mytable (myid, name, description) " "WITH c AS (SELECT mytable.myid AS myid, mytable.name AS name, " "mytable.description AS description FROM mytable) " + "INSERT INTO mytable (myid, name, description) " "SELECT c.myid, c.name, c.description FROM c" )