]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [bug] Fixed more un-intuitivenesses in CTEs
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 10 Jul 2012 15:10:42 +0000 (11:10 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 10 Jul 2012 15:10:42 +0000 (11:10 -0400)
    which prevented referring to a CTE in a union
    of itself without it being aliased.
    CTEs now render uniquely
    on name, rendering the outermost CTE of a given
    name only - all other references are rendered
    just as the name.   This even includes other
    CTE/SELECTs that refer to different versions
    of the same CTE object, such as a SELECT
    or a UNION ALL of that SELECT. We are
    somewhat loosening the usual link between object
    identity and lexical identity in this case.
    A true name conflict between two unrelated
    CTEs now raises an error.

CHANGES
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/sql/test_cte.py

diff --git a/CHANGES b/CHANGES
index 1f0fd186010a143b6c5295636c28d0fb220a7104..a2f68f49d141326d21ee08aed45906d08f13b441 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -14,6 +14,21 @@ CHANGES
     positional binds + CTE support.
     [ticket:2521]
 
+  - [bug] Fixed more un-intuitivenesses in CTEs
+    which prevented referring to a CTE in a union
+    of itself without it being aliased.
+    CTEs now render uniquely
+    on name, rendering the outermost CTE of a given
+    name only - all other references are rendered
+    just as the name.   This even includes other
+    CTE/SELECTs that refer to different versions
+    of the same CTE object, such as a SELECT
+    or a UNION ALL of that SELECT. We are
+    somewhat loosening the usual link between object
+    identity and lexical identity in this case.
+    A true name conflict between two unrelated
+    CTEs now raises an error.
+
   - [bug] quoting is applied to the column names
     inside the WITH RECURSIVE clause of a
     common table expression according to the
index 2ba581384fb5f08bb15a778c00ec8d473d197c62..4ed1468444947cebbad7babcec87bb28e822916e 100644 (file)
@@ -17,7 +17,7 @@ strings
 :class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders
 type specification strings.
 
-To generate user-defined SQL strings, see 
+To generate user-defined SQL strings, see
 :module:`~sqlalchemy.ext.compiler`.
 
 """
@@ -215,7 +215,7 @@ class SQLCompiler(engine.Compiled):
     driver/DB enforces this
     """
 
-    def __init__(self, dialect, statement, column_keys=None, 
+    def __init__(self, dialect, statement, column_keys=None,
                     inline=False, **kwargs):
         """Construct a new ``DefaultCompiler`` object.
 
@@ -259,11 +259,7 @@ class SQLCompiler(engine.Compiled):
             self.positiontup = []
         self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
 
-        # collect CTEs to tack on top of a SELECT
-        self.ctes = util.OrderedDict()
-        self.ctes_recursive = False
-        if self.positional:
-            self.cte_positional = []
+        self.ctes = None
 
         # an IdentifierPreparer that formats the quoting of identifiers
         self.preparer = dialect.identifier_preparer
@@ -282,11 +278,25 @@ class SQLCompiler(engine.Compiled):
         if self.positional and dialect.paramstyle == 'numeric':
             self._apply_numbered_params()
 
+    @util.memoized_instancemethod
+    def _init_cte_state(self):
+        """Initialize collections related to CTEs only if
+        a CTE is located, to save on the overhead of
+        these collections otherwise.
+
+        """
+        # collect CTEs to tack on top of a SELECT
+        self.ctes = util.OrderedDict()
+        self.ctes_by_name = {}
+        self.ctes_recursive = False
+        if self.positional:
+            self.cte_positional = []
+
     def _apply_numbered_params(self):
         poscount = itertools.count(1)
         self.string = re.sub(
-                        r'\[_POSITION\]', 
-                        lambda m:str(util.next(poscount)), 
+                        r'\[_POSITION\]',
+                        lambda m:str(util.next(poscount)),
                         self.string)
 
     @util.memoized_property
@@ -320,11 +330,11 @@ class SQLCompiler(engine.Compiled):
                     if _group_number:
                         raise exc.InvalidRequestError(
                             "A value is required for bind parameter %r, "
-                            "in parameter group %d" % 
+                            "in parameter group %d" %
                             (bindparam.key, _group_number))
                     else:
                         raise exc.InvalidRequestError(
-                            "A value is required for bind parameter %r" 
+                            "A value is required for bind parameter %r"
                             % bindparam.key)
                 else:
                     pd[name] = bindparam.effective_value
@@ -336,18 +346,18 @@ class SQLCompiler(engine.Compiled):
                     if _group_number:
                         raise exc.InvalidRequestError(
                             "A value is required for bind parameter %r, "
-                            "in parameter group %d" % 
+                            "in parameter group %d" %
                             (bindparam.key, _group_number))
                     else:
                         raise exc.InvalidRequestError(
-                            "A value is required for bind parameter %r" 
+                            "A value is required for bind parameter %r"
                             % bindparam.key)
                 pd[self.bind_names[bindparam]] = bindparam.effective_value
             return pd
 
     @property
     def params(self):
-        """Return the bind param dictionary embedded into this 
+        """Return the bind param dictionary embedded into this
         compiled object, for those values that are present."""
         return self.construct_params(_check=False)
 
@@ -363,8 +373,8 @@ class SQLCompiler(engine.Compiled):
     def visit_grouping(self, grouping, asfrom=False, **kwargs):
         return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
 
-    def visit_label(self, label, result_map=None, 
-                            within_label_clause=False, 
+    def visit_label(self, label, result_map=None,
+                            within_label_clause=False,
                             within_columns_clause=False, **kw):
         # only render labels within the columns clause
         # or ORDER BY clause of a select.  dialect-specific compilers
@@ -377,20 +387,20 @@ class SQLCompiler(engine.Compiled):
 
             if result_map is not None:
                 result_map[labelname.lower()] = (
-                        label.name, 
-                        (label, label.element, labelname, ) + 
+                        label.name,
+                        (label, label.element, labelname, ) +
                             label._alt_names,
                         label.type)
 
-            return label.element._compiler_dispatch(self, 
+            return label.element._compiler_dispatch(self,
                                     within_columns_clause=True,
-                                    within_label_clause=True, 
+                                    within_label_clause=True,
                                     **kw) + \
                         OPERATORS[operators.as_] + \
                         self.preparer.format_label(label, labelname)
         else:
-            return label.element._compiler_dispatch(self, 
-                                    within_columns_clause=False, 
+            return label.element._compiler_dispatch(self,
+                                    within_columns_clause=False,
                                     **kw)
 
     def visit_column(self, column, result_map=None, **kwargs):
@@ -404,8 +414,8 @@ class SQLCompiler(engine.Compiled):
             name = self._truncated_identifier("colident", name)
 
         if result_map is not None:
-            result_map[name.lower()] = (orig_name, 
-                                        (column, name, column.key), 
+            result_map[name.lower()] = (orig_name,
+                                        (column, name, column.key),
                                         column.type)
 
         if is_literal:
@@ -419,7 +429,7 @@ class SQLCompiler(engine.Compiled):
         else:
             if table.schema:
                 schema_prefix = self.preparer.quote_schema(
-                                    table.schema, 
+                                    table.schema,
                                     table.quote_schema) + '.'
             else:
                 schema_prefix = ''
@@ -483,8 +493,8 @@ class SQLCompiler(engine.Compiled):
         else:
             sep = OPERATORS[clauselist.operator]
         return sep.join(
-                    s for s in 
-                    (c._compiler_dispatch(self, **kwargs) 
+                    s for s in
+                    (c._compiler_dispatch(self, **kwargs)
                     for c in clauselist.clauses)
                     if s)
 
@@ -524,7 +534,7 @@ class SQLCompiler(engine.Compiled):
 
     def visit_extract(self, extract, **kwargs):
         field = self.extract_map.get(extract.field, extract.field)
-        return "EXTRACT(%s FROM %s)" % (field, 
+        return "EXTRACT(%s FROM %s)" % (field,
                             extract.expr._compiler_dispatch(self, **kwargs))
 
     def visit_function(self, func, result_map=None, **kwargs):
@@ -550,7 +560,7 @@ class SQLCompiler(engine.Compiled):
     def function_argspec(self, func, **kwargs):
         return func.clause_expr._compiler_dispatch(self, **kwargs)
 
-    def visit_compound_select(self, cs, asfrom=False, 
+    def visit_compound_select(self, cs, asfrom=False,
                             parens=True, compound_index=1, **kwargs):
         entry = self.stack and self.stack[-1] or {}
         self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
@@ -558,8 +568,8 @@ class SQLCompiler(engine.Compiled):
         keyword = self.compound_keywords.get(cs.keyword)
 
         text = (" " + keyword + " ").join(
-                            (c._compiler_dispatch(self, 
-                                            asfrom=asfrom, parens=False, 
+                            (c._compiler_dispatch(self,
+                                            asfrom=asfrom, parens=False,
                                             compound_index=i, **kwargs)
                             for i, c in enumerate(cs.selects))
                         )
@@ -600,8 +610,8 @@ class SQLCompiler(engine.Compiled):
 
         return self._operator_dispatch(binary.operator,
                     binary,
-                    lambda opstr: binary.left._compiler_dispatch(self, **kw) + 
-                                        opstr + 
+                    lambda opstr: binary.left._compiler_dispatch(self, **kw) +
+                                        opstr +
                                     binary.right._compiler_dispatch(
                                             self, **kw),
                     **kw
@@ -610,36 +620,36 @@ class SQLCompiler(engine.Compiled):
     def visit_like_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
         return '%s LIKE %s' % (
-                            binary.left._compiler_dispatch(self, **kw), 
+                            binary.left._compiler_dispatch(self, **kw),
                             binary.right._compiler_dispatch(self, **kw)) \
-            + (escape and 
+            + (escape and
                     (' ESCAPE ' + self.render_literal_value(escape, None))
                     or '')
 
     def visit_notlike_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
         return '%s NOT LIKE %s' % (
-                            binary.left._compiler_dispatch(self, **kw), 
+                            binary.left._compiler_dispatch(self, **kw),
                             binary.right._compiler_dispatch(self, **kw)) \
-            + (escape and 
+            + (escape and
                     (' ESCAPE ' + self.render_literal_value(escape, None))
                     or '')
 
     def visit_ilike_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
         return 'lower(%s) LIKE lower(%s)' % (
-                            binary.left._compiler_dispatch(self, **kw), 
+                            binary.left._compiler_dispatch(self, **kw),
                             binary.right._compiler_dispatch(self, **kw)) \
-            + (escape and 
+            + (escape and
                     (' ESCAPE ' + self.render_literal_value(escape, None))
                     or '')
 
     def visit_notilike_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
         return 'lower(%s) NOT LIKE lower(%s)' % (
-                            binary.left._compiler_dispatch(self, **kw), 
+                            binary.left._compiler_dispatch(self, **kw),
                             binary.right._compiler_dispatch(self, **kw)) \
-            + (escape and 
+            + (escape and
                     (' ESCAPE ' + self.render_literal_value(escape, None))
                     or '')
 
@@ -683,7 +693,7 @@ class SQLCompiler(engine.Compiled):
                         "bindparam() name '%s' is reserved "
                         "for automatic usage in the VALUES or SET "
                         "clause of this "
-                        "insert/update statement.   Please use a " 
+                        "insert/update statement.   Please use a "
                         "name other than column name when using bindparam() "
                         "with insert() or update() (for example, 'b_%s')."
                         % (bindparam.key, bindparam.key)
@@ -769,8 +779,10 @@ class SQLCompiler(engine.Compiled):
                 self.positiontup.append(name)
         return self.bindtemplate % {'name':name}
 
-    def visit_cte(self, cte, asfrom=False, ashint=False, 
-                                fromhints=None, **kwargs):
+    def visit_cte(self, cte, asfrom=False, ashint=False,
+                                fromhints=None,
+                                **kwargs):
+        self._init_cte_state()
         if self.positional:
             kwargs['positional_names'] = self.cte_positional
 
@@ -778,6 +790,26 @@ class SQLCompiler(engine.Compiled):
             cte_name = self._truncated_identifier("alias", cte.name)
         else:
             cte_name = cte.name
+
+        if cte_name in self.ctes_by_name:
+            existing_cte = self.ctes_by_name[cte_name]
+            # we've generated a same-named CTE that we are enclosed in,
+            # or this is the same CTE.  just return the name.
+            if cte in existing_cte._restates or cte is existing_cte:
+                return cte_name
+            elif existing_cte in cte._restates:
+                # we've generated a same-named CTE that is
+                # enclosed in us - we take precedence, so
+                # discard the text for the "inner".
+                del self.ctes[existing_cte]
+            else:
+                raise exc.CompileError(
+                        "Multiple, unrelated CTEs found with "
+                        "the same name: %r" %
+                        cte_name)
+
+        self.ctes_by_name[cte_name] = cte
+
         if cte.cte_alias:
             if isinstance(cte.cte_alias, sql._truncated_label):
                 cte_alias = self._truncated_identifier("alias", cte.cte_alias)
@@ -794,12 +826,12 @@ class SQLCompiler(engine.Compiled):
                     col_source = cte.original.selects[0]
                 else:
                     assert False
-                recur_cols = [c for c in 
+                recur_cols = [c for c in
                             util.unique_list(col_source.inner_columns)
                                 if c is not None]
 
                 text += "(%s)" % (", ".join(
-                                    self.preparer.format_column(ident) 
+                                    self.preparer.format_column(ident)
                                     for ident in recur_cols))
             text += " AS \n" + \
                         cte.original._compiler_dispatch(
@@ -814,7 +846,7 @@ class SQLCompiler(engine.Compiled):
                 return self.preparer.format_alias(cte, cte_name)
             return text
 
-    def visit_alias(self, alias, asfrom=False, ashint=False, 
+    def visit_alias(self, alias, asfrom=False, ashint=False,
                                 fromhints=None, **kwargs):
         if asfrom or ashint:
             if isinstance(alias.name, sql._truncated_label):
@@ -825,7 +857,7 @@ class SQLCompiler(engine.Compiled):
         if ashint:
             return self.preparer.format_alias(alias, alias_name)
         elif asfrom:
-            ret = alias.original._compiler_dispatch(self, 
+            ret = alias.original._compiler_dispatch(self,
                                 asfrom=True, **kwargs) + \
                                 " AS " + \
                     self.preparer.format_alias(alias, alias_name)
@@ -849,8 +881,8 @@ class SQLCompiler(engine.Compiled):
                 select.use_labels and \
                 column._label:
             return _CompileLabel(
-                    column, 
-                    column._label, 
+                    column,
+                    column._label,
                     alt_names=(column._key_label, )
                 )
 
@@ -860,9 +892,9 @@ class SQLCompiler(engine.Compiled):
             not column.is_literal and \
             column.table is not None and \
             not isinstance(column.table, sql.Select):
-            return _CompileLabel(column, sql._as_truncated(column.name), 
+            return _CompileLabel(column, sql._as_truncated(column.name),
                                         alt_names=(column.key,))
-        elif not isinstance(column, 
+        elif not isinstance(column,
                     (sql._UnaryExpression, sql._TextClause)) \
                 and (not hasattr(column, 'name') or \
                         isinstance(column, sql.Function)):
@@ -879,9 +911,9 @@ class SQLCompiler(engine.Compiled):
     def get_crud_hint_text(self, table, text):
         return None
 
-    def visit_select(self, select, asfrom=False, parens=True, 
-                            iswrapper=False, fromhints=None, 
-                            compound_index=1, 
+    def visit_select(self, select, asfrom=False, parens=True,
+                            iswrapper=False, fromhints=None,
+                            compound_index=1,
                             positional_names=None, **kwargs):
 
         entry = self.stack and self.stack[-1] or {}
@@ -901,7 +933,7 @@ class SQLCompiler(engine.Compiled):
                           : iswrapper})
 
         if compound_index==1 and not entry or entry.get('iswrapper', False):
-            column_clause_args = {'result_map':self.result_map, 
+            column_clause_args = {'result_map':self.result_map,
                                     'positional_names':positional_names}
         else:
             column_clause_args = {'positional_names':positional_names}
@@ -912,7 +944,7 @@ class SQLCompiler(engine.Compiled):
                 self.label_select_column(select, co, asfrom=asfrom).\
                     _compiler_dispatch(self,
                         within_columns_clause=True,
-                        **column_clause_args) 
+                        **column_clause_args)
                 for co in util.unique_list(select.inner_columns)
             ]
             if c is not None
@@ -925,9 +957,9 @@ class SQLCompiler(engine.Compiled):
                             (from_, hinttext % {
                                 'name':from_._compiler_dispatch(
                                     self, ashint=True)
-                            }) 
-                            for (from_, dialect), hinttext in 
-                            select._hints.iteritems() 
+                            })
+                            for (from_, dialect), hinttext in
+                            select._hints.iteritems()
                             if dialect in ('*', self.dialect.name)
                         ])
             hint_text = self.get_select_hint_text(byfrom)
@@ -936,7 +968,7 @@ class SQLCompiler(engine.Compiled):
 
         if select._prefixes:
             text += " ".join(
-                            x._compiler_dispatch(self, **kwargs) 
+                            x._compiler_dispatch(self, **kwargs)
                             for x in select._prefixes) + " "
         text += self.get_select_precolumns(select)
         text += ', '.join(inner_columns)
@@ -945,13 +977,13 @@ class SQLCompiler(engine.Compiled):
             text += " \nFROM "
 
             if select._hints:
-                text += ', '.join([f._compiler_dispatch(self, 
-                                    asfrom=True, fromhints=byfrom, 
-                                    **kwargs) 
+                text += ', '.join([f._compiler_dispatch(self,
+                                    asfrom=True, fromhints=byfrom,
+                                    **kwargs)
                                 for f in froms])
             else:
-                text += ', '.join([f._compiler_dispatch(self, 
-                                    asfrom=True, **kwargs) 
+                text += ', '.join([f._compiler_dispatch(self,
+                                    asfrom=True, **kwargs)
                                 for f in froms])
         else:
             text += self.default_from()
@@ -1036,7 +1068,7 @@ class SQLCompiler(engine.Compiled):
             text += " OFFSET " + self.process(sql.literal(select._offset))
         return text
 
-    def visit_table(self, table, asfrom=False, ashint=False, 
+    def visit_table(self, table, asfrom=False, ashint=False,
                         fromhints=None, **kwargs):
         if asfrom or ashint:
             if getattr(table, "schema", None):
@@ -1056,10 +1088,10 @@ class SQLCompiler(engine.Compiled):
 
     def visit_join(self, join, asfrom=False, **kwargs):
         return (
-            join.left._compiler_dispatch(self, asfrom=True, **kwargs) + 
-            (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + 
-            join.right._compiler_dispatch(self, asfrom=True, **kwargs) + 
-            " ON " + 
+            join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
+            (join.isouter and " LEFT OUTER JOIN " or " JOIN ") +
+            join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
+            " ON " +
             join.onclause._compiler_dispatch(self, **kwargs)
         )
 
@@ -1071,7 +1103,7 @@ class SQLCompiler(engine.Compiled):
                 not self.dialect.supports_default_values and \
                 not self.dialect.supports_empty_insert:
             raise exc.CompileError("The version of %s you are using does "
-                                    "not support empty inserts." % 
+                                    "not support empty inserts." %
                                     self.dialect.name)
 
         preparer = self.preparer
@@ -1089,13 +1121,13 @@ class SQLCompiler(engine.Compiled):
         if insert_stmt._hints:
             dialect_hints = dict([
                 (table, hint_text)
-                for (table, dialect), hint_text in 
+                for (table, dialect), hint_text in
                 insert_stmt._hints.items()
                 if dialect in ('*', self.dialect.name)
             ])
             if insert_stmt.table in dialect_hints:
                 text += " " + self.get_crud_hint_text(
-                                    insert_stmt.table, 
+                                    insert_stmt.table,
                                     dialect_hints[insert_stmt.table]
                                 )
 
@@ -1126,7 +1158,7 @@ class SQLCompiler(engine.Compiled):
         """Provide a hook for MySQL to add LIMIT to the UPDATE"""
         return None
 
-    def update_tables_clause(self, update_stmt, from_table, 
+    def update_tables_clause(self, update_stmt, from_table,
                                             extra_froms, **kw):
         """Provide a hook to override the initial table clause
         in an UPDATE statement.
@@ -1136,19 +1168,19 @@ class SQLCompiler(engine.Compiled):
         """
         return self.preparer.format_table(from_table)
 
-    def update_from_clause(self, update_stmt, 
-                                from_table, extra_froms, 
+    def update_from_clause(self, update_stmt,
+                                from_table, extra_froms,
                                 from_hints,
                                 **kw):
-        """Provide a hook to override the generation of an 
+        """Provide a hook to override the generation of an
         UPDATE..FROM clause.
 
         MySQL and MSSQL override this.
 
         """
         return "FROM " + ', '.join(
-                    t._compiler_dispatch(self, asfrom=True, 
-                                    fromhints=from_hints, **kw) 
+                    t._compiler_dispatch(self, asfrom=True,
+                                    fromhints=from_hints, **kw)
                     for t in extra_froms)
 
     def visit_update(self, update_stmt, **kw):
@@ -1161,20 +1193,20 @@ class SQLCompiler(engine.Compiled):
         colparams = self._get_colparams(update_stmt, extra_froms)
 
         text = "UPDATE " + self.update_tables_clause(
-                                        update_stmt, 
-                                        update_stmt.table, 
+                                        update_stmt,
+                                        update_stmt.table,
                                         extra_froms, **kw)
 
         if update_stmt._hints:
             dialect_hints = dict([
                 (table, hint_text)
-                for (table, dialect), hint_text in 
+                for (table, dialect), hint_text in
                 update_stmt._hints.items()
                 if dialect in ('*', self.dialect.name)
             ])
             if update_stmt.table in dialect_hints:
                 text += " " + self.get_crud_hint_text(
-                                    update_stmt.table, 
+                                    update_stmt.table,
                                     dialect_hints[update_stmt.table]
                                 )
         else:
@@ -1183,12 +1215,12 @@ class SQLCompiler(engine.Compiled):
         text += ' SET '
         if extra_froms and self.render_table_with_column_in_update_from:
             text += ', '.join(
-                            self.visit_column(c[0]) + 
+                            self.visit_column(c[0]) +
                             '=' + c[1] for c in colparams
                             )
         else:
             text += ', '.join(
-                        self.preparer.quote(c[0].name, c[0].quote) + 
+                        self.preparer.quote(c[0].name, c[0].quote) +
                         '=' + c[1] for c in colparams
                             )
 
@@ -1200,9 +1232,9 @@ class SQLCompiler(engine.Compiled):
 
         if extra_froms:
             extra_from_text = self.update_from_clause(
-                                        update_stmt, 
-                                        update_stmt.table, 
-                                        extra_froms, 
+                                        update_stmt,
+                                        update_stmt.table,
+                                        extra_froms,
                                         dialect_hints, **kw)
             if extra_from_text:
                 text += " " + extra_from_text
@@ -1223,7 +1255,7 @@ class SQLCompiler(engine.Compiled):
         return text
 
     def _create_crud_bind_param(self, col, value, required=False):
-        bindparam = sql.bindparam(col.key, value, 
+        bindparam = sql.bindparam(col.key, value,
                             type_=col.type, required=required)
         bindparam._is_crud = True
         return bindparam._compiler_dispatch(self)
@@ -1248,8 +1280,8 @@ class SQLCompiler(engine.Compiled):
         # compiled params - return binds for all columns
         if self.column_keys is None and stmt.parameters is None:
             return [
-                        (c, self._create_crud_bind_param(c, 
-                                    None, required=True)) 
+                        (c, self._create_crud_bind_param(c,
+                                    None, required=True))
                         for c in stmt.table.columns
                     ]
 
@@ -1261,8 +1293,8 @@ class SQLCompiler(engine.Compiled):
             parameters = {}
         else:
             parameters = dict((sql._column_as_key(key), required)
-                              for key in self.column_keys 
-                              if not stmt.parameters or 
+                              for key in self.column_keys
+                              if not stmt.parameters or
                               key not in stmt.parameters)
 
         if stmt.parameters is not None:
@@ -1283,7 +1315,7 @@ class SQLCompiler(engine.Compiled):
         postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
 
         check_columns = {}
-        # special logic that only occurs for multi-table UPDATE 
+        # special logic that only occurs for multi-table UPDATE
         # statements
         if extra_tables and stmt.parameters:
             assert self.isupdate
@@ -1302,7 +1334,7 @@ class SQLCompiler(engine.Compiled):
                             value = self.process(value.self_group())
                         values.append((c, value))
             # determine tables which are actually
-            # to be updated - process onupdate and 
+            # to be updated - process onupdate and
             # server_onupdate for these
             for t in affected_tables:
                 for c in t.c:
@@ -1323,7 +1355,7 @@ class SQLCompiler(engine.Compiled):
                         self.postfetch.append(c)
 
         # iterating through columns at the top to maintain ordering.
-        # otherwise we might iterate through individual sets of 
+        # otherwise we might iterate through individual sets of
         # "defaults", "primary key cols", etc.
         for c in stmt.table.columns:
             if c.key in parameters and c.key not in check_columns:
@@ -1343,8 +1375,8 @@ class SQLCompiler(engine.Compiled):
                 if c.primary_key and \
                     need_pks and \
                     (
-                        implicit_returning or 
-                        not postfetch_lastrowid or 
+                        implicit_returning or
+                        not postfetch_lastrowid or
                         c is not stmt.table._autoincrement_column
                     ):
 
@@ -1430,7 +1462,7 @@ class SQLCompiler(engine.Compiled):
             ).difference(check_columns)
             if check:
                 util.warn(
-                    "Unconsumed column names: %s" % 
+                    "Unconsumed column names: %s" %
                     (", ".join(check))
                 )
 
@@ -1445,13 +1477,13 @@ class SQLCompiler(engine.Compiled):
         if delete_stmt._hints:
             dialect_hints = dict([
                 (table, hint_text)
-                for (table, dialect), hint_text in 
+                for (table, dialect), hint_text in
                 delete_stmt._hints.items()
                 if dialect in ('*', self.dialect.name)
             ])
             if delete_stmt.table in dialect_hints:
                 text += " " + self.get_crud_hint_text(
-                                    delete_stmt.table, 
+                                    delete_stmt.table,
                                     dialect_hints[delete_stmt.table]
                                 )
         else:
@@ -1545,7 +1577,7 @@ class DDLCompiler(engine.Compiled):
                 text += separator
                 separator = ", \n"
                 text += "\t" + self.get_column_specification(
-                                                column, 
+                                                column,
                                                 first_pk=column.primary_key and \
                                                 not first_pk
                                             )
@@ -1557,16 +1589,16 @@ class DDLCompiler(engine.Compiled):
                     text += " " + const
             except exc.CompileError, ce:
                 # Py3K
-                #raise exc.CompileError("(in table '%s', column '%s'): %s" 
+                #raise exc.CompileError("(in table '%s', column '%s'): %s"
                 #                             % (
-                #                                table.description, 
-                #                                column.name, 
+                #                                table.description,
+                #                                column.name,
                 #                                ce.args[0]
                 #                            )) from ce
                 # Py2K
-                raise exc.CompileError("(in table '%s', column '%s'): %s" 
+                raise exc.CompileError("(in table '%s', column '%s'): %s"
                                             % (
-                                                table.description, 
+                                                table.description,
                                                 column.name,
                                                 ce.args[0]
                                             )), None, sys.exc_info()[2]
@@ -1587,17 +1619,17 @@ class DDLCompiler(engine.Compiled):
         if table.primary_key:
             constraints.append(table.primary_key)
 
-        constraints.extend([c for c in table._sorted_constraints 
+        constraints.extend([c for c in table._sorted_constraints
                                 if c is not table.primary_key])
 
         return ", \n\t".join(p for p in
-                        (self.process(constraint) 
-                        for constraint in constraints 
+                        (self.process(constraint)
+                        for constraint in constraints
                         if (
                             constraint._create_rule is None or
                             constraint._create_rule(self))
                         and (
-                            not self.dialect.supports_alter or 
+                            not self.dialect.supports_alter or
                             not getattr(constraint, 'use_alter', False)
                         )) if p is not None
                 )
@@ -1625,7 +1657,7 @@ class DDLCompiler(engine.Compiled):
         if index.unique:
             text += "UNIQUE "
         text += "INDEX %s ON %s (%s)" \
-                    % (preparer.quote(self._index_identifier(index.name), 
+                    % (preparer.quote(self._index_identifier(index.name),
                         index.quote),
                        preparer.format_table(index.table),
                        ', '.join(preparer.quote(c.name, c.quote)
@@ -1751,7 +1783,7 @@ class DDLCompiler(engine.Compiled):
             text += "CONSTRAINT %s " % \
                     self.preparer.format_constraint(constraint)
         text += "UNIQUE (%s)" % (
-                    ', '.join(self.preparer.quote(c.name, c.quote) 
+                    ', '.join(self.preparer.quote(c.name, c.quote)
                             for c in constraint))
         text += self.define_constraint_deferrability(constraint)
         return text
@@ -1797,7 +1829,7 @@ class GenericTypeCompiler(engine.TypeCompiler):
                         {'precision': type_.precision}
         else:
             return "NUMERIC(%(precision)s, %(scale)s)" % \
-                        {'precision': type_.precision, 
+                        {'precision': type_.precision,
                         'scale' : type_.scale}
 
     def visit_DECIMAL(self, type_):
@@ -1854,25 +1886,25 @@ class GenericTypeCompiler(engine.TypeCompiler):
     def visit_large_binary(self, type_):
         return self.visit_BLOB(type_)
 
-    def visit_boolean(self, type_): 
+    def visit_boolean(self, type_):
         return self.visit_BOOLEAN(type_)
 
-    def visit_time(self, type_): 
+    def visit_time(self, type_):
         return self.visit_TIME(type_)
 
-    def visit_datetime(self, type_): 
+    def visit_datetime(self, type_):
         return self.visit_DATETIME(type_)
 
-    def visit_date(self, type_): 
+    def visit_date(self, type_):
         return self.visit_DATE(type_)
 
-    def visit_big_integer(self, type_): 
+    def visit_big_integer(self, type_):
         return self.visit_BIGINT(type_)
 
-    def visit_small_integer(self, type_): 
+    def visit_small_integer(self, type_):
         return self.visit_SMALLINT(type_)
 
-    def visit_integer(self, type_): 
+    def visit_integer(self, type_):
         return self.visit_INTEGER(type_)
 
     def visit_real(self, type_):
@@ -1881,19 +1913,19 @@ class GenericTypeCompiler(engine.TypeCompiler):
     def visit_float(self, type_):
         return self.visit_FLOAT(type_)
 
-    def visit_numeric(self, type_): 
+    def visit_numeric(self, type_):
         return self.visit_NUMERIC(type_)
 
-    def visit_string(self, type_): 
+    def visit_string(self, type_):
         return self.visit_VARCHAR(type_)
 
-    def visit_unicode(self, type_): 
+    def visit_unicode(self, type_):
         return self.visit_VARCHAR(type_)
 
-    def visit_text(self, type_): 
+    def visit_text(self, type_):
         return self.visit_TEXT(type_)
 
-    def visit_unicode_text(self, type_): 
+    def visit_unicode_text(self, type_):
         return self.visit_TEXT(type_)
 
     def visit_enum(self, type_):
@@ -1917,7 +1949,7 @@ class IdentifierPreparer(object):
 
     illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
 
-    def __init__(self, dialect, initial_quote='"', 
+    def __init__(self, dialect, initial_quote='"',
                     final_quote=None, escape_quote='"', omit_schema=False):
         """Construct a new ``IdentifierPreparer`` object.
 
@@ -1981,7 +2013,7 @@ class IdentifierPreparer(object):
     def quote_schema(self, schema, force):
         """Quote a schema.
 
-        Subclasses should override this to provide database-dependent 
+        Subclasses should override this to provide database-dependent
         quoting behavior.
         """
         return self.quote(schema, force)
@@ -2038,7 +2070,7 @@ class IdentifierPreparer(object):
 
         return self.quote(name, quote)
 
-    def format_column(self, column, use_table=False, 
+    def format_column(self, column, use_table=False,
                             name=None, table_name=None):
         """Prepare a quoted column name."""
 
@@ -2047,7 +2079,7 @@ class IdentifierPreparer(object):
         if not getattr(column, 'is_literal', False):
             if use_table:
                 return self.format_table(
-                            column.table, use_schema=False, 
+                            column.table, use_schema=False,
                             name=table_name) + "." + \
                             self.quote(name, column.quote)
             else:
index 8359c3314a02ce6627bd8ce5f653c615a033f4b9..f0549cc799a86f168e1326fb0baca44d01900541 100644 (file)
@@ -3754,12 +3754,15 @@ class CTE(Alias):
 
     """
     __visit_name__ = 'cte'
+
     def __init__(self, selectable, 
                         name=None, 
                         recursive=False, 
-                        cte_alias=False):
+                        cte_alias=False,
+                        _restates=frozenset()):
         self.recursive = recursive
         self.cte_alias = cte_alias
+        self._restates = _restates
         super(CTE, self).__init__(selectable, name=name)
 
     def alias(self, name=None):
@@ -3774,14 +3777,16 @@ class CTE(Alias):
         return CTE(
             self.original.union(other),
             name=self.name,
-            recursive=self.recursive
+            recursive=self.recursive,
+            _restates=self._restates.union([self])
         )
 
     def union_all(self, other):
         return CTE(
             self.original.union_all(other),
             name=self.name,
-            recursive=self.recursive
+            recursive=self.recursive,
+            _restates=self._restates.union([self])
         )
 
 
index 36f992a86c84470e4c94044135cd4334071e7034..49a53a3ec77ff0c1598058bc2782161583aad182 100644 (file)
@@ -1,8 +1,9 @@
 from test.lib import fixtures
-from test.lib.testing import AssertsCompiledSQL
+from test.lib.testing import AssertsCompiledSQL, assert_raises_message
 from sqlalchemy.sql import table, column, select, func, literal
 from sqlalchemy.dialects import mssql
 from sqlalchemy.engine import default
+from sqlalchemy.exc import CompileError
 
 class CTETest(fixtures.TestBase, AssertsCompiledSQL):
 
@@ -119,6 +120,144 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
                 dialect=mssql.dialect()
             )
 
+    def test_recursive_union_no_alias_one(self):
+        s1 = select([literal(0).label("x")])
+        cte = s1.cte(name="cte", recursive=True)
+        cte = cte.union_all(
+            select([cte.c.x + 1]).where(cte.c.x < 10)
+        )
+        s2 = select([cte])
+        self.assert_compile(s2,
+        "WITH RECURSIVE cte(x) AS "
+        "(SELECT :param_1 AS x UNION ALL "
+        "SELECT cte.x + :x_1 AS anon_1 "
+        "FROM cte WHERE cte.x < :x_2) "
+        "SELECT cte.x FROM cte"
+        )
+
+
+    def test_recursive_union_no_alias_two(self):
+        """
+
+        pg's example:
+
+            WITH RECURSIVE t(n) AS (
+                VALUES (1)
+              UNION ALL
+                SELECT n+1 FROM t WHERE n < 100
+            )
+            SELECT sum(n) FROM t;
+
+        """
+
+        # I know, this is the PG VALUES keyword,
+        # we're cheating here.  also yes we need the SELECT,
+        # sorry PG.
+        t = select([func.values(1).label("n")]).cte("t", recursive=True)
+        t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100))
+        s = select([func.sum(t.c.n)])
+        self.assert_compile(s,
+            "WITH RECURSIVE t(n) AS "
+            "(SELECT values(:values_1) AS n "
+            "UNION ALL SELECT t.n + :n_1 AS anon_1 "
+            "FROM t "
+            "WHERE t.n < :n_2) "
+            "SELECT sum(t.n) AS sum_1 FROM t"
+            )
+
+    def test_recursive_union_no_alias_three(self):
+        # like test one, but let's refer to the CTE
+        # in a sibling CTE.
+
+        s1 = select([literal(0).label("x")])
+        cte = s1.cte(name="cte", recursive=True)
+
+        # can't do it here...
+        #bar = select([cte]).cte('bar')
+        cte = cte.union_all(
+            select([cte.c.x + 1]).where(cte.c.x < 10)
+        )
+        bar = select([cte]).cte('bar')
+
+        s2 = select([cte, bar])
+        self.assert_compile(s2,
+        "WITH RECURSIVE cte(x) AS "
+        "(SELECT :param_1 AS x UNION ALL "
+        "SELECT cte.x + :x_1 AS anon_1 "
+        "FROM cte WHERE cte.x < :x_2), "
+        "bar AS (SELECT cte.x AS x FROM cte) "
+        "SELECT cte.x, bar.x FROM cte, bar"
+        )
+
+
+    def test_recursive_union_no_alias_four(self):
+        # like test one and three, but let's refer
+        # previous version of "cte".  here we test
+        # how the compiler resolves multiple instances
+        # of "cte".
+
+        s1 = select([literal(0).label("x")])
+        cte = s1.cte(name="cte", recursive=True)
+
+        bar = select([cte]).cte('bar')
+        cte = cte.union_all(
+            select([cte.c.x + 1]).where(cte.c.x < 10)
+        )
+
+        # outer cte rendered first, then bar, which
+        # includes "inner" cte
+        s2 = select([cte, bar])
+        self.assert_compile(s2,
+        "WITH RECURSIVE cte(x) AS "
+        "(SELECT :param_1 AS x UNION ALL "
+        "SELECT cte.x + :x_1 AS anon_1 "
+        "FROM cte WHERE cte.x < :x_2), "
+        "bar AS (SELECT cte.x AS x FROM cte) "
+        "SELECT cte.x, bar.x FROM cte, bar"
+        )
+
+        # bar rendered, only includes "inner" cte,
+        # "outer" cte isn't present
+        s2 = select([bar])
+        self.assert_compile(s2,
+        "WITH RECURSIVE cte(x) AS "
+        "(SELECT :param_1 AS x), "
+        "bar AS (SELECT cte.x AS x FROM cte) "
+        "SELECT bar.x FROM bar"
+        )
+
+        # bar rendered, but then the "outer"
+        # cte is rendered.
+        s2 = select([bar, cte])
+        self.assert_compile(s2,
+        "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), "
+        "cte(x) AS "
+        "(SELECT :param_1 AS x UNION ALL "
+        "SELECT cte.x + :x_1 AS anon_1 "
+        "FROM cte WHERE cte.x < :x_2) "
+
+        "SELECT bar.x, cte.x FROM bar, cte"
+        )
+
+    def test_conflicting_names(self):
+        """test a flat out name conflict."""
+
+        s1 = select([1])
+        c1= s1.cte(name='cte1', recursive=True)
+        s2 = select([1])
+        c2 = s2.cte(name='cte1', recursive=True)
+
+        s = select([c1, c2])
+        assert_raises_message(
+                CompileError,
+                "Multiple, unrelated CTEs found "
+                "with the same name: 'cte1'",
+                s.compile
+        )
+
+
+
+
     def test_union(self):
         orders = table('orders', 
             column('region'),