]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- CTE functionality has been expanded to support all DML, allowing
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Feb 2016 17:12:19 +0000 (12:12 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Feb 2016 17:27:28 +0000 (12:27 -0500)
INSERT, UPDATE, and DELETE statements to both specify their own
WITH clause, as well as for these statements themselves to be
CTE expressions when they include a RETURNING clause.
fixes #2551

12 files changed:
doc/build/changelog/changelog_11.rst
doc/build/changelog/migration_11.rst
doc/build/core/selectable.rst
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/testing/assertions.py
test/sql/test_cte.py
test/sql/test_insert.py

index 2473a02a2cba48ab13d3b1871f8ae9dfc8b548e3..273bffb83dff99507dd485e69171a984906a6ecb 100644 (file)
 .. changelog::
     :version: 1.1.0b1
 
+    .. change::
+        :tags: feature, sql
+        :tickets: 2551
+
+        CTE functionality has been expanded to support all DML, allowing
+        INSERT, UPDATE, and DELETE statements to both specify their own
+        WITH clause, as well as for these statements themselves to be
+        CTE expressions when they include a RETURNING clause.
+
+        .. seealso::
+
+            :ref:`change_2551`
+
     .. change::
         :tags: bug, orm
         :tickets: 3641
index 3d65ede80d40afaf8652997f8ddfdc37a1b56cf7..7eb8e800f33a27cf5b859a9628e16932eab3caa1 100644 (file)
@@ -529,6 +529,65 @@ remains unchanged.
 New Features and Improvements - Core
 ====================================
 
+.. _change_2551:
+
+CTE Support for INSERT, UPDATE, DELETE
+--------------------------------------
+
+One of the most widely requested features is support for common table
+expressions (CTE) that work with INSERT, UPDATE, DELETE, and is now implemented.
+An INSERT/UPDATE/DELETE can both draw from a WITH clause that's stated at the
+top of the SQL, as well as can be used as a CTE itself in the context of
+a larger statement.
+
+As part of this change, an INSERT from SELECT that includes a CTE will now
+render the CTE at the top of the entire statement, rather than nested
+in the SELECT statement as was the case in 1.0.
+
+Below is an example that renders UPDATE, INSERT and SELECT all in one
+statement::
+
+    >>> from sqlalchemy import table, column, select, literal, exists
+    >>> orders = table(
+    ...     'orders',
+    ...     column('region'),
+    ...     column('amount'),
+    ...     column('product'),
+    ...     column('quantity')
+    ... )
+    >>>
+    >>> upsert = (
+    ...     orders.update()
+    ...     .where(orders.c.region == 'Region1')
+    ...     .values(amount=1.0, product='Product1', quantity=1)
+    ...     .returning(*(orders.c._all_columns)).cte('upsert'))
+    >>>
+    >>> insert = orders.insert().from_select(
+    ...     orders.c.keys(),
+    ...     select([
+    ...         literal('Region1'), literal(1.0),
+    ...         literal('Product1'), literal(1)
+    ...     ]).where(~exists(upsert.select()))
+    ... )
+    >>>
+    >>> print(insert)  # note formatting added for clarity
+    WITH upsert AS
+    (UPDATE orders SET amount=:amount, product=:product, quantity=:quantity
+     WHERE orders.region = :region_1
+     RETURNING orders.region, orders.amount, orders.product, orders.quantity
+    )
+    INSERT INTO orders (region, amount, product, quantity)
+    SELECT
+        :param_1 AS anon_1, :param_2 AS anon_2,
+        :param_3 AS anon_3, :param_4 AS anon_4
+    WHERE NOT (
+        EXISTS (
+            SELECT upsert.region, upsert.amount,
+                   upsert.product, upsert.quantity
+            FROM upsert))
+
+:ticket:`2551`
+
 .. _change_3216:
 
 The ``.autoincrement`` directive is no longer implicitly enabled for a composite primary key column
index e73ce7b64cf6657a5e6bf4a6b369117675326c66..a582ab4dc58d8ca382aca1ac80c2b208b64e367d 100644 (file)
@@ -57,6 +57,9 @@ elements are themselves :class:`.ColumnElement` subclasses).
    :members:
    :inherited-members:
 
+.. autoclass:: HasCTE
+   :members:
+
 .. autoclass:: HasPrefixes
    :members:
 
index 8a25f570af11d6cb2bf57072b382f4c1ff160686..ad7b9130bfeb85c305b4a7d197f871f900046e80 100644 (file)
@@ -478,8 +478,6 @@ class Query(object):
         """Return the full SELECT statement represented by this
         :class:`.Query` represented as a common table expression (CTE).
 
-        .. versionadded:: 0.7.6
-
         Parameters and usage are the same as those of the
         :meth:`.SelectBase.cte` method; see that method for
         further details.
@@ -528,7 +526,7 @@ class Query(object):
 
         .. seealso::
 
-            :meth:`.SelectBase.cte`
+            :meth:`.HasCTE.cte`
 
         """
         return self.enable_eagerloads(False).\
index cc9a49a914751760d40ed2785bf9195dead23415..a2fc0fe68a77d1eb0b7f9e909c188114356f8f35 100644 (file)
@@ -418,6 +418,11 @@ class SQLCompiler(Compiled):
         self.truncated_names = {}
         Compiled.__init__(self, dialect, statement, **kwargs)
 
+        if (
+                self.isinsert or self.isupdate or self.isdelete
+        ) and statement._returning:
+            self.returning = statement._returning
+
         if self.positional and dialect.paramstyle == 'numeric':
             self._apply_numbered_params()
 
@@ -1659,7 +1664,7 @@ class SQLCompiler(Compiled):
             if per_dialect:
                 text += " " + self.get_statement_hint_text(per_dialect)
 
-        if self.ctes and self._is_toplevel_select(select):
+        if self.ctes and toplevel:
             text = self._render_cte_clause() + text
 
         if select._suffixes:
@@ -1673,20 +1678,6 @@ class SQLCompiler(Compiled):
         else:
             return text
 
-    def _is_toplevel_select(self, select):
-        """Return True if the stack is placed at the given select, and
-        is also the outermost SELECT, meaning there is either no stack
-        before this one, or the enclosing stack is a topmost INSERT.
-
-        """
-        return (
-            self.stack[-1]['selectable'] is select and
-            (
-                len(self.stack) == 1 or self.isinsert and len(self.stack) == 2
-                and self.statement is self.stack[0]['selectable']
-            )
-        )
-
     def _setup_select_hints(self, select):
         byfrom = dict([
             (from_, hinttext % {
@@ -1876,14 +1867,16 @@ class SQLCompiler(Compiled):
             )
         return dialect_hints, table_text
 
-    def visit_insert(self, insert_stmt, **kw):
+    def visit_insert(self, insert_stmt, asfrom=False, **kw):
+        toplevel = not self.stack
+
         self.stack.append(
             {'correlate_froms': set(),
              "asfrom_froms": set(),
              "selectable": insert_stmt})
 
-        self.isinsert = True
-        crud_params = crud._get_crud_params(self, insert_stmt, **kw)
+        crud_params = crud._setup_crud_params(
+            self, insert_stmt, crud.ISINSERT, **kw)
 
         if not crud_params and \
                 not self.dialect.supports_default_values and \
@@ -1929,12 +1922,13 @@ class SQLCompiler(Compiled):
                                          for c in crud_params_single])
 
         if self.returning or insert_stmt._returning:
-            self.returning = self.returning or insert_stmt._returning
             returning_clause = self.returning_clause(
-                insert_stmt, self.returning)
+                insert_stmt, self.returning or insert_stmt._returning)
 
             if self.returning_precedes_values:
                 text += " " + returning_clause
+        else:
+            returning_clause = None
 
         if insert_stmt.select is not None:
             text += " %s" % self.process(self._insert_from_select, **kw)
@@ -1953,12 +1947,18 @@ class SQLCompiler(Compiled):
             text += " VALUES (%s)" % \
                 ', '.join([c[1] for c in crud_params])
 
-        if self.returning and not self.returning_precedes_values:
+        if returning_clause and not self.returning_precedes_values:
             text += " " + returning_clause
 
+        if self.ctes and toplevel:
+            text = self._render_cte_clause() + text
+
         self.stack.pop(-1)
 
-        return text
+        if asfrom:
+            return "(" + text + ")"
+        else:
+            return text
 
     def update_limit_clause(self, update_stmt):
         """Provide a hook for MySQL to add LIMIT to the UPDATE"""
@@ -1972,8 +1972,8 @@ class SQLCompiler(Compiled):
         MySQL overrides this.
 
         """
-        return from_table._compiler_dispatch(self, asfrom=True,
-                                             iscrud=True, **kw)
+        kw['asfrom'] = True
+        return from_table._compiler_dispatch(self, iscrud=True, **kw)
 
     def update_from_clause(self, update_stmt,
                            from_table, extra_froms,
@@ -1990,14 +1990,14 @@ class SQLCompiler(Compiled):
                                  fromhints=from_hints, **kw)
             for t in extra_froms)
 
-    def visit_update(self, update_stmt, **kw):
+    def visit_update(self, update_stmt, asfrom=False, **kw):
+        toplevel = not self.stack
+
         self.stack.append(
             {'correlate_froms': set([update_stmt.table]),
              "asfrom_froms": set([update_stmt.table]),
              "selectable": update_stmt})
 
-        self.isupdate = True
-
         extra_froms = update_stmt._extra_froms
 
         text = "UPDATE "
@@ -2009,7 +2009,8 @@ class SQLCompiler(Compiled):
         table_text = self.update_tables_clause(update_stmt, update_stmt.table,
                                                extra_froms, **kw)
 
-        crud_params = crud._get_crud_params(self, update_stmt, **kw)
+        crud_params = crud._setup_crud_params(
+            self, update_stmt, crud.ISUPDATE, **kw)
 
         if update_stmt._hints:
             dialect_hints, table_text = self._setup_crud_hints(
@@ -2029,11 +2030,9 @@ class SQLCompiler(Compiled):
         )
 
         if self.returning or update_stmt._returning:
-            if not self.returning:
-                self.returning = update_stmt._returning
             if self.returning_precedes_values:
                 text += " " + self.returning_clause(
-                    update_stmt, self.returning)
+                    update_stmt, self.returning or update_stmt._returning)
 
         if extra_froms:
             extra_from_text = self.update_from_clause(
@@ -2053,23 +2052,33 @@ class SQLCompiler(Compiled):
         if limit_clause:
             text += " " + limit_clause
 
-        if self.returning and not self.returning_precedes_values:
+        if (self.returning or update_stmt._returning) and \
+                not self.returning_precedes_values:
             text += " " + self.returning_clause(
-                update_stmt, self.returning)
+                update_stmt, self.returning or update_stmt._returning)
+
+        if self.ctes and toplevel:
+            text = self._render_cte_clause() + text
 
         self.stack.pop(-1)
 
-        return text
+        if asfrom:
+            return "(" + text + ")"
+        else:
+            return text
 
     @util.memoized_property
     def _key_getters_for_crud_column(self):
-        return crud._key_getters_for_crud_column(self)
+        return crud._key_getters_for_crud_column(self, self.statement)
+
+    def visit_delete(self, delete_stmt, asfrom=False, **kw):
+        toplevel = not self.stack
 
-    def visit_delete(self, delete_stmt, **kw):
         self.stack.append({'correlate_froms': set([delete_stmt.table]),
                            "asfrom_froms": set([delete_stmt.table]),
                            "selectable": delete_stmt})
-        self.isdelete = True
+
+        crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw)
 
         text = "DELETE "
 
@@ -2088,7 +2097,6 @@ class SQLCompiler(Compiled):
         text += table_text
 
         if delete_stmt._returning:
-            self.returning = delete_stmt._returning
             if self.returning_precedes_values:
                 text += " " + self.returning_clause(
                     delete_stmt, delete_stmt._returning)
@@ -2098,13 +2106,19 @@ class SQLCompiler(Compiled):
             if t:
                 text += " WHERE " + t
 
-        if self.returning and not self.returning_precedes_values:
+        if delete_stmt._returning and not self.returning_precedes_values:
             text += " " + self.returning_clause(
                 delete_stmt, delete_stmt._returning)
 
+        if self.ctes and toplevel:
+            text = self._render_cte_clause() + text
+
         self.stack.pop(-1)
 
-        return text
+        if asfrom:
+            return "(" + text + ")"
+        else:
+            return text
 
     def visit_savepoint(self, savepoint_stmt):
         return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
index a01b72e61b4edb097bbcdccb9cc793bf4658e894..58cd80995e9a458161c79155c32c7d0f348759a6 100644 (file)
@@ -25,6 +25,41 @@ values present.
 
 """)
 
+ISINSERT = util.symbol('ISINSERT')
+ISUPDATE = util.symbol('ISUPDATE')
+ISDELETE = util.symbol('ISDELETE')
+
+
+def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
+    restore_isinsert = compiler.isinsert
+    restore_isupdate = compiler.isupdate
+    restore_isdelete = compiler.isdelete
+
+    should_restore = (
+        restore_isinsert or restore_isupdate or restore_isdelete
+    ) or len(compiler.stack) > 1
+
+    if local_stmt_type is ISINSERT:
+        compiler.isupdate = False
+        compiler.isinsert = True
+    elif local_stmt_type is ISUPDATE:
+        compiler.isupdate = True
+        compiler.isinsert = False
+    elif local_stmt_type is ISDELETE:
+        if not should_restore:
+            compiler.isdelete = True
+    else:
+        assert False, "ISINSERT, ISUPDATE, or ISDELETE expected"
+
+    try:
+        if local_stmt_type in (ISINSERT, ISUPDATE):
+            return _get_crud_params(compiler, stmt, **kw)
+    finally:
+        if should_restore:
+            compiler.isinsert = restore_isinsert
+            compiler.isupdate = restore_isupdate
+            compiler.isdelete = restore_isdelete
+
 
 def _get_crud_params(compiler, stmt, **kw):
     """create a set of tuples representing column/string pairs for use
@@ -59,7 +94,7 @@ def _get_crud_params(compiler, stmt, **kw):
     # but in the case of mysql multi-table update, the rules for
     # .key must conditionally take tablename into account
     _column_as_key, _getattr_col_key, _col_bind_name = \
-        _key_getters_for_crud_column(compiler)
+        _key_getters_for_crud_column(compiler, stmt)
 
     # if we have statement parameters - set defaults in the
     # compiled params
@@ -128,15 +163,15 @@ def _create_bind_param(
     return bindparam
 
 
-def _key_getters_for_crud_column(compiler):
-    if compiler.isupdate and compiler.statement._extra_froms:
+def _key_getters_for_crud_column(compiler, stmt):
+    if compiler.isupdate and stmt._extra_froms:
         # when extra tables are present, refer to the columns
         # in those extra tables as table-qualified, including in
         # dictionaries and when rendering bind param names.
         # the "main" table of the statement remains unqualified,
         # allowing the most compatibility with a non-multi-table
         # statement.
-        _et = set(compiler.statement._extra_froms)
+        _et = set(stmt._extra_froms)
 
         def _column_as_key(key):
             str_key = elements._column_as_key(key)
@@ -609,7 +644,9 @@ def _get_returning_modifiers(compiler, stmt):
                                     stmt.table.implicit_returning and
                                     stmt._return_defaults)
     else:
-        implicit_return_defaults = False
+        # this line is unused, currently we are always
+        # isinsert or isupdate
+        implicit_return_defaults = False  # pragma: no cover
 
     if implicit_return_defaults:
         if stmt._return_defaults is True:
index 7b506f9db9d44328a575b6a294e9995bc3766b16..8f368dcdbb3cbe2a6beac829aac9e7cdef551d90 100644 (file)
@@ -9,15 +9,18 @@ Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`.
 
 """
 
-from .base import Executable, _generative, _from_objects, DialectKWArgs
+from .base import Executable, _generative, _from_objects, DialectKWArgs, \
+    ColumnCollection
 from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \
     _column_as_key
-from .selectable import _interpret_as_from, _interpret_as_select, HasPrefixes
+from .selectable import _interpret_as_from, _interpret_as_select, \
+    HasPrefixes, HasCTE
 from .. import util
 from .. import exc
 
 
-class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement):
+class UpdateBase(
+        HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement):
     """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.
 
     """
index a63eac2f831d135c0f7fa8ab10361ef4f06683c1..36f7f7fe1292b65eb124b319a860a6f56f25f01a 100644 (file)
@@ -46,8 +46,8 @@ from .base import ColumnCollection, Generative, Executable, \
 
 from .selectable import Alias, Join, Select, Selectable, TableClause, \
     CompoundSelect, CTE, FromClause, FromGrouping, SelectBase, \
-    alias, GenerativeSelect, \
-    subquery, HasPrefixes, HasSuffixes, Exists, ScalarSelect, TextAsFrom
+    alias, GenerativeSelect, subquery, HasCTE, HasPrefixes, HasSuffixes, \
+    Exists, ScalarSelect, TextAsFrom
 
 
 from .dml import Insert, Update, Delete, UpdateBase, ValuesBase
index c3906c2f2f1a2bfba8f35a577287bc705bcccb0d..fcd22a786c436b35580c4bcb4b4834f9a3c97717 100644 (file)
@@ -1195,6 +1195,15 @@ class CTE(Generative, HasSuffixes, Alias):
             self._suffixes = _suffixes
         super(CTE, self).__init__(selectable, name=name)
 
+    @util.dependencies("sqlalchemy.sql.dml")
+    def _populate_column_collection(self, dml):
+        if isinstance(self.element, dml.UpdateBase):
+            for col in self.element._returning:
+                col._make_proxy(self)
+        else:
+            for col in self.element.columns._all_columns:
+                col._make_proxy(self)
+
     def alias(self, name=None, flat=False):
         return CTE(
             self.original,
@@ -1223,6 +1232,164 @@ class CTE(Generative, HasSuffixes, Alias):
         )
 
 
+class HasCTE(object):
+    """Mixin that declares a class to include CTE support.
+
+    .. versionadded:: 1.1
+
+    """
+
+    def cte(self, name=None, recursive=False):
+        """Return a new :class:`.CTE`, or Common Table Expression instance.
+
+        Common table expressions are a SQL standard whereby SELECT
+        statements can draw upon secondary statements specified along
+        with the primary statement, using a clause called "WITH".
+        Special semantics regarding UNION can also be employed to
+        allow "recursive" queries, where a SELECT statement can draw
+        upon the set of rows that have previously been selected.
+
+        CTEs can also be applied to DML constructs UPDATE, INSERT
+        and DELETE on some databases, both as a source of CTE rows
+        when combined with RETURNING, as well as a consumer of
+        CTE rows.
+
+        SQLAlchemy detects :class:`.CTE` objects, which are treated
+        similarly to :class:`.Alias` objects, as special elements
+        to be delivered to the FROM clause of the statement as well
+        as to a WITH clause at the top of the statement.
+
+        .. versionchanged:: 1.1 Added support for UPDATE/INSERT/DELETE as
+           CTE, CTEs added to UPDATE/INSERT/DELETE.
+
+        :param name: name given to the common table expression.  Like
+         :meth:`._FromClause.alias`, the name can be left as ``None``
+         in which case an anonymous symbol will be used at query
+         compile time.
+        :param recursive: if ``True``, will render ``WITH RECURSIVE``.
+         A recursive common table expression is intended to be used in
+         conjunction with UNION ALL in order to derive rows
+         from those already selected.
+
+        The following examples include two from Postgresql's documentation at
+        http://www.postgresql.org/docs/current/static/queries-with.html,
+        as well as additional examples.
+
+        Example 1, non recursive::
+
+            from sqlalchemy import (Table, Column, String, Integer,
+                                    MetaData, select, func)
+
+            metadata = MetaData()
+
+            orders = Table('orders', metadata,
+                Column('region', String),
+                Column('amount', Integer),
+                Column('product', String),
+                Column('quantity', Integer)
+            )
+
+            regional_sales = select([
+                                orders.c.region,
+                                func.sum(orders.c.amount).label('total_sales')
+                            ]).group_by(orders.c.region).cte("regional_sales")
+
+
+            top_regions = select([regional_sales.c.region]).\\
+                    where(
+                        regional_sales.c.total_sales >
+                        select([
+                            func.sum(regional_sales.c.total_sales)/10
+                        ])
+                    ).cte("top_regions")
+
+            statement = select([
+                        orders.c.region,
+                        orders.c.product,
+                        func.sum(orders.c.quantity).label("product_units"),
+                        func.sum(orders.c.amount).label("product_sales")
+                ]).where(orders.c.region.in_(
+                    select([top_regions.c.region])
+                )).group_by(orders.c.region, orders.c.product)
+
+            result = conn.execute(statement).fetchall()
+
+        Example 2, WITH RECURSIVE::
+
+            from sqlalchemy import (Table, Column, String, Integer,
+                                    MetaData, select, func)
+
+            metadata = MetaData()
+
+            parts = Table('parts', metadata,
+                Column('part', String),
+                Column('sub_part', String),
+                Column('quantity', Integer),
+            )
+
+            included_parts = select([
+                                parts.c.sub_part,
+                                parts.c.part,
+                                parts.c.quantity]).\\
+                                where(parts.c.part=='our part').\\
+                                cte(recursive=True)
+
+
+            incl_alias = included_parts.alias()
+            parts_alias = parts.alias()
+            included_parts = included_parts.union_all(
+                select([
+                    parts_alias.c.sub_part,
+                    parts_alias.c.part,
+                    parts_alias.c.quantity
+                ]).
+                    where(parts_alias.c.part==incl_alias.c.sub_part)
+            )
+
+            statement = select([
+                        included_parts.c.sub_part,
+                        func.sum(included_parts.c.quantity).
+                          label('total_quantity')
+                    ]).\\
+                    group_by(included_parts.c.sub_part)
+
+            result = conn.execute(statement).fetchall()
+
+        Example 3, an upsert using UPDATE and INSERT with CTEs::
+
+            orders = table(
+                'orders',
+                column('region'),
+                column('amount'),
+                column('product'),
+                column('quantity')
+            )
+
+            upsert = (
+                orders.update()
+                .where(orders.c.region == 'Region1')
+                .values(amount=1.0, product='Product1', quantity=1)
+                .returning(*(orders.c._all_columns)).cte('upsert'))
+
+            insert = orders.insert().from_select(
+                orders.c.keys(),
+                select([
+                    literal('Region1'), literal(1.0),
+                    literal('Product1'), literal(1)
+                ).where(exists(upsert.select()))
+            )
+
+            connection.execute(insert)
+
+        .. seealso::
+
+            :meth:`.orm.query.Query.cte` - ORM version of
+            :meth:`.HasCTE.cte`.
+
+        """
+        return CTE(self, name=name, recursive=recursive)
+
+
 class FromGrouping(FromClause):
     """Represent a grouping of a FROM clause"""
     __visit_name__ = 'grouping'
@@ -1497,7 +1664,7 @@ class ForUpdateArg(ClauseElement):
             self.of = None
 
 
-class SelectBase(Executable, FromClause):
+class SelectBase(HasCTE, Executable, FromClause):
     """Base class for SELECT statements.
 
 
@@ -1531,125 +1698,6 @@ class SelectBase(Executable, FromClause):
         """
         return self.as_scalar().label(name)
 
-    def cte(self, name=None, recursive=False):
-        """Return a new :class:`.CTE`, or Common Table Expression instance.
-
-        Common table expressions are a SQL standard whereby SELECT
-        statements can draw upon secondary statements specified along
-        with the primary statement, using a clause called "WITH".
-        Special semantics regarding UNION can also be employed to
-        allow "recursive" queries, where a SELECT statement can draw
-        upon the set of rows that have previously been selected.
-
-        SQLAlchemy detects :class:`.CTE` objects, which are treated
-        similarly to :class:`.Alias` objects, as special elements
-        to be delivered to the FROM clause of the statement as well
-        as to a WITH clause at the top of the statement.
-
-        .. versionadded:: 0.7.6
-
-        :param name: name given to the common table expression.  Like
-         :meth:`._FromClause.alias`, the name can be left as ``None``
-         in which case an anonymous symbol will be used at query
-         compile time.
-        :param recursive: if ``True``, will render ``WITH RECURSIVE``.
-         A recursive common table expression is intended to be used in
-         conjunction with UNION ALL in order to derive rows
-         from those already selected.
-
-        The following examples illustrate two examples from
-        Postgresql's documentation at
-        http://www.postgresql.org/docs/8.4/static/queries-with.html.
-
-        Example 1, non recursive::
-
-            from sqlalchemy import (Table, Column, String, Integer,
-                                    MetaData, select, func)
-
-            metadata = MetaData()
-
-            orders = Table('orders', metadata,
-                Column('region', String),
-                Column('amount', Integer),
-                Column('product', String),
-                Column('quantity', Integer)
-            )
-
-            regional_sales = select([
-                                orders.c.region,
-                                func.sum(orders.c.amount).label('total_sales')
-                            ]).group_by(orders.c.region).cte("regional_sales")
-
-
-            top_regions = select([regional_sales.c.region]).\\
-                    where(
-                        regional_sales.c.total_sales >
-                        select([
-                            func.sum(regional_sales.c.total_sales)/10
-                        ])
-                    ).cte("top_regions")
-
-            statement = select([
-                        orders.c.region,
-                        orders.c.product,
-                        func.sum(orders.c.quantity).label("product_units"),
-                        func.sum(orders.c.amount).label("product_sales")
-                ]).where(orders.c.region.in_(
-                    select([top_regions.c.region])
-                )).group_by(orders.c.region, orders.c.product)
-
-            result = conn.execute(statement).fetchall()
-
-        Example 2, WITH RECURSIVE::
-
-            from sqlalchemy import (Table, Column, String, Integer,
-                                    MetaData, select, func)
-
-            metadata = MetaData()
-
-            parts = Table('parts', metadata,
-                Column('part', String),
-                Column('sub_part', String),
-                Column('quantity', Integer),
-            )
-
-            included_parts = select([
-                                parts.c.sub_part,
-                                parts.c.part,
-                                parts.c.quantity]).\\
-                                where(parts.c.part=='our part').\\
-                                cte(recursive=True)
-
-
-            incl_alias = included_parts.alias()
-            parts_alias = parts.alias()
-            included_parts = included_parts.union_all(
-                select([
-                    parts_alias.c.sub_part,
-                    parts_alias.c.part,
-                    parts_alias.c.quantity
-                ]).
-                    where(parts_alias.c.part==incl_alias.c.sub_part)
-            )
-
-            statement = select([
-                        included_parts.c.sub_part,
-                        func.sum(included_parts.c.quantity).
-                          label('total_quantity')
-                    ]).\\
-                    group_by(included_parts.c.sub_part)
-
-            result = conn.execute(statement).fetchall()
-
-
-        .. seealso::
-
-            :meth:`.orm.query.Query.cte` - ORM version of
-            :meth:`.SelectBase.cte`.
-
-        """
-        return CTE(self, name=name, recursive=recursive)
-
     @_generative
     @util.deprecated('0.6',
                      message="``autocommit()`` is deprecated. Use "
index bb5a962561c6ed5292261c43cea1275b8e32d7a5..21f9f68fb8bb5ef6c26330b46ad23856fe2df9cf 100644 (file)
@@ -296,6 +296,8 @@ class AssertsCompiledSQL(object):
                 dialect = config.db.dialect
             elif dialect == 'default':
                 dialect = default.DefaultDialect()
+            elif dialect == 'default_enhanced':
+                dialect = default.StrCompileDialect()
             elif isinstance(dialect, util.string_types):
                 dialect = url.URL(dialect).get_dialect()()
 
index b59914afce021b4e477b423e78261a31b40c72ae..aa674403e6228fe98aa864d20b7b8bf9775c96a8 100644 (file)
@@ -1,6 +1,6 @@
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import AssertsCompiledSQL, assert_raises_message
-from sqlalchemy.sql import table, column, select, func, literal
+from sqlalchemy.sql import table, column, select, func, literal, exists, and_
 from sqlalchemy.dialects import mssql
 from sqlalchemy.engine import default
 from sqlalchemy.exc import CompileError
@@ -8,7 +8,7 @@ from sqlalchemy.exc import CompileError
 
 class CTETest(fixtures.TestBase, AssertsCompiledSQL):
 
-    __dialect__ = 'default'
+    __dialect__ = 'default_enhanced'
 
     def test_nonrecursive(self):
         orders = table('orders',
@@ -492,3 +492,151 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             'regional_sales WHERE "order"."order" > regional_sales."order"',
             dialect='postgresql'
         )
+
+    def test_upsert_from_select(self):
+        orders = table(
+            'orders',
+            column('region'),
+            column('amount'),
+            column('product'),
+            column('quantity')
+        )
+
+        upsert = (
+            orders.update()
+            .where(orders.c.region == 'Region1')
+            .values(amount=1.0, product='Product1', quantity=1)
+            .returning(*(orders.c._all_columns)).cte('upsert'))
+
+        insert = orders.insert().from_select(
+            orders.c.keys(),
+            select([
+                literal('Region1'), literal(1.0),
+                literal('Product1'), literal(1)
+            ]).where(~exists(upsert.select()))
+        )
+
+        self.assert_compile(
+            insert,
+            "WITH upsert AS (UPDATE orders SET amount=:amount, "
+            "product=:product, quantity=:quantity "
+            "WHERE orders.region = :region_1 "
+            "RETURNING orders.region, orders.amount, "
+            "orders.product, orders.quantity) "
+            "INSERT INTO orders (region, amount, product, quantity) "
+            "SELECT :param_1 AS anon_1, :param_2 AS anon_2, "
+            ":param_3 AS anon_3, :param_4 AS anon_4 WHERE NOT (EXISTS "
+            "(SELECT upsert.region, upsert.amount, upsert.product, "
+            "upsert.quantity FROM upsert))"
+        )
+
+    def test_pg_example_one(self):
+        products = table('products', column('id'), column('date'))
+        products_log = table('products_log', column('id'), column('date'))
+
+        moved_rows = products.delete().where(and_(
+            products.c.date >= 'dateone',
+            products.c.date < 'datetwo')).returning(*products.c).\
+            cte('moved_rows')
+
+        stmt = products_log.insert().from_select(
+            products_log.c, moved_rows.select())
+        self.assert_compile(
+            stmt,
+            "WITH moved_rows AS "
+            "(DELETE FROM products WHERE products.date >= :date_1 "
+            "AND products.date < :date_2 "
+            "RETURNING products.id, products.date) "
+            "INSERT INTO products_log (id, date) "
+            "SELECT moved_rows.id, moved_rows.date FROM moved_rows"
+        )
+
+    def test_pg_example_two(self):
+        products = table('products', column('id'), column('price'))
+
+        t = products.update().values(price='someprice').\
+            returning(*products.c).cte('t')
+        stmt = t.select()
+
+        self.assert_compile(
+            stmt,
+            "WITH t AS "
+            "(UPDATE products SET price=:price "
+            "RETURNING products.id, products.price) "
+            "SELECT t.id, t.price "
+            "FROM t"
+        )
+
+    def test_pg_example_three(self):
+
+        parts = table(
+            'parts',
+            column('part'),
+            column('sub_part'),
+        )
+
+        included_parts = select([
+            parts.c.sub_part,
+            parts.c.part]).\
+            where(parts.c.part == 'our part').\
+            cte("included_parts", recursive=True)
+
+        pr = included_parts.alias('pr')
+        p = parts.alias('p')
+        included_parts = included_parts.union_all(
+            select([
+                p.c.sub_part,
+                p.c.part]).
+            where(p.c.part == pr.c.sub_part)
+        )
+        stmt = parts.delete().where(
+            parts.c.part.in_(select([included_parts.c.part]))).returning(
+            parts.c.part)
+
+        # the outer RETURNING is a bonus over what PG's docs have
+        self.assert_compile(
+            stmt,
+            "WITH RECURSIVE included_parts(sub_part, part) AS "
+            "(SELECT parts.sub_part AS sub_part, parts.part AS part "
+            "FROM parts "
+            "WHERE parts.part = :part_1 "
+            "UNION ALL SELECT p.sub_part AS sub_part, p.part AS part "
+            "FROM parts AS p, included_parts AS pr "
+            "WHERE p.part = pr.sub_part) "
+            "DELETE FROM parts WHERE parts.part IN "
+            "(SELECT included_parts.part FROM included_parts) "
+            "RETURNING parts.part"
+        )
+
+    def test_insert_in_the_cte(self):
+        products = table('products', column('id'), column('price'))
+
+        cte = products.insert().values(id=1, price=27.0).\
+            returning(*products.c).cte('pd')
+
+        stmt = select([cte])
+
+        self.assert_compile(
+            stmt,
+            "WITH pd AS "
+            "(INSERT INTO products (id, price) VALUES (:id, :price) "
+            "RETURNING products.id, products.price) "
+            "SELECT pd.id, pd.price "
+            "FROM pd"
+        )
+
+    def test_update_pulls_from_cte(self):
+        products = table('products', column('id'), column('price'))
+
+        cte = products.select().cte('pd')
+
+        stmt = products.update().where(products.c.price == cte.c.price)
+
+        self.assert_compile(
+            stmt,
+            "WITH pd AS "
+            "(SELECT products.id AS id, products.price AS price "
+            "FROM products) "
+            "UPDATE products SET id=:id, price=:price FROM pd "
+            "WHERE products.price = pd.price"
+        )
index ea4de032c8359c379cd435b1b4c4cd3ff132e807..513757d5be0dbd8e0d8b24bc68fb660c42fc3408 100644 (file)
@@ -188,9 +188,10 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             from_select(("otherid", "othername"), sel)
         self.assert_compile(
             ins,
-            "INSERT INTO myothertable (otherid, othername) WITH anon_1 AS "
+            "WITH anon_1 AS "
             "(SELECT mytable.name AS name FROM mytable "
             "WHERE mytable.name = :name_1) "
+            "INSERT INTO myothertable (otherid, othername) "
             "SELECT mytable.myid, mytable.name FROM mytable, anon_1 "
             "WHERE mytable.name = anon_1.name",
             checkparams={"name_1": "bar"}
@@ -205,9 +206,9 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
 
         self.assert_compile(
             ins,
-            "INSERT INTO mytable (myid, name, description) "
             "WITH c AS (SELECT mytable.myid AS myid, mytable.name AS name, "
             "mytable.description AS description FROM mytable) "
+            "INSERT INTO mytable (myid, name, description) "
             "SELECT c.myid, c.name, c.description FROM c"
         )