From: Mike Bayer Date: Tue, 10 Jul 2012 15:10:42 +0000 (-0400) Subject: - [bug] Fixed more un-intuitivenesses in CTEs X-Git-Tag: rel_0_7_9~65 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=878af426f7eb19c257f1db83f5b1af34624c2c6a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - [bug] Fixed more un-intuitivenesses in CTEs which prevented referring to a CTE in a union of itself without it being aliased. CTEs now render uniquely on name, rendering the outermost CTE of a given name only - all other references are rendered just as the name. This even includes other CTE/SELECTs that refer to different versions of the same CTE object, such as a SELECT or a UNION ALL of that SELECT. We are somewhat loosening the usual link between object identity and lexical identity in this case. A true name conflict between two unrelated CTEs now raises an error. --- diff --git a/CHANGES b/CHANGES index 1f0fd18601..a2f68f49d1 100644 --- a/CHANGES +++ b/CHANGES @@ -14,6 +14,21 @@ CHANGES positional binds + CTE support. [ticket:2521] + - [bug] Fixed more un-intuitivenesses in CTEs + which prevented referring to a CTE in a union + of itself without it being aliased. + CTEs now render uniquely + on name, rendering the outermost CTE of a given + name only - all other references are rendered + just as the name. This even includes other + CTE/SELECTs that refer to different versions + of the same CTE object, such as a SELECT + or a UNION ALL of that SELECT. We are + somewhat loosening the usual link between object + identity and lexical identity in this case. + A true name conflict between two unrelated + CTEs now raises an error. + - [bug] quoting is applied to the column names inside the WITH RECURSIVE clause of a common table expression according to the diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 2ba581384f..4ed1468444 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -17,7 +17,7 @@ strings :class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders type specification strings. -To generate user-defined SQL strings, see +To generate user-defined SQL strings, see :module:`~sqlalchemy.ext.compiler`. """ @@ -215,7 +215,7 @@ class SQLCompiler(engine.Compiled): driver/DB enforces this """ - def __init__(self, dialect, statement, column_keys=None, + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -259,11 +259,7 @@ class SQLCompiler(engine.Compiled): self.positiontup = [] self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] - # collect CTEs to tack on top of a SELECT - self.ctes = util.OrderedDict() - self.ctes_recursive = False - if self.positional: - self.cte_positional = [] + self.ctes = None # an IdentifierPreparer that formats the quoting of identifiers self.preparer = dialect.identifier_preparer @@ -282,11 +278,25 @@ class SQLCompiler(engine.Compiled): if self.positional and dialect.paramstyle == 'numeric': self._apply_numbered_params() + @util.memoized_instancemethod + def _init_cte_state(self): + """Initialize collections related to CTEs only if + a CTE is located, to save on the overhead of + these collections otherwise. + + """ + # collect CTEs to tack on top of a SELECT + self.ctes = util.OrderedDict() + self.ctes_by_name = {} + self.ctes_recursive = False + if self.positional: + self.cte_positional = [] + def _apply_numbered_params(self): poscount = itertools.count(1) self.string = re.sub( - r'\[_POSITION\]', - lambda m:str(util.next(poscount)), + r'\[_POSITION\]', + lambda m:str(util.next(poscount)), self.string) @util.memoized_property @@ -320,11 +330,11 @@ class SQLCompiler(engine.Compiled): if _group_number: raise exc.InvalidRequestError( "A value is required for bind parameter %r, " - "in parameter group %d" % + "in parameter group %d" % (bindparam.key, _group_number)) else: raise exc.InvalidRequestError( - "A value is required for bind parameter %r" + "A value is required for bind parameter %r" % bindparam.key) else: pd[name] = bindparam.effective_value @@ -336,18 +346,18 @@ class SQLCompiler(engine.Compiled): if _group_number: raise exc.InvalidRequestError( "A value is required for bind parameter %r, " - "in parameter group %d" % + "in parameter group %d" % (bindparam.key, _group_number)) else: raise exc.InvalidRequestError( - "A value is required for bind parameter %r" + "A value is required for bind parameter %r" % bindparam.key) pd[self.bind_names[bindparam]] = bindparam.effective_value return pd @property def params(self): - """Return the bind param dictionary embedded into this + """Return the bind param dictionary embedded into this compiled object, for those values that are present.""" return self.construct_params(_check=False) @@ -363,8 +373,8 @@ class SQLCompiler(engine.Compiled): def visit_grouping(self, grouping, asfrom=False, **kwargs): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" - def visit_label(self, label, result_map=None, - within_label_clause=False, + def visit_label(self, label, result_map=None, + within_label_clause=False, within_columns_clause=False, **kw): # only render labels within the columns clause # or ORDER BY clause of a select. dialect-specific compilers @@ -377,20 +387,20 @@ class SQLCompiler(engine.Compiled): if result_map is not None: result_map[labelname.lower()] = ( - label.name, - (label, label.element, labelname, ) + + label.name, + (label, label.element, labelname, ) + label._alt_names, label.type) - return label.element._compiler_dispatch(self, + return label.element._compiler_dispatch(self, within_columns_clause=True, - within_label_clause=True, + within_label_clause=True, **kw) + \ OPERATORS[operators.as_] + \ self.preparer.format_label(label, labelname) else: - return label.element._compiler_dispatch(self, - within_columns_clause=False, + return label.element._compiler_dispatch(self, + within_columns_clause=False, **kw) def visit_column(self, column, result_map=None, **kwargs): @@ -404,8 +414,8 @@ class SQLCompiler(engine.Compiled): name = self._truncated_identifier("colident", name) if result_map is not None: - result_map[name.lower()] = (orig_name, - (column, name, column.key), + result_map[name.lower()] = (orig_name, + (column, name, column.key), column.type) if is_literal: @@ -419,7 +429,7 @@ class SQLCompiler(engine.Compiled): else: if table.schema: schema_prefix = self.preparer.quote_schema( - table.schema, + table.schema, table.quote_schema) + '.' else: schema_prefix = '' @@ -483,8 +493,8 @@ class SQLCompiler(engine.Compiled): else: sep = OPERATORS[clauselist.operator] return sep.join( - s for s in - (c._compiler_dispatch(self, **kwargs) + s for s in + (c._compiler_dispatch(self, **kwargs) for c in clauselist.clauses) if s) @@ -524,7 +534,7 @@ class SQLCompiler(engine.Compiled): def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) - return "EXTRACT(%s FROM %s)" % (field, + return "EXTRACT(%s FROM %s)" % (field, extract.expr._compiler_dispatch(self, **kwargs)) def visit_function(self, func, result_map=None, **kwargs): @@ -550,7 +560,7 @@ class SQLCompiler(engine.Compiled): def function_argspec(self, func, **kwargs): return func.clause_expr._compiler_dispatch(self, **kwargs) - def visit_compound_select(self, cs, asfrom=False, + def visit_compound_select(self, cs, asfrom=False, parens=True, compound_index=1, **kwargs): entry = self.stack and self.stack[-1] or {} self.stack.append({'from':entry.get('from', None), 'iswrapper':True}) @@ -558,8 +568,8 @@ class SQLCompiler(engine.Compiled): keyword = self.compound_keywords.get(cs.keyword) text = (" " + keyword + " ").join( - (c._compiler_dispatch(self, - asfrom=asfrom, parens=False, + (c._compiler_dispatch(self, + asfrom=asfrom, parens=False, compound_index=i, **kwargs) for i, c in enumerate(cs.selects)) ) @@ -600,8 +610,8 @@ class SQLCompiler(engine.Compiled): return self._operator_dispatch(binary.operator, binary, - lambda opstr: binary.left._compiler_dispatch(self, **kw) + - opstr + + lambda opstr: binary.left._compiler_dispatch(self, **kw) + + opstr + binary.right._compiler_dispatch( self, **kw), **kw @@ -610,36 +620,36 @@ class SQLCompiler(engine.Compiled): def visit_like_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) return '%s LIKE %s' % ( - binary.left._compiler_dispatch(self, **kw), + binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and + + (escape and (' ESCAPE ' + self.render_literal_value(escape, None)) or '') def visit_notlike_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) return '%s NOT LIKE %s' % ( - binary.left._compiler_dispatch(self, **kw), + binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and + + (escape and (' ESCAPE ' + self.render_literal_value(escape, None)) or '') def visit_ilike_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) LIKE lower(%s)' % ( - binary.left._compiler_dispatch(self, **kw), + binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and + + (escape and (' ESCAPE ' + self.render_literal_value(escape, None)) or '') def visit_notilike_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) NOT LIKE lower(%s)' % ( - binary.left._compiler_dispatch(self, **kw), + binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and + + (escape and (' ESCAPE ' + self.render_literal_value(escape, None)) or '') @@ -683,7 +693,7 @@ class SQLCompiler(engine.Compiled): "bindparam() name '%s' is reserved " "for automatic usage in the VALUES or SET " "clause of this " - "insert/update statement. Please use a " + "insert/update statement. Please use a " "name other than column name when using bindparam() " "with insert() or update() (for example, 'b_%s')." % (bindparam.key, bindparam.key) @@ -769,8 +779,10 @@ class SQLCompiler(engine.Compiled): self.positiontup.append(name) return self.bindtemplate % {'name':name} - def visit_cte(self, cte, asfrom=False, ashint=False, - fromhints=None, **kwargs): + def visit_cte(self, cte, asfrom=False, ashint=False, + fromhints=None, + **kwargs): + self._init_cte_state() if self.positional: kwargs['positional_names'] = self.cte_positional @@ -778,6 +790,26 @@ class SQLCompiler(engine.Compiled): cte_name = self._truncated_identifier("alias", cte.name) else: cte_name = cte.name + + if cte_name in self.ctes_by_name: + existing_cte = self.ctes_by_name[cte_name] + # we've generated a same-named CTE that we are enclosed in, + # or this is the same CTE. just return the name. + if cte in existing_cte._restates or cte is existing_cte: + return cte_name + elif existing_cte in cte._restates: + # we've generated a same-named CTE that is + # enclosed in us - we take precedence, so + # discard the text for the "inner". + del self.ctes[existing_cte] + else: + raise exc.CompileError( + "Multiple, unrelated CTEs found with " + "the same name: %r" % + cte_name) + + self.ctes_by_name[cte_name] = cte + if cte.cte_alias: if isinstance(cte.cte_alias, sql._truncated_label): cte_alias = self._truncated_identifier("alias", cte.cte_alias) @@ -794,12 +826,12 @@ class SQLCompiler(engine.Compiled): col_source = cte.original.selects[0] else: assert False - recur_cols = [c for c in + recur_cols = [c for c in util.unique_list(col_source.inner_columns) if c is not None] text += "(%s)" % (", ".join( - self.preparer.format_column(ident) + self.preparer.format_column(ident) for ident in recur_cols)) text += " AS \n" + \ cte.original._compiler_dispatch( @@ -814,7 +846,7 @@ class SQLCompiler(engine.Compiled): return self.preparer.format_alias(cte, cte_name) return text - def visit_alias(self, alias, asfrom=False, ashint=False, + def visit_alias(self, alias, asfrom=False, ashint=False, fromhints=None, **kwargs): if asfrom or ashint: if isinstance(alias.name, sql._truncated_label): @@ -825,7 +857,7 @@ class SQLCompiler(engine.Compiled): if ashint: return self.preparer.format_alias(alias, alias_name) elif asfrom: - ret = alias.original._compiler_dispatch(self, + ret = alias.original._compiler_dispatch(self, asfrom=True, **kwargs) + \ " AS " + \ self.preparer.format_alias(alias, alias_name) @@ -849,8 +881,8 @@ class SQLCompiler(engine.Compiled): select.use_labels and \ column._label: return _CompileLabel( - column, - column._label, + column, + column._label, alt_names=(column._key_label, ) ) @@ -860,9 +892,9 @@ class SQLCompiler(engine.Compiled): not column.is_literal and \ column.table is not None and \ not isinstance(column.table, sql.Select): - return _CompileLabel(column, sql._as_truncated(column.name), + return _CompileLabel(column, sql._as_truncated(column.name), alt_names=(column.key,)) - elif not isinstance(column, + elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) \ and (not hasattr(column, 'name') or \ isinstance(column, sql.Function)): @@ -879,9 +911,9 @@ class SQLCompiler(engine.Compiled): def get_crud_hint_text(self, table, text): return None - def visit_select(self, select, asfrom=False, parens=True, - iswrapper=False, fromhints=None, - compound_index=1, + def visit_select(self, select, asfrom=False, parens=True, + iswrapper=False, fromhints=None, + compound_index=1, positional_names=None, **kwargs): entry = self.stack and self.stack[-1] or {} @@ -901,7 +933,7 @@ class SQLCompiler(engine.Compiled): : iswrapper}) if compound_index==1 and not entry or entry.get('iswrapper', False): - column_clause_args = {'result_map':self.result_map, + column_clause_args = {'result_map':self.result_map, 'positional_names':positional_names} else: column_clause_args = {'positional_names':positional_names} @@ -912,7 +944,7 @@ class SQLCompiler(engine.Compiled): self.label_select_column(select, co, asfrom=asfrom).\ _compiler_dispatch(self, within_columns_clause=True, - **column_clause_args) + **column_clause_args) for co in util.unique_list(select.inner_columns) ] if c is not None @@ -925,9 +957,9 @@ class SQLCompiler(engine.Compiled): (from_, hinttext % { 'name':from_._compiler_dispatch( self, ashint=True) - }) - for (from_, dialect), hinttext in - select._hints.iteritems() + }) + for (from_, dialect), hinttext in + select._hints.iteritems() if dialect in ('*', self.dialect.name) ]) hint_text = self.get_select_hint_text(byfrom) @@ -936,7 +968,7 @@ class SQLCompiler(engine.Compiled): if select._prefixes: text += " ".join( - x._compiler_dispatch(self, **kwargs) + x._compiler_dispatch(self, **kwargs) for x in select._prefixes) + " " text += self.get_select_precolumns(select) text += ', '.join(inner_columns) @@ -945,13 +977,13 @@ class SQLCompiler(engine.Compiled): text += " \nFROM " if select._hints: - text += ', '.join([f._compiler_dispatch(self, - asfrom=True, fromhints=byfrom, - **kwargs) + text += ', '.join([f._compiler_dispatch(self, + asfrom=True, fromhints=byfrom, + **kwargs) for f in froms]) else: - text += ', '.join([f._compiler_dispatch(self, - asfrom=True, **kwargs) + text += ', '.join([f._compiler_dispatch(self, + asfrom=True, **kwargs) for f in froms]) else: text += self.default_from() @@ -1036,7 +1068,7 @@ class SQLCompiler(engine.Compiled): text += " OFFSET " + self.process(sql.literal(select._offset)) return text - def visit_table(self, table, asfrom=False, ashint=False, + def visit_table(self, table, asfrom=False, ashint=False, fromhints=None, **kwargs): if asfrom or ashint: if getattr(table, "schema", None): @@ -1056,10 +1088,10 @@ class SQLCompiler(engine.Compiled): def visit_join(self, join, asfrom=False, **kwargs): return ( - join.left._compiler_dispatch(self, asfrom=True, **kwargs) + - (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + - join.right._compiler_dispatch(self, asfrom=True, **kwargs) + - " ON " + + join.left._compiler_dispatch(self, asfrom=True, **kwargs) + + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + + " ON " + join.onclause._compiler_dispatch(self, **kwargs) ) @@ -1071,7 +1103,7 @@ class SQLCompiler(engine.Compiled): not self.dialect.supports_default_values and \ not self.dialect.supports_empty_insert: raise exc.CompileError("The version of %s you are using does " - "not support empty inserts." % + "not support empty inserts." % self.dialect.name) preparer = self.preparer @@ -1089,13 +1121,13 @@ class SQLCompiler(engine.Compiled): if insert_stmt._hints: dialect_hints = dict([ (table, hint_text) - for (table, dialect), hint_text in + for (table, dialect), hint_text in insert_stmt._hints.items() if dialect in ('*', self.dialect.name) ]) if insert_stmt.table in dialect_hints: text += " " + self.get_crud_hint_text( - insert_stmt.table, + insert_stmt.table, dialect_hints[insert_stmt.table] ) @@ -1126,7 +1158,7 @@ class SQLCompiler(engine.Compiled): """Provide a hook for MySQL to add LIMIT to the UPDATE""" return None - def update_tables_clause(self, update_stmt, from_table, + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): """Provide a hook to override the initial table clause in an UPDATE statement. @@ -1136,19 +1168,19 @@ class SQLCompiler(engine.Compiled): """ return self.preparer.format_table(from_table) - def update_from_clause(self, update_stmt, - from_table, extra_froms, + def update_from_clause(self, update_stmt, + from_table, extra_froms, from_hints, **kw): - """Provide a hook to override the generation of an + """Provide a hook to override the generation of an UPDATE..FROM clause. MySQL and MSSQL override this. """ return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) + t._compiler_dispatch(self, asfrom=True, + fromhints=from_hints, **kw) for t in extra_froms) def visit_update(self, update_stmt, **kw): @@ -1161,20 +1193,20 @@ class SQLCompiler(engine.Compiled): colparams = self._get_colparams(update_stmt, extra_froms) text = "UPDATE " + self.update_tables_clause( - update_stmt, - update_stmt.table, + update_stmt, + update_stmt.table, extra_froms, **kw) if update_stmt._hints: dialect_hints = dict([ (table, hint_text) - for (table, dialect), hint_text in + for (table, dialect), hint_text in update_stmt._hints.items() if dialect in ('*', self.dialect.name) ]) if update_stmt.table in dialect_hints: text += " " + self.get_crud_hint_text( - update_stmt.table, + update_stmt.table, dialect_hints[update_stmt.table] ) else: @@ -1183,12 +1215,12 @@ class SQLCompiler(engine.Compiled): text += ' SET ' if extra_froms and self.render_table_with_column_in_update_from: text += ', '.join( - self.visit_column(c[0]) + + self.visit_column(c[0]) + '=' + c[1] for c in colparams ) else: text += ', '.join( - self.preparer.quote(c[0].name, c[0].quote) + + self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] for c in colparams ) @@ -1200,9 +1232,9 @@ class SQLCompiler(engine.Compiled): if extra_froms: extra_from_text = self.update_from_clause( - update_stmt, - update_stmt.table, - extra_froms, + update_stmt, + update_stmt.table, + extra_froms, dialect_hints, **kw) if extra_from_text: text += " " + extra_from_text @@ -1223,7 +1255,7 @@ class SQLCompiler(engine.Compiled): return text def _create_crud_bind_param(self, col, value, required=False): - bindparam = sql.bindparam(col.key, value, + bindparam = sql.bindparam(col.key, value, type_=col.type, required=required) bindparam._is_crud = True return bindparam._compiler_dispatch(self) @@ -1248,8 +1280,8 @@ class SQLCompiler(engine.Compiled): # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: return [ - (c, self._create_crud_bind_param(c, - None, required=True)) + (c, self._create_crud_bind_param(c, + None, required=True)) for c in stmt.table.columns ] @@ -1261,8 +1293,8 @@ class SQLCompiler(engine.Compiled): parameters = {} else: parameters = dict((sql._column_as_key(key), required) - for key in self.column_keys - if not stmt.parameters or + for key in self.column_keys + if not stmt.parameters or key not in stmt.parameters) if stmt.parameters is not None: @@ -1283,7 +1315,7 @@ class SQLCompiler(engine.Compiled): postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid check_columns = {} - # special logic that only occurs for multi-table UPDATE + # special logic that only occurs for multi-table UPDATE # statements if extra_tables and stmt.parameters: assert self.isupdate @@ -1302,7 +1334,7 @@ class SQLCompiler(engine.Compiled): value = self.process(value.self_group()) values.append((c, value)) # determine tables which are actually - # to be updated - process onupdate and + # to be updated - process onupdate and # server_onupdate for these for t in affected_tables: for c in t.c: @@ -1323,7 +1355,7 @@ class SQLCompiler(engine.Compiled): self.postfetch.append(c) # iterating through columns at the top to maintain ordering. - # otherwise we might iterate through individual sets of + # otherwise we might iterate through individual sets of # "defaults", "primary key cols", etc. for c in stmt.table.columns: if c.key in parameters and c.key not in check_columns: @@ -1343,8 +1375,8 @@ class SQLCompiler(engine.Compiled): if c.primary_key and \ need_pks and \ ( - implicit_returning or - not postfetch_lastrowid or + implicit_returning or + not postfetch_lastrowid or c is not stmt.table._autoincrement_column ): @@ -1430,7 +1462,7 @@ class SQLCompiler(engine.Compiled): ).difference(check_columns) if check: util.warn( - "Unconsumed column names: %s" % + "Unconsumed column names: %s" % (", ".join(check)) ) @@ -1445,13 +1477,13 @@ class SQLCompiler(engine.Compiled): if delete_stmt._hints: dialect_hints = dict([ (table, hint_text) - for (table, dialect), hint_text in + for (table, dialect), hint_text in delete_stmt._hints.items() if dialect in ('*', self.dialect.name) ]) if delete_stmt.table in dialect_hints: text += " " + self.get_crud_hint_text( - delete_stmt.table, + delete_stmt.table, dialect_hints[delete_stmt.table] ) else: @@ -1545,7 +1577,7 @@ class DDLCompiler(engine.Compiled): text += separator separator = ", \n" text += "\t" + self.get_column_specification( - column, + column, first_pk=column.primary_key and \ not first_pk ) @@ -1557,16 +1589,16 @@ class DDLCompiler(engine.Compiled): text += " " + const except exc.CompileError, ce: # Py3K - #raise exc.CompileError("(in table '%s', column '%s'): %s" + #raise exc.CompileError("(in table '%s', column '%s'): %s" # % ( - # table.description, - # column.name, + # table.description, + # column.name, # ce.args[0] # )) from ce # Py2K - raise exc.CompileError("(in table '%s', column '%s'): %s" + raise exc.CompileError("(in table '%s', column '%s'): %s" % ( - table.description, + table.description, column.name, ce.args[0] )), None, sys.exc_info()[2] @@ -1587,17 +1619,17 @@ class DDLCompiler(engine.Compiled): if table.primary_key: constraints.append(table.primary_key) - constraints.extend([c for c in table._sorted_constraints + constraints.extend([c for c in table._sorted_constraints if c is not table.primary_key]) return ", \n\t".join(p for p in - (self.process(constraint) - for constraint in constraints + (self.process(constraint) + for constraint in constraints if ( constraint._create_rule is None or constraint._create_rule(self)) and ( - not self.dialect.supports_alter or + not self.dialect.supports_alter or not getattr(constraint, 'use_alter', False) )) if p is not None ) @@ -1625,7 +1657,7 @@ class DDLCompiler(engine.Compiled): if index.unique: text += "UNIQUE " text += "INDEX %s ON %s (%s)" \ - % (preparer.quote(self._index_identifier(index.name), + % (preparer.quote(self._index_identifier(index.name), index.quote), preparer.format_table(index.table), ', '.join(preparer.quote(c.name, c.quote) @@ -1751,7 +1783,7 @@ class DDLCompiler(engine.Compiled): text += "CONSTRAINT %s " % \ self.preparer.format_constraint(constraint) text += "UNIQUE (%s)" % ( - ', '.join(self.preparer.quote(c.name, c.quote) + ', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)) text += self.define_constraint_deferrability(constraint) return text @@ -1797,7 +1829,7 @@ class GenericTypeCompiler(engine.TypeCompiler): {'precision': type_.precision} else: return "NUMERIC(%(precision)s, %(scale)s)" % \ - {'precision': type_.precision, + {'precision': type_.precision, 'scale' : type_.scale} def visit_DECIMAL(self, type_): @@ -1854,25 +1886,25 @@ class GenericTypeCompiler(engine.TypeCompiler): def visit_large_binary(self, type_): return self.visit_BLOB(type_) - def visit_boolean(self, type_): + def visit_boolean(self, type_): return self.visit_BOOLEAN(type_) - def visit_time(self, type_): + def visit_time(self, type_): return self.visit_TIME(type_) - def visit_datetime(self, type_): + def visit_datetime(self, type_): return self.visit_DATETIME(type_) - def visit_date(self, type_): + def visit_date(self, type_): return self.visit_DATE(type_) - def visit_big_integer(self, type_): + def visit_big_integer(self, type_): return self.visit_BIGINT(type_) - def visit_small_integer(self, type_): + def visit_small_integer(self, type_): return self.visit_SMALLINT(type_) - def visit_integer(self, type_): + def visit_integer(self, type_): return self.visit_INTEGER(type_) def visit_real(self, type_): @@ -1881,19 +1913,19 @@ class GenericTypeCompiler(engine.TypeCompiler): def visit_float(self, type_): return self.visit_FLOAT(type_) - def visit_numeric(self, type_): + def visit_numeric(self, type_): return self.visit_NUMERIC(type_) - def visit_string(self, type_): + def visit_string(self, type_): return self.visit_VARCHAR(type_) - def visit_unicode(self, type_): + def visit_unicode(self, type_): return self.visit_VARCHAR(type_) - def visit_text(self, type_): + def visit_text(self, type_): return self.visit_TEXT(type_) - def visit_unicode_text(self, type_): + def visit_unicode_text(self, type_): return self.visit_TEXT(type_) def visit_enum(self, type_): @@ -1917,7 +1949,7 @@ class IdentifierPreparer(object): illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS - def __init__(self, dialect, initial_quote='"', + def __init__(self, dialect, initial_quote='"', final_quote=None, escape_quote='"', omit_schema=False): """Construct a new ``IdentifierPreparer`` object. @@ -1981,7 +2013,7 @@ class IdentifierPreparer(object): def quote_schema(self, schema, force): """Quote a schema. - Subclasses should override this to provide database-dependent + Subclasses should override this to provide database-dependent quoting behavior. """ return self.quote(schema, force) @@ -2038,7 +2070,7 @@ class IdentifierPreparer(object): return self.quote(name, quote) - def format_column(self, column, use_table=False, + def format_column(self, column, use_table=False, name=None, table_name=None): """Prepare a quoted column name.""" @@ -2047,7 +2079,7 @@ class IdentifierPreparer(object): if not getattr(column, 'is_literal', False): if use_table: return self.format_table( - column.table, use_schema=False, + column.table, use_schema=False, name=table_name) + "." + \ self.quote(name, column.quote) else: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 8359c3314a..f0549cc799 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -3754,12 +3754,15 @@ class CTE(Alias): """ __visit_name__ = 'cte' + def __init__(self, selectable, name=None, recursive=False, - cte_alias=False): + cte_alias=False, + _restates=frozenset()): self.recursive = recursive self.cte_alias = cte_alias + self._restates = _restates super(CTE, self).__init__(selectable, name=name) def alias(self, name=None): @@ -3774,14 +3777,16 @@ class CTE(Alias): return CTE( self.original.union(other), name=self.name, - recursive=self.recursive + recursive=self.recursive, + _restates=self._restates.union([self]) ) def union_all(self, other): return CTE( self.original.union_all(other), name=self.name, - recursive=self.recursive + recursive=self.recursive, + _restates=self._restates.union([self]) ) diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 36f992a86c..49a53a3ec7 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1,8 +1,9 @@ from test.lib import fixtures -from test.lib.testing import AssertsCompiledSQL +from test.lib.testing import AssertsCompiledSQL, assert_raises_message from sqlalchemy.sql import table, column, select, func, literal from sqlalchemy.dialects import mssql from sqlalchemy.engine import default +from sqlalchemy.exc import CompileError class CTETest(fixtures.TestBase, AssertsCompiledSQL): @@ -119,6 +120,144 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): dialect=mssql.dialect() ) + def test_recursive_union_no_alias_one(self): + s1 = select([literal(0).label("x")]) + cte = s1.cte(name="cte", recursive=True) + cte = cte.union_all( + select([cte.c.x + 1]).where(cte.c.x < 10) + ) + s2 = select([cte]) + self.assert_compile(s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2) " + "SELECT cte.x FROM cte" + ) + + + def test_recursive_union_no_alias_two(self): + """ + + pg's example: + + WITH RECURSIVE t(n) AS ( + VALUES (1) + UNION ALL + SELECT n+1 FROM t WHERE n < 100 + ) + SELECT sum(n) FROM t; + + """ + + # I know, this is the PG VALUES keyword, + # we're cheating here. also yes we need the SELECT, + # sorry PG. + t = select([func.values(1).label("n")]).cte("t", recursive=True) + t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100)) + s = select([func.sum(t.c.n)]) + self.assert_compile(s, + "WITH RECURSIVE t(n) AS " + "(SELECT values(:values_1) AS n " + "UNION ALL SELECT t.n + :n_1 AS anon_1 " + "FROM t " + "WHERE t.n < :n_2) " + "SELECT sum(t.n) AS sum_1 FROM t" + ) + + def test_recursive_union_no_alias_three(self): + # like test one, but let's refer to the CTE + # in a sibling CTE. + + s1 = select([literal(0).label("x")]) + cte = s1.cte(name="cte", recursive=True) + + # can't do it here... + #bar = select([cte]).cte('bar') + cte = cte.union_all( + select([cte.c.x + 1]).where(cte.c.x < 10) + ) + bar = select([cte]).cte('bar') + + s2 = select([cte, bar]) + self.assert_compile(s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2), " + "bar AS (SELECT cte.x AS x FROM cte) " + "SELECT cte.x, bar.x FROM cte, bar" + ) + + + def test_recursive_union_no_alias_four(self): + # like test one and three, but let's refer + # previous version of "cte". here we test + # how the compiler resolves multiple instances + # of "cte". + + s1 = select([literal(0).label("x")]) + cte = s1.cte(name="cte", recursive=True) + + bar = select([cte]).cte('bar') + cte = cte.union_all( + select([cte.c.x + 1]).where(cte.c.x < 10) + ) + + # outer cte rendered first, then bar, which + # includes "inner" cte + s2 = select([cte, bar]) + self.assert_compile(s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2), " + "bar AS (SELECT cte.x AS x FROM cte) " + "SELECT cte.x, bar.x FROM cte, bar" + ) + + # bar rendered, only includes "inner" cte, + # "outer" cte isn't present + s2 = select([bar]) + self.assert_compile(s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x), " + "bar AS (SELECT cte.x AS x FROM cte) " + "SELECT bar.x FROM bar" + ) + + # bar rendered, but then the "outer" + # cte is rendered. + s2 = select([bar, cte]) + self.assert_compile(s2, + "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), " + "cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2) " + + "SELECT bar.x, cte.x FROM bar, cte" + ) + + def test_conflicting_names(self): + """test a flat out name conflict.""" + + s1 = select([1]) + c1= s1.cte(name='cte1', recursive=True) + s2 = select([1]) + c2 = s2.cte(name='cte1', recursive=True) + + s = select([c1, c2]) + assert_raises_message( + CompileError, + "Multiple, unrelated CTEs found " + "with the same name: 'cte1'", + s.compile + ) + + + + def test_union(self): orders = table('orders', column('region'),