]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [feature] Added cte() method to Query,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Mar 2012 18:00:44 +0000 (13:00 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Mar 2012 18:00:44 +0000 (13:00 -0500)
invokes common table expression support
from the Core (see below). [ticket:1859]

- [feature] Added support for SQL standard
common table expressions (CTE), allowing
SELECT objects as the CTE source (DML
not yet supported).  This is invoked via
the cte() method on any select() construct.
[ticket:1859]

CHANGES
doc/build/core/expression_api.rst
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/sql/test_compiler.py

diff --git a/CHANGES b/CHANGES
index 4d8adcd87da98a98abfada96676bb6e06fe60390..ca97051d60c1f0ff9a672161ff20e46569b413e2 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -10,6 +10,10 @@ CHANGES
     manager to Session, used with with:
     will temporarily disable autoflush.
 
+  - [feature] Added cte() method to Query,
+    invokes common table expression support
+    from the Core (see below). [ticket:1859]
+
   - [bug] Fixed bug whereby MappedCollection
     would not get the appropriate collection
     instrumentation if it were only used
@@ -53,6 +57,13 @@ CHANGES
     on the method object.  [ticket:2352]
 
 - sql
+  - [feature] Added support for SQL standard
+    common table expressions (CTE), allowing
+    SELECT objects as the CTE source (DML
+    not yet supported).  This is invoked via
+    the cte() method on any select() construct.
+    [ticket:1859]
+
   - [bug] Added support for using the .key
     of a Column as a string identifier in a 
     result set row.   The .key is currently
index ac6aa9e8b60f1c29d40a8eb7b34746c714e4263d..4cec26f98265731a6c3e03aaa9e8f8af6894ccd8 100644 (file)
@@ -163,6 +163,10 @@ Classes
    :members:
    :show-inheritance:
 
+.. autoclass:: CTE
+   :members:
+   :show-inheritance:
+
 .. autoclass:: Delete
    :members: where
    :show-inheritance:
index f7c94aabc20a3cc60f2713f03b4b2d918f0150e4..b73235875c0d503cc0f9590c4c06437c43074301 100644 (file)
@@ -949,6 +949,13 @@ class MSSQLCompiler(compiler.SQLCompiler):
         ]
         return 'OUTPUT ' + ', '.join(columns)
 
+    def get_cte_preamble(self, recursive):
+        # SQL Server finds it too inconvenient to accept
+        # an entirely optional, SQL standard specified,
+        # "RECURSIVE" word with their "WITH",
+        # so here we go
+        return "WITH"
+
     def label_select_column(self, select, column, asfrom):
         if isinstance(column, expression.Function):
             return column.label(None)
index cafce5e3ce01b25097f8adf1e178446402bae3c3..5b7f7c9af4160f9d663e0b07e8575bde929ff64f 100644 (file)
@@ -450,6 +450,62 @@ class Query(object):
         """
         return self.enable_eagerloads(False).statement.alias(name=name)
 
+    def cte(self, name=None, recursive=False):
+        """Return the full SELECT statement represented by this :class:`.Query`
+        represented as a common table expression (CTE).
+
+        The :meth:`.Query.cte` method is new in 0.7.6.
+        
+        Parameters and usage are the same as those of the 
+        :meth:`._SelectBase.cte` method; see that method for 
+        further details.
+        
+        Here is the `Postgresql WITH 
+        RECURSIVE example <http://www.postgresql.org/docs/8.4/static/queries-with.html>`_.
+        Note that, in this example, the ``included_parts`` cte and the ``incl_alias`` alias
+        of it are Core selectables, which
+        means the columns are accessed via the ``.c.`` attribute.  The ``parts_alias``
+        object is an :func:`.orm.aliased` instance of the ``Part`` entity, so column-mapped
+        attributes are available directly::
+
+            from sqlalchemy.orm import aliased
+
+            class Part(Base):
+                __tablename__ = 'part'
+                part = Column(String)
+                sub_part = Column(String)
+                quantity = Column(Integer)
+
+            included_parts = session.query(
+                                Part.sub_part, 
+                                Part.part, 
+                                Part.quantity).\\
+                                    filter(Part.part=="our part").\\
+                                    cte(name="included_parts", recursive=True)
+
+            incl_alias = aliased(included_parts, name="pr")
+            parts_alias = aliased(Part, name="p")
+            included_parts = included_parts.union(
+                session.query(
+                    parts_alias.part, 
+                    parts_alias.sub_part, 
+                    parts_alias.quantity).\\
+                        filter(parts_alias.part==incl_alias.c.sub_part)
+                )
+
+            q = session.query(
+                    included_parts.c.sub_part,
+                    func.sum(included_parts.c.quantity).label('total_quantity')
+                ).\
+                group_by(included_parts.c.sub_part)
+
+        See also:
+        
+        :meth:`._SelectBase.cte`
+
+        """
+        return self.enable_eagerloads(False).statement.cte(name=name, recursive=recursive)
+
     def label(self, name):
         """Return the full SELECT statement represented by this :class:`.Query`, converted 
         to a scalar subquery with a label of the given name.
index b955c5608813ac81c60d1cbae9d95d5efafe6c80..e8f86634d2d3a52096bb99f6cf9d6e394b641850 100644 (file)
@@ -252,6 +252,10 @@ class SQLCompiler(engine.Compiled):
         # column targeting
         self.result_map = {}
 
+        # collect CTEs to tack on top of a SELECT
+        self.ctes = util.OrderedDict()
+        self.ctes_recursive = False
+
         # true if the paramstyle is positional
         self.positional = dialect.positional
         if self.positional:
@@ -749,6 +753,45 @@ class SQLCompiler(engine.Compiled):
         else:
             return self.bindtemplate % {'name':name}
 
+    def visit_cte(self, cte, asfrom=False, ashint=False, 
+                                fromhints=None, **kwargs):
+        if isinstance(cte.name, sql._truncated_label):
+            cte_name = self._truncated_identifier("alias", cte.name)
+        else:
+            cte_name = cte.name
+        if cte.cte_alias:
+            if isinstance(cte.cte_alias, sql._truncated_label):
+                cte_alias = self._truncated_identifier("alias", cte.cte_alias)
+            else:
+                cte_alias = cte.cte_alias
+        if not cte.cte_alias and cte not in self.ctes:
+            if cte.recursive:
+                self.ctes_recursive = True
+            text = self.preparer.format_alias(cte, cte_name)
+            if cte.recursive:
+                if isinstance(cte.original, sql.Select):
+                    col_source = cte.original
+                elif isinstance(cte.original, sql.CompoundSelect):
+                    col_source = cte.original.selects[0]
+                else:
+                    assert False
+                recur_cols = [c.key for c in util.unique_list(col_source.inner_columns)
+                                if c is not None]
+
+                text += "(%s)" % (", ".join(recur_cols))
+            text += " AS \n" + \
+                        cte.original._compiler_dispatch(
+                                self, asfrom=True, **kwargs
+                            )
+            self.ctes[cte] = text
+        if asfrom:
+            if cte.cte_alias:
+                text = self.preparer.format_alias(cte, cte_alias)
+                text += " AS " + cte_name
+            else:
+                return self.preparer.format_alias(cte, cte_name)
+            return text
+
     def visit_alias(self, alias, asfrom=False, ashint=False, 
                                 fromhints=None, **kwargs):
         if asfrom or ashint:
@@ -909,6 +952,15 @@ class SQLCompiler(engine.Compiled):
         if select.for_update:
             text += self.for_update_clause(select)
 
+        if self.ctes and \
+            compound_index==1 and not entry:
+            cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
+            cte_text += ", \n".join(
+                [txt for txt in self.ctes.values()]
+            )
+            cte_text += "\n "
+            text = cte_text + text
+
         self.stack.pop(-1)
 
         if asfrom and parens:
@@ -916,6 +968,12 @@ class SQLCompiler(engine.Compiled):
         else:
             return text
 
+    def get_cte_preamble(self, recursive):
+        if recursive:
+            return "WITH RECURSIVE"
+        else:
+            return "WITH"
+
     def get_select_precolumns(self, select):
         """Called when building a ``SELECT`` statement, position is just
         before column list.
index 4b61e6dc334a441bea0928273d5505d7e46dec2e..22fe6c420f2d58d7c41bb3e6eab559c8631ee3f6 100644 (file)
@@ -3719,6 +3719,47 @@ class Alias(FromClause):
     def bind(self):
         return self.element.bind
 
+class CTE(Alias):
+    """Represent a Common Table Expression.
+    
+    The :class:`.CTE` object is obtained using the
+    :meth:`._SelectBase.cte` method from any selectable.
+    See that method for complete examples.
+    
+    New in 0.7.6.
+
+    """
+    __visit_name__ = 'cte'
+    def __init__(self, selectable, 
+                        name=None, 
+                        recursive=False, 
+                        cte_alias=False):
+        self.recursive = recursive
+        self.cte_alias = cte_alias
+        super(CTE, self).__init__(selectable, name=name)
+
+    def alias(self, name=None):
+        return CTE(
+            self.original,
+            name=name,
+            recursive=self.recursive,
+            cte_alias = self.name
+        )
+
+    def union(self, other):
+        return CTE(
+            self.original.union(other),
+            name=self.name,
+            recursive=self.recursive
+        )
+
+    def union_all(self, other):
+        return CTE(
+            self.original.union_all(other),
+            name=self.name,
+            recursive=self.recursive
+        )
+
 
 class _Grouping(ColumnElement):
     """Represent a grouping within a column expression"""
@@ -4289,6 +4330,125 @@ class _SelectBase(Executable, FromClause):
         """
         return self.as_scalar().label(name)
 
+    def cte(self, name=None, recursive=False):
+        """Return a new :class:`.CTE`, or Common Table Expression instance.
+        
+        Common table expressions are a SQL standard whereby SELECT
+        statements can draw upon secondary statements specified along
+        with the primary statement, using a clause called "WITH".
+        Special semantics regarding UNION can also be employed to 
+        allow "recursive" queries, where a SELECT statement can draw 
+        upon the set of rows that have previously been selected.
+        
+        SQLAlchemy detects :class:`.CTE` objects, which are treated
+        similarly to :class:`.Alias` objects, as special elements
+        to be delivered to the FROM clause of the statement as well
+        as to a WITH clause at the top of the statement.
+
+        The :meth:`._SelectBase.cte` method is new in 0.7.6.
+        
+        :param name: name given to the common table expression.  Like
+         :meth:`._FromClause.alias`, the name can be left as ``None``
+         in which case an anonymous symbol will be used at query
+         compile time.
+        :param recursive: if ``True``, will render ``WITH RECURSIVE``.
+         A recursive common table expression is intended to be used in 
+         conjunction with UNION or UNION ALL in order to derive rows
+         from those already selected.
+
+        The following examples illustrate two examples from 
+        Postgresql's documentation at
+        http://www.postgresql.org/docs/8.4/static/queries-with.html.
+        
+        Example 1, non recursive::
+        
+            from sqlalchemy import Table, Column, String, Integer, MetaData, \\
+                select, func
+
+            metadata = MetaData()
+
+            orders = Table('orders', metadata,
+                Column('region', String),
+                Column('amount', Integer),
+                Column('product', String),
+                Column('quantity', Integer)
+            )
+
+            regional_sales = select([
+                                orders.c.region, 
+                                func.sum(orders.c.amount).label('total_sales')
+                            ]).group_by(orders.c.region).cte("regional_sales")
+
+
+            top_regions = select([regional_sales.c.region]).\\
+                    where(
+                        regional_sales.c.total_sales > 
+                        select([
+                            func.sum(regional_sales.c.total_sales)/10
+                        ])
+                    ).cte("top_regions")
+
+            statement = select([
+                        orders.c.region, 
+                        orders.c.product, 
+                        func.sum(orders.c.quantity).label("product_units"), 
+                        func.sum(orders.c.amount).label("product_sales")
+                ]).where(orders.c.region.in_(
+                    select([top_regions.c.region])
+                )).group_by(orders.c.region, orders.c.product)
+        
+            result = conn.execute(statement).fetchall()
+            
+        Example 2, WITH RECURSIVE::
+
+            from sqlalchemy import Table, Column, String, Integer, MetaData, \\
+                select, func
+
+            metadata = MetaData()
+
+            parts = Table('parts', metadata,
+                Column('part', String),
+                Column('sub_part', String),
+                Column('quantity', Integer),
+            )
+
+            included_parts = select([
+                                parts.c.sub_part, 
+                                parts.c.part, 
+                                parts.c.quantity]).\\
+                                where(parts.c.part=='our part').\\
+                                cte(recursive=True)
+
+
+            incl_alias = included_parts.alias()
+            parts_alias = parts.alias()
+            included_parts = included_parts.union(
+                select([
+                    parts_alias.c.part, 
+                    parts_alias.c.sub_part, 
+                    parts_alias.c.quantity
+                ]).
+                    where(parts_alias.c.part==incl_alias.c.sub_part)
+            )
+
+            statement = select([
+                        included_parts.c.sub_part, 
+                        func.sum(included_parts.c.quantity).label('total_quantity')
+                    ]).\
+                    select_from(included_parts.join(parts,
+                                included_parts.c.part==parts.c.part)).\\
+                    group_by(included_parts.c.sub_part)
+
+            result = conn.execute(statement).fetchall()
+
+        
+        See also:
+        
+        :meth:`.orm.query.Query.cte` - ORM version of :meth:`._SelectBase.cte`.
+
+        """
+        return CTE(self, name=name, recursive=recursive)
+
     @_generative
     @util.deprecated('0.6',
                      message=":func:`.autocommit` is deprecated. Use "
index 528a49558363588d840a26eea9ce65ad3d5844b0..970030d55b50bf768f65e2f99fd1e3d7fe0b63a2 100644 (file)
@@ -2297,6 +2297,116 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT x + foo() OVER () AS anon_1"
         )
 
+    def test_cte_nonrecursive(self):
+        orders = table('orders', 
+            column('region'),
+            column('amount'),
+            column('product'),
+            column('quantity')
+        )
+
+        regional_sales = select([
+                            orders.c.region, 
+                            func.sum(orders.c.amount).label('total_sales')
+                        ]).group_by(orders.c.region).cte("regional_sales")
+
+        top_regions = select([regional_sales.c.region]).\
+                where(
+                    regional_sales.c.total_sales > 
+                    select([
+                        func.sum(regional_sales.c.total_sales)/10
+                    ])
+                ).cte("top_regions")
+
+        s = select([
+                    orders.c.region, 
+                    orders.c.product, 
+                    func.sum(orders.c.quantity).label("product_units"), 
+                    func.sum(orders.c.amount).label("product_sales")
+            ]).where(orders.c.region.in_(
+                select([top_regions.c.region])
+            )).group_by(orders.c.region, orders.c.product)
+
+        # needs to render regional_sales first as top_regions
+        # refers to it
+        self.assert_compile(
+            s,
+            "WITH regional_sales AS (SELECT orders.region AS region, "
+            "sum(orders.amount) AS total_sales FROM orders "
+            "GROUP BY orders.region), "
+            "top_regions AS (SELECT "
+            "regional_sales.region AS region FROM regional_sales "
+            "WHERE regional_sales.total_sales > "
+            "(SELECT sum(regional_sales.total_sales) / :sum_1 AS "
+            "anon_1 FROM regional_sales)) "
+            "SELECT orders.region, orders.product, "
+            "sum(orders.quantity) AS product_units, "
+            "sum(orders.amount) AS product_sales "
+            "FROM orders WHERE orders.region "
+            "IN (SELECT top_regions.region FROM top_regions) "
+            "GROUP BY orders.region, orders.product"
+        )
+
+    def test_cte_recursive(self):
+        parts = table('parts', 
+            column('part'),
+            column('sub_part'),
+            column('quantity'),
+        )
+
+        included_parts = select([
+                            parts.c.sub_part, 
+                            parts.c.part, 
+                            parts.c.quantity]).\
+                            where(parts.c.part=='our part').\
+                                cte(recursive=True)
+
+        incl_alias = included_parts.alias()
+        parts_alias = parts.alias()
+        included_parts = included_parts.union(
+            select([
+                parts_alias.c.part, 
+                parts_alias.c.sub_part, 
+                parts_alias.c.quantity]).\
+                where(parts_alias.c.part==incl_alias.c.sub_part)
+            )
+
+        s = select([
+            included_parts.c.sub_part, 
+            func.sum(included_parts.c.quantity).label('total_quantity')]).\
+            select_from(included_parts.join(
+                    parts,included_parts.c.part==parts.c.part)).\
+            group_by(included_parts.c.sub_part)
+        self.assert_compile(s, 
+                "WITH RECURSIVE anon_1(sub_part, part, quantity) "
+                "AS (SELECT parts.sub_part AS sub_part, parts.part "
+                "AS part, parts.quantity AS quantity FROM parts "
+                "WHERE parts.part = :part_1 UNION SELECT parts_1.part "
+                "AS part, parts_1.sub_part AS sub_part, parts_1.quantity "
+                "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
+                "WHERE parts_1.part = anon_2.sub_part) "
+                "SELECT anon_1.sub_part, "
+                "sum(anon_1.quantity) AS total_quantity FROM anon_1 "
+                "JOIN parts ON anon_1.part = parts.part "
+                "GROUP BY anon_1.sub_part"
+            )
+
+        # quick check that the "WITH RECURSIVE" varies per
+        # dialect
+        self.assert_compile(s, 
+                "WITH anon_1(sub_part, part, quantity) "
+                "AS (SELECT parts.sub_part AS sub_part, parts.part "
+                "AS part, parts.quantity AS quantity FROM parts "
+                "WHERE parts.part = :part_1 UNION SELECT parts_1.part "
+                "AS part, parts_1.sub_part AS sub_part, parts_1.quantity "
+                "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
+                "WHERE parts_1.part = anon_2.sub_part) "
+                "SELECT anon_1.sub_part, "
+                "sum(anon_1.quantity) AS total_quantity FROM anon_1 "
+                "JOIN parts ON anon_1.part = parts.part "
+                "GROUP BY anon_1.sub_part",
+                dialect=mssql.dialect()
+            )
 
     def test_date_between(self):
         import datetime