From: Mike Bayer Date: Sat, 3 Mar 2012 18:00:44 +0000 (-0500) Subject: - [feature] Added cte() method to Query, X-Git-Tag: rel_0_7_6~26 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1607b74f8527905ecdc6133b4b4166a9ed675e09;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - [feature] Added cte() method to Query, invokes common table expression support from the Core (see below). [ticket:1859] - [feature] Added support for SQL standard common table expressions (CTE), allowing SELECT objects as the CTE source (DML not yet supported). This is invoked via the cte() method on any select() construct. [ticket:1859] --- diff --git a/CHANGES b/CHANGES index 4d8adcd87d..ca97051d60 100644 --- a/CHANGES +++ b/CHANGES @@ -10,6 +10,10 @@ CHANGES manager to Session, used with with: will temporarily disable autoflush. + - [feature] Added cte() method to Query, + invokes common table expression support + from the Core (see below). [ticket:1859] + - [bug] Fixed bug whereby MappedCollection would not get the appropriate collection instrumentation if it were only used @@ -53,6 +57,13 @@ CHANGES on the method object. [ticket:2352] - sql + - [feature] Added support for SQL standard + common table expressions (CTE), allowing + SELECT objects as the CTE source (DML + not yet supported). This is invoked via + the cte() method on any select() construct. + [ticket:1859] + - [bug] Added support for using the .key of a Column as a string identifier in a result set row. The .key is currently diff --git a/doc/build/core/expression_api.rst b/doc/build/core/expression_api.rst index ac6aa9e8b6..4cec26f982 100644 --- a/doc/build/core/expression_api.rst +++ b/doc/build/core/expression_api.rst @@ -163,6 +163,10 @@ Classes :members: :show-inheritance: +.. autoclass:: CTE + :members: + :show-inheritance: + .. autoclass:: Delete :members: where :show-inheritance: diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index f7c94aabc2..b73235875c 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -949,6 +949,13 @@ class MSSQLCompiler(compiler.SQLCompiler): ] return 'OUTPUT ' + ', '.join(columns) + def get_cte_preamble(self, recursive): + # SQL Server finds it too inconvenient to accept + # an entirely optional, SQL standard specified, + # "RECURSIVE" word with their "WITH", + # so here we go + return "WITH" + def label_select_column(self, select, column, asfrom): if isinstance(column, expression.Function): return column.label(None) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index cafce5e3ce..5b7f7c9af4 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -450,6 +450,62 @@ class Query(object): """ return self.enable_eagerloads(False).statement.alias(name=name) + def cte(self, name=None, recursive=False): + """Return the full SELECT statement represented by this :class:`.Query` + represented as a common table expression (CTE). + + The :meth:`.Query.cte` method is new in 0.7.6. + + Parameters and usage are the same as those of the + :meth:`._SelectBase.cte` method; see that method for + further details. + + Here is the `Postgresql WITH + RECURSIVE example `_. + Note that, in this example, the ``included_parts`` cte and the ``incl_alias`` alias + of it are Core selectables, which + means the columns are accessed via the ``.c.`` attribute. The ``parts_alias`` + object is an :func:`.orm.aliased` instance of the ``Part`` entity, so column-mapped + attributes are available directly:: + + from sqlalchemy.orm import aliased + + class Part(Base): + __tablename__ = 'part' + part = Column(String) + sub_part = Column(String) + quantity = Column(Integer) + + included_parts = session.query( + Part.sub_part, + Part.part, + Part.quantity).\\ + filter(Part.part=="our part").\\ + cte(name="included_parts", recursive=True) + + incl_alias = aliased(included_parts, name="pr") + parts_alias = aliased(Part, name="p") + included_parts = included_parts.union( + session.query( + parts_alias.part, + parts_alias.sub_part, + parts_alias.quantity).\\ + filter(parts_alias.part==incl_alias.c.sub_part) + ) + + q = session.query( + included_parts.c.sub_part, + func.sum(included_parts.c.quantity).label('total_quantity') + ).\ + group_by(included_parts.c.sub_part) + + See also: + + :meth:`._SelectBase.cte` + + """ + return self.enable_eagerloads(False).statement.cte(name=name, recursive=recursive) + def label(self, name): """Return the full SELECT statement represented by this :class:`.Query`, converted to a scalar subquery with a label of the given name. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b955c56088..e8f86634d2 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -252,6 +252,10 @@ class SQLCompiler(engine.Compiled): # column targeting self.result_map = {} + # collect CTEs to tack on top of a SELECT + self.ctes = util.OrderedDict() + self.ctes_recursive = False + # true if the paramstyle is positional self.positional = dialect.positional if self.positional: @@ -749,6 +753,45 @@ class SQLCompiler(engine.Compiled): else: return self.bindtemplate % {'name':name} + def visit_cte(self, cte, asfrom=False, ashint=False, + fromhints=None, **kwargs): + if isinstance(cte.name, sql._truncated_label): + cte_name = self._truncated_identifier("alias", cte.name) + else: + cte_name = cte.name + if cte.cte_alias: + if isinstance(cte.cte_alias, sql._truncated_label): + cte_alias = self._truncated_identifier("alias", cte.cte_alias) + else: + cte_alias = cte.cte_alias + if not cte.cte_alias and cte not in self.ctes: + if cte.recursive: + self.ctes_recursive = True + text = self.preparer.format_alias(cte, cte_name) + if cte.recursive: + if isinstance(cte.original, sql.Select): + col_source = cte.original + elif isinstance(cte.original, sql.CompoundSelect): + col_source = cte.original.selects[0] + else: + assert False + recur_cols = [c.key for c in util.unique_list(col_source.inner_columns) + if c is not None] + + text += "(%s)" % (", ".join(recur_cols)) + text += " AS \n" + \ + cte.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) + self.ctes[cte] = text + if asfrom: + if cte.cte_alias: + text = self.preparer.format_alias(cte, cte_alias) + text += " AS " + cte_name + else: + return self.preparer.format_alias(cte, cte_name) + return text + def visit_alias(self, alias, asfrom=False, ashint=False, fromhints=None, **kwargs): if asfrom or ashint: @@ -909,6 +952,15 @@ class SQLCompiler(engine.Compiled): if select.for_update: text += self.for_update_clause(select) + if self.ctes and \ + compound_index==1 and not entry: + cte_text = self.get_cte_preamble(self.ctes_recursive) + " " + cte_text += ", \n".join( + [txt for txt in self.ctes.values()] + ) + cte_text += "\n " + text = cte_text + text + self.stack.pop(-1) if asfrom and parens: @@ -916,6 +968,12 @@ class SQLCompiler(engine.Compiled): else: return text + def get_cte_preamble(self, recursive): + if recursive: + return "WITH RECURSIVE" + else: + return "WITH" + def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list. diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 4b61e6dc33..22fe6c420f 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -3719,6 +3719,47 @@ class Alias(FromClause): def bind(self): return self.element.bind +class CTE(Alias): + """Represent a Common Table Expression. + + The :class:`.CTE` object is obtained using the + :meth:`._SelectBase.cte` method from any selectable. + See that method for complete examples. + + New in 0.7.6. + + """ + __visit_name__ = 'cte' + def __init__(self, selectable, + name=None, + recursive=False, + cte_alias=False): + self.recursive = recursive + self.cte_alias = cte_alias + super(CTE, self).__init__(selectable, name=name) + + def alias(self, name=None): + return CTE( + self.original, + name=name, + recursive=self.recursive, + cte_alias = self.name + ) + + def union(self, other): + return CTE( + self.original.union(other), + name=self.name, + recursive=self.recursive + ) + + def union_all(self, other): + return CTE( + self.original.union_all(other), + name=self.name, + recursive=self.recursive + ) + class _Grouping(ColumnElement): """Represent a grouping within a column expression""" @@ -4289,6 +4330,125 @@ class _SelectBase(Executable, FromClause): """ return self.as_scalar().label(name) + def cte(self, name=None, recursive=False): + """Return a new :class:`.CTE`, or Common Table Expression instance. + + Common table expressions are a SQL standard whereby SELECT + statements can draw upon secondary statements specified along + with the primary statement, using a clause called "WITH". + Special semantics regarding UNION can also be employed to + allow "recursive" queries, where a SELECT statement can draw + upon the set of rows that have previously been selected. + + SQLAlchemy detects :class:`.CTE` objects, which are treated + similarly to :class:`.Alias` objects, as special elements + to be delivered to the FROM clause of the statement as well + as to a WITH clause at the top of the statement. + + The :meth:`._SelectBase.cte` method is new in 0.7.6. + + :param name: name given to the common table expression. Like + :meth:`._FromClause.alias`, the name can be left as ``None`` + in which case an anonymous symbol will be used at query + compile time. + :param recursive: if ``True``, will render ``WITH RECURSIVE``. + A recursive common table expression is intended to be used in + conjunction with UNION or UNION ALL in order to derive rows + from those already selected. + + The following examples illustrate two examples from + Postgresql's documentation at + http://www.postgresql.org/docs/8.4/static/queries-with.html. + + Example 1, non recursive:: + + from sqlalchemy import Table, Column, String, Integer, MetaData, \\ + select, func + + metadata = MetaData() + + orders = Table('orders', metadata, + Column('region', String), + Column('amount', Integer), + Column('product', String), + Column('quantity', Integer) + ) + + regional_sales = select([ + orders.c.region, + func.sum(orders.c.amount).label('total_sales') + ]).group_by(orders.c.region).cte("regional_sales") + + + top_regions = select([regional_sales.c.region]).\\ + where( + regional_sales.c.total_sales > + select([ + func.sum(regional_sales.c.total_sales)/10 + ]) + ).cte("top_regions") + + statement = select([ + orders.c.region, + orders.c.product, + func.sum(orders.c.quantity).label("product_units"), + func.sum(orders.c.amount).label("product_sales") + ]).where(orders.c.region.in_( + select([top_regions.c.region]) + )).group_by(orders.c.region, orders.c.product) + + result = conn.execute(statement).fetchall() + + Example 2, WITH RECURSIVE:: + + from sqlalchemy import Table, Column, String, Integer, MetaData, \\ + select, func + + metadata = MetaData() + + parts = Table('parts', metadata, + Column('part', String), + Column('sub_part', String), + Column('quantity', Integer), + ) + + included_parts = select([ + parts.c.sub_part, + parts.c.part, + parts.c.quantity]).\\ + where(parts.c.part=='our part').\\ + cte(recursive=True) + + + incl_alias = included_parts.alias() + parts_alias = parts.alias() + included_parts = included_parts.union( + select([ + parts_alias.c.part, + parts_alias.c.sub_part, + parts_alias.c.quantity + ]). + where(parts_alias.c.part==incl_alias.c.sub_part) + ) + + statement = select([ + included_parts.c.sub_part, + func.sum(included_parts.c.quantity).label('total_quantity') + ]).\ + select_from(included_parts.join(parts, + included_parts.c.part==parts.c.part)).\\ + group_by(included_parts.c.sub_part) + + result = conn.execute(statement).fetchall() + + + See also: + + :meth:`.orm.query.Query.cte` - ORM version of :meth:`._SelectBase.cte`. + + """ + return CTE(self, name=name, recursive=recursive) + @_generative @util.deprecated('0.6', message=":func:`.autocommit` is deprecated. Use " diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 528a495583..970030d55b 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -2297,6 +2297,116 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT x + foo() OVER () AS anon_1" ) + def test_cte_nonrecursive(self): + orders = table('orders', + column('region'), + column('amount'), + column('product'), + column('quantity') + ) + + regional_sales = select([ + orders.c.region, + func.sum(orders.c.amount).label('total_sales') + ]).group_by(orders.c.region).cte("regional_sales") + + top_regions = select([regional_sales.c.region]).\ + where( + regional_sales.c.total_sales > + select([ + func.sum(regional_sales.c.total_sales)/10 + ]) + ).cte("top_regions") + + s = select([ + orders.c.region, + orders.c.product, + func.sum(orders.c.quantity).label("product_units"), + func.sum(orders.c.amount).label("product_sales") + ]).where(orders.c.region.in_( + select([top_regions.c.region]) + )).group_by(orders.c.region, orders.c.product) + + # needs to render regional_sales first as top_regions + # refers to it + self.assert_compile( + s, + "WITH regional_sales AS (SELECT orders.region AS region, " + "sum(orders.amount) AS total_sales FROM orders " + "GROUP BY orders.region), " + "top_regions AS (SELECT " + "regional_sales.region AS region FROM regional_sales " + "WHERE regional_sales.total_sales > " + "(SELECT sum(regional_sales.total_sales) / :sum_1 AS " + "anon_1 FROM regional_sales)) " + "SELECT orders.region, orders.product, " + "sum(orders.quantity) AS product_units, " + "sum(orders.amount) AS product_sales " + "FROM orders WHERE orders.region " + "IN (SELECT top_regions.region FROM top_regions) " + "GROUP BY orders.region, orders.product" + ) + + def test_cte_recursive(self): + parts = table('parts', + column('part'), + column('sub_part'), + column('quantity'), + ) + + included_parts = select([ + parts.c.sub_part, + parts.c.part, + parts.c.quantity]).\ + where(parts.c.part=='our part').\ + cte(recursive=True) + + incl_alias = included_parts.alias() + parts_alias = parts.alias() + included_parts = included_parts.union( + select([ + parts_alias.c.part, + parts_alias.c.sub_part, + parts_alias.c.quantity]).\ + where(parts_alias.c.part==incl_alias.c.sub_part) + ) + + s = select([ + included_parts.c.sub_part, + func.sum(included_parts.c.quantity).label('total_quantity')]).\ + select_from(included_parts.join( + parts,included_parts.c.part==parts.c.part)).\ + group_by(included_parts.c.sub_part) + self.assert_compile(s, + "WITH RECURSIVE anon_1(sub_part, part, quantity) " + "AS (SELECT parts.sub_part AS sub_part, parts.part " + "AS part, parts.quantity AS quantity FROM parts " + "WHERE parts.part = :part_1 UNION SELECT parts_1.part " + "AS part, parts_1.sub_part AS sub_part, parts_1.quantity " + "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 " + "WHERE parts_1.part = anon_2.sub_part) " + "SELECT anon_1.sub_part, " + "sum(anon_1.quantity) AS total_quantity FROM anon_1 " + "JOIN parts ON anon_1.part = parts.part " + "GROUP BY anon_1.sub_part" + ) + + # quick check that the "WITH RECURSIVE" varies per + # dialect + self.assert_compile(s, + "WITH anon_1(sub_part, part, quantity) " + "AS (SELECT parts.sub_part AS sub_part, parts.part " + "AS part, parts.quantity AS quantity FROM parts " + "WHERE parts.part = :part_1 UNION SELECT parts_1.part " + "AS part, parts_1.sub_part AS sub_part, parts_1.quantity " + "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 " + "WHERE parts_1.part = anon_2.sub_part) " + "SELECT anon_1.sub_part, " + "sum(anon_1.quantity) AS total_quantity FROM anon_1 " + "JOIN parts ON anon_1.part = parts.part " + "GROUP BY anon_1.sub_part", + dialect=mssql.dialect() + ) def test_date_between(self): import datetime