From: Mike Bayer Date: Thu, 26 Jul 2007 07:19:37 +0000 (+0000) Subject: - ANSICompiler now uses its own traversal when compiling, returning a composed X-Git-Tag: rel_0_4_6~18 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=778cb994f5b5765cde4cfd8cdb12064c31af0af5;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - ANSICompiler now uses its own traversal when compiling, returning a composed string from each visit_XXXX method, so that the full string is compiled at once without using any dictionary storage. dialects modified accordingly. tested on mysql/sqlite/postgres fully, tested with string-only tests for oracle/fb/informix/mssql so far. --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 2d738769d7..e9f8b45ad3 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -142,21 +142,6 @@ class ANSICompiler(engine.Compiled): # actually present in the generated SQL self.bind_names = {} - # a dictionary which stores the string representation for every ClauseElement - # processed by this compiler. - self.strings = {} - - # a dictionary which stores the string representation for ClauseElements - # processed by this compiler, which are to be used in the FROM clause - # of a select. items are often placed in "froms" as well as "strings" - # and sometimes with different representations. - self.froms = {} - - # slightly hacky. maps FROM clauses to WHERE clauses, and used in select - # generation to modify the WHERE clause of the select. currently a hack - # used by the oracle module. - self.wheres = {} - # when the compiler visits a SELECT statement, the clause object is appended # to this stack. various visit operations will check this stack to determine # additional choices (TODO: it seems to be all typemap stuff. shouldnt this only @@ -209,7 +194,7 @@ class ANSICompiler(engine.Compiled): # this re will search for params like :param # it has a negative lookbehind for an extra ':' so that it doesnt match # postgres '::text' tokens - text = self.strings[self.statement] + text = self.string if ':' not in text: return @@ -231,13 +216,27 @@ class ANSICompiler(engine.Compiled): text = BIND_PARAMS.sub(getnum, text) # un-escape any \:params text = BIND_PARAMS_ESC.sub(lambda m: m.group(1), text) - self.strings[self.statement] = text + self.string = text + def compile(self): + self.string = self.process(self.statement) + self.after_compile() + + def process(self, obj, **kwargs): + return self.traverse_single(obj, **kwargs) + def is_subquery(self, select): return self.correlate_state[select].get('is_subquery', False) def get_whereclause(self, obj): - return self.wheres.get(obj, None) + """given a FROM clause, return an additional WHERE condition that should be + applied to a SELECT. + + Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN + constructs in non-ansi mode. + """ + + return None def construct_params(self, params): """Return a sql.ClauseParameters object. @@ -276,8 +275,8 @@ class ANSICompiler(engine.Compiled): return "" - def visit_grouping(self, grouping): - self.strings[grouping] = self.froms[grouping] = "(" + self.strings[grouping.elem] + ")" + def visit_grouping(self, grouping, **kwargs): + return "(" + self.process(grouping.elem) + ")" def visit_label(self, label): labelname = self._truncated_identifier("colident", label.name) @@ -287,9 +286,9 @@ class ANSICompiler(engine.Compiled): if isinstance(label.obj, sql._ColumnClause): self.column_labels[label.obj._label] = labelname self.column_labels[label.name] = labelname - self.strings[label] = " ".join([self.strings[label.obj], self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)]) + return " ".join([self.process(label.obj), self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)]) - def visit_column(self, column): + def visit_column(self, column, **kwargs): # there is actually somewhat of a ruleset when you would *not* necessarily # want to truncate a column identifier, if its mapped to the name of a # physical column. but thats very hard to identify at this point, and @@ -300,47 +299,49 @@ class ANSICompiler(engine.Compiled): else: name = column.name + if len(self.select_stack): + # if we are within a visit to a Select, set up the "typemap" + # for this column which is used to translate result set values + self.typemap.setdefault(name.lower(), column.type) + self.column_labels.setdefault(column._label, name.lower()) + if column.table is None or not column.table.named_with_column(): - self.strings[column] = self.preparer.format_column(column, name=name) + return self.preparer.format_column(column, name=name) else: if column.table.oid_column is column: n = self.dialect.oid_column_name(column) if n is not None: - self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n) + return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n) elif len(column.table.primary_key) != 0: pk = list(column.table.primary_key)[0] pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name)) - self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name)) + return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name)) else: - self.strings[column] = None + return None else: - self.strings[column] = self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name)) + return self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name)) - if len(self.select_stack): - # if we are within a visit to a Select, set up the "typemap" - # for this column which is used to translate result set values - self.typemap.setdefault(name.lower(), column.type) - self.column_labels.setdefault(column._label, name.lower()) - def visit_fromclause(self, fromclause): - self.froms[fromclause] = fromclause.name + def visit_fromclause(self, fromclause, **kwargs): + return fromclause.name - def visit_index(self, index): - self.strings[index] = index.name + def visit_index(self, index, **kwargs): + return index.name - def visit_typeclause(self, typeclause): - self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec() + def visit_typeclause(self, typeclause, **kwargs): + return typeclause.type.dialect_impl(self.dialect).get_col_spec() - def visit_textclause(self, textclause): - self.strings[textclause] = textclause.text - self.froms[textclause] = textclause.text + def visit_textclause(self, textclause, **kwargs): + for bind in textclause.bindparams.values(): + self.process(bind) if textclause.typemap is not None: self.typemap.update(textclause.typemap) + return textclause.text - def visit_null(self, null): - self.strings[null] = 'NULL' + def visit_null(self, null, **kwargs): + return 'NULL' - def visit_clauselist(self, clauselist): + def visit_clauselist(self, clauselist, **kwargs): sep = clauselist.operator if sep is None: sep = " " @@ -348,59 +349,60 @@ class ANSICompiler(engine.Compiled): sep = ', ' else: sep = " " + self.operator_string(clauselist.operator) + " " - self.strings[clauselist] = string.join([s for s in [self.strings[c] for c in clauselist.clauses] if s is not None], sep) + return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep) def apply_function_parens(self, func): return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 - def visit_calculatedclause(self, clause): - self.strings[clause] = self.strings[clause.clause_expr] + def visit_calculatedclause(self, clause, **kwargs): + return self.process(clause.clause_expr) - def visit_cast(self, cast): + def visit_cast(self, cast, **kwargs): if len(self.select_stack): # not sure if we want to set the typemap here... self.typemap.setdefault("CAST", cast.type) - self.strings[cast] = "CAST(%s AS %s)" % (self.strings[cast.clause],self.strings[cast.typeclause]) + return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) - def visit_function(self, func): + def visit_function(self, func, **kwargs): if len(self.select_stack): self.typemap.setdefault(func.name, func.type) if not self.apply_function_parens(func): - self.strings[func] = ".".join(func.packagenames + [func.name]) - self.froms[func] = self.strings[func] + return ".".join(func.packagenames + [func.name]) else: - self.strings[func] = ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.strings[func.clause_expr] - self.froms[func] = self.strings[func] + return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr) - def visit_compound_select(self, cs): - text = string.join([self.strings[c] for c in cs.selects], " " + cs.keyword + " ") - group_by = self.strings[cs._group_by_clause] + def visit_compound_select(self, cs, asfrom=False, **kwargs): + text = string.join([self.process(c) for c in cs.selects], " " + cs.keyword + " ") + group_by = self.process(cs._group_by_clause) if group_by: text += " GROUP BY " + group_by text += self.order_by_clause(cs) - text += self.visit_select_postclauses(cs) - self.strings[cs] = text - self.froms[cs] = "(" + text + ")" + text += (cs._limit or cs._offset) and self.limit_clause(cs) or "" + + if asfrom: + return "(" + text + ")" + else: + return text - def visit_unary(self, unary): - s = self.strings[unary.element] + def visit_unary(self, unary, **kwargs): + s = self.process(unary.element) if unary.operator: s = self.operator_string(unary.operator) + " " + s if unary.modifier: s = s + " " + unary.modifier - self.strings[unary] = s + return s - def visit_binary(self, binary): + def visit_binary(self, binary, **kwargs): op = self.operator_string(binary.operator) if callable(op): - self.strings[binary] = op(binary.left, binary.right) + return op(self.process(binary.left), self.process(binary.right)) else: - self.strings[binary] = self.strings[binary.left] + " " + op + " " + self.strings[binary.right] + return self.process(binary.left) + " " + op + " " + self.process(binary.right) def operator_string(self, operator): return self.operators.get(operator, str(operator)) - def visit_bindparam(self, bindparam): + def visit_bindparam(self, bindparam, **kwargs): # apply truncation to the ultimate generated name if bindparam.shortname != bindparam.key: @@ -416,13 +418,13 @@ class ANSICompiler(engine.Compiled): key = bindparam.key + tag count += 1 bindparam.key = key - self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam)) + return self.bindparam_string(self._truncate_bindparam(bindparam)) else: existing = self.binds.get(bindparam.key) if existing is not None and existing.unique: raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) - self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam)) self.binds[bindparam.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: @@ -465,20 +467,12 @@ class ANSICompiler(engine.Compiled): def bindparam_string(self, name): return self.bindtemplate % name - def visit_alias(self, alias): - self.froms[alias] = self.froms[alias.original] + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name)) - self.strings[alias] = self.strings[alias.original] - - def enter_select(self, select): - select._calculate_correlations(self.correlate_state) - self.select_stack.append(select) - - def enter_update(self, update): - update._calculate_correlations(self.correlate_state) + def visit_alias(self, alias, asfrom=False, **kwargs): + if asfrom: + return self.process(alias.original, asfrom=True, **kwargs) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name)) + else: + return self.process(alias.original, **kwargs) - def enter_delete(self, delete): - delete._calculate_correlations(self.correlate_state) - def label_select_column(self, select, column): """convert a column from a select's "columns" clause. @@ -503,57 +497,53 @@ class ANSICompiler(engine.Compiled): else: return None - def visit_select(self, select): + def visit_select(self, select, asfrom=False, **kwargs): + + select._calculate_correlations(self.correlate_state) + self.select_stack.append(select) + # the actual list of columns to print in the SELECT column list. - inner_columns = util.OrderedDict() + inner_columns = util.OrderedSet() froms = select._get_display_froms(self.correlate_state) - for f in froms: - if f not in self.strings: - self.traverse(f) for co in select.inner_columns: if select.use_labels: labelname = co._label if labelname is not None: l = co.label(labelname) - self.traverse(l) - inner_columns[labelname] = l + inner_columns.add(self.process(l)) else: self.traverse(co) - inner_columns[self.strings[co]] = co + inner_columns.add(self.process(co)) else: l = self.label_select_column(select, co) if l is not None: - self.traverse(l) - inner_columns[self.strings[l.obj]] = l + inner_columns.add(self.process(l)) else: - self.traverse(co) - inner_columns[self.strings[co]] = co + inner_columns.add(self.process(co)) self.select_stack.pop(-1) - collist = string.join([self.strings[v] for v in inner_columns.values()], ', ') + collist = string.join(inner_columns.difference(util.Set([None])), ', ') text = "SELECT " - text += self.visit_select_precolumns(select) + text += self.get_select_precolumns(select) text += collist whereclause = select._whereclause from_strings = [] for f in froms: - # special thingy used by oracle to redefine a join + from_strings.append(self.process(f, asfrom=True)) + w = self.get_whereclause(f) if w is not None: - # TODO: move this more into the oracle module if whereclause is not None: - whereclause = self.traverse(sql.and_(w, whereclause), stop_on=util.Set([whereclause, w])) + whereclause = sql.and_(w, whereclause) else: whereclause = w - from_strings.append(self.froms[f]) - if len(froms): text += " \nFROM " text += string.join(from_strings, ', ') @@ -561,41 +551,34 @@ class ANSICompiler(engine.Compiled): text += self.default_from() if whereclause is not None: - t = self.strings[whereclause] + t = self.process(whereclause) if t: text += " \nWHERE " + t - group_by = self.strings[select._group_by_clause] + group_by = self.process(select._group_by_clause) if group_by: text += " GROUP BY " + group_by if select._having is not None: - t = self.strings[select._having] + t = self.process(select._having) if t: text += " \nHAVING " + t text += self.order_by_clause(select) - text += self.visit_select_postclauses(select) + text += (select._limit or select._offset) and self.limit_clause(select) or "" text += self.for_update_clause(select) - self.strings[select] = text - self.froms[select] = "(" + text + ")" + if asfrom: + return "(" + text + ")" + else: + return text - def visit_select_precolumns(self, select): + def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list.""" - return select._distinct and "DISTINCT " or "" - def visit_select_postclauses(self, select): - """Called when building a ``SELECT`` statement, position is after all other ``SELECT`` clauses. - - Most DB syntaxes put ``LIMIT``/``OFFSET`` here. - """ - - return (select._limit or select._offset) and self.limit_clause(select) or "" - def order_by_clause(self, select): - order_by = self.strings[select._order_by_clause] + order_by = self.process(select._order_by_clause) if order_by: return " ORDER BY " + order_by else: @@ -617,19 +600,15 @@ class ANSICompiler(engine.Compiled): text += " OFFSET " + str(select._offset) return text - def visit_table(self, table): - self.froms[table] = self.preparer.format_table(table) - self.strings[table] = "" - - def visit_join(self, join): - righttext = self.froms[join.right] - if join.isouter: - self.froms[join] = (self.froms[join.left] + " LEFT OUTER JOIN " + righttext + - " ON " + self.strings[join.onclause]) + def visit_table(self, table, asfrom=False, **kwargs): + if asfrom: + return self.preparer.format_table(table) else: - self.froms[join] = (self.froms[join.left] + " JOIN " + righttext + - " ON " + self.strings[join.onclause]) - self.strings[join] = self.froms[join] + return "" + + def visit_join(self, join, asfrom=False, **kwargs): + return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ + self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) def uses_sequences_for_inserts(self): return False @@ -658,12 +637,11 @@ class ANSICompiler(engine.Compiled): self.isinsert = True colparams = self._get_colparams(insert_stmt, required_cols) - text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" + + return ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" + " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")") - self.strings[insert_stmt] = text - def visit_update(self, update_stmt): + update_stmt._calculate_correlations(self.correlate_state) # search for columns who will be required to have an explicit bound value. # for updates, this includes Python-side "onupdate" defaults. @@ -682,9 +660,9 @@ class ANSICompiler(engine.Compiled): text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ') if update_stmt._whereclause: - text += " WHERE " + self.strings[update_stmt._whereclause] + text += " WHERE " + self.process(update_stmt._whereclause) - self.strings[update_stmt] = text + return text def _get_colparams(self, stmt, required_cols): """create a set of tuples representing column/string pairs for use @@ -708,7 +686,7 @@ class ANSICompiler(engine.Compiled): def create_clause_param(col, value): self.traverse(value) self.inline_params.add(col) - return self.strings[value] + return self.process(value) self.inline_params = util.Set() @@ -746,27 +724,26 @@ class ANSICompiler(engine.Compiled): return values def visit_delete(self, delete_stmt): + delete_stmt._calculate_correlations(self.correlate_state) + text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) if delete_stmt._whereclause: - text += " WHERE " + self.strings[delete_stmt._whereclause] + text += " WHERE " + self.process(delete_stmt._whereclause) - self.strings[delete_stmt] = text + return text def visit_savepoint(self, savepoint_stmt): - text = "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) - self.strings[savepoint_stmt] = text + return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) def visit_rollback_to_savepoint(self, savepoint_stmt): - text = "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) - self.strings[savepoint_stmt] = text + return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) def visit_release_savepoint(self, savepoint_stmt): - text = "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) - self.strings[savepoint_stmt] = text + return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) def __str__(self): - return self.strings[self.statement] + return self.string class ANSISchemaBase(engine.SchemaIterator): def find_alterables(self, tables): diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 103415b238..481e63a79a 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -306,23 +306,23 @@ class FBDialect(ansisql.ANSIDialect): class FBCompiler(ansisql.ANSICompiler): """Firebird specific idiosincrasies""" - def visit_alias(self, alias): + def visit_alias(self, alias, asfrom=False, **kwargs): # Override to not use the AS keyword which FB 1.5 does not like - self.froms[alias] = self.froms[alias.original] + " " + self.preparer.format_alias(alias) - self.strings[alias] = self.strings[alias.original] + if asfrom: + return self.process(alias.original, asfrom=True) + " " + self.preparer.format_alias(alias) + else: + return self.process(alias.original, asfrom=True) def visit_function(self, func): if len(func.clauses): - super(FBCompiler, self).visit_function(func) + return super(FBCompiler, self).visit_function(func) else: - self.strings[func] = func.name + return func.name - def visit_insert_column(self, column, parameters): - # all column primary key inserts must be explicitly present - if column.primary_key: - parameters[column.key] = None + def uses_sequences_for_inserts(self): + return True - def visit_select_precolumns(self, select): + def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list Firebird puts the limit and offset right after the ``SELECT``... @@ -338,7 +338,7 @@ class FBCompiler(ansisql.ANSICompiler): return result def limit_clause(self, select): - """Already taken care of in the `visit_select_precolumns` method.""" + """Already taken care of in the `get_select_precolumns` method.""" return "" diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index 396cd487c2..6c15fbe2a3 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -369,7 +369,7 @@ class InfoCompiler(ansisql.ANSICompiler): def default_from(self): return " from systables where tabname = 'systables' " - def visit_select_precolumns( self , select ): + def get_select_precolumns( self , select ): s = select._distinct and "DISTINCT " or "" # only has limit if select._limit: @@ -390,13 +390,14 @@ class InfoCompiler(ansisql.ANSICompiler): return c._label.lower() except: return '' - + + # TODO: dont modify the original select, generate a new one a = [ __label(c) for c in select._raw_columns ] for c in select.order_by_clause.clauses: if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid': select.append_column( c ) - ansisql.ANSICompiler.visit_select(self, select) + return ansisql.ANSICompiler.visit_select(self, select) def limit_clause(self, select): return "" @@ -411,23 +412,20 @@ class InfoCompiler(ansisql.ANSICompiler): def visit_function( self , func ): if func.name.lower() == 'current_date': - self.strings[func] = "today" + return "today" elif func.name.lower() == 'current_time': - self.strings[func] = "CURRENT HOUR TO SECOND" + return "CURRENT HOUR TO SECOND" elif func.name.lower() in ( 'current_timestamp' , 'now' ): - self.strings[func] = "CURRENT YEAR TO SECOND" + return "CURRENT YEAR TO SECOND" else: - ansisql.ANSICompiler.visit_function( self , func ) + return ansisql.ANSICompiler.visit_function( self , func ) def visit_clauselist(self, list): try: li = [ c for c in list.clauses if c.name != 'oid' ] except: li = [ c for c in list.clauses ] - if list.parens: - self.strings[list] = "(" + ', '.join([s for s in [self.strings[c] for c in li] if s is not None ]) + ")" - else: - self.strings[list] = ', '.join([s for s in [self.strings[c] for c in li] if s is not None]) + return ', '.join([s for s in [self.process(c) for c in li] if s is not None]) class InfoSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, first_pk=False): diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 74103ecdbc..3f5a1708ee 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -789,7 +789,7 @@ class MSSQLCompiler(ansisql.ANSICompiler): super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs) self.tablealiases = {} - def visit_select_precolumns(self, select): + def get_select_precolumns(self, select): """ MS-SQL puts TOP, it's version of LIMIT here """ s = select._distinct and "DISTINCT " or "" if select._limit: @@ -802,41 +802,44 @@ class MSSQLCompiler(ansisql.ANSICompiler): # Limit in mssql is after the select keyword return "" - def visit_table(self, table): + def _schema_aliased_table(self, table): + if getattr(table, 'schema', None) is not None: + if not self.tablealiases.has_key(table): + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + + def visit_table(self, table, mssql_aliased=False, **kwargs): + if mssql_aliased: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + # alias schema-qualified tables - if getattr(table, 'schema', None) is not None and not self.tablealiases.has_key(table): - alias = table.alias() - self.tablealiases[table] = alias - self.traverse(alias) - self.froms[('alias', table)] = self.froms[table] - for c in alias.c: - self.traverse(c) - self.traverse(alias.oid_column) - self.tablealiases[alias] = self.froms[table] - self.froms[table] = self.froms[alias] + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, mssql_aliased=True, **kwargs) else: - super(MSSQLCompiler, self).visit_table(table) + return super(MSSQLCompiler, self).visit_table(table, **kwargs) - def visit_alias(self, alias): + def visit_alias(self, alias, **kwargs): # translate for schema-qualified table aliases - if self.froms.has_key(('alias', alias.original)): - self.froms[alias] = self.froms[('alias', alias.original)] + " AS " + alias.name - self.strings[alias] = "" - else: - super(MSSQLCompiler, self).visit_alias(alias) + self.tablealiases[alias.original] = alias + return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) def visit_column(self, column): - # translate for schema-qualified table aliases - super(MSSQLCompiler, self).visit_column(column) - if column.table is not None and self.tablealiases.has_key(column.table): - self.strings[column] = \ - self.strings[self.tablealiases[column.table].corresponding_column(column)] + if column.table is not None: + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + return self.process(t.corresponding_column(column)) + return super(MSSQLCompiler, self).visit_column(column) def visit_binary(self, binary): """Move bind parameters to the right-hand side of an operator, where possible.""" - if isinstance(binary.left, sql._BindParamClause) and binary.operator == '=': - binary.left, binary.right = binary.right, binary.left - super(MSSQLCompiler, self).visit_binary(binary) + if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq: + return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator)) + else: + return super(MSSQLCompiler, self).visit_binary(binary) def label_select_column(self, select, column): if isinstance(column, sql._Function): @@ -856,7 +859,7 @@ class MSSQLCompiler(ansisql.ANSICompiler): return '' def order_by_clause(self, select): - order_by = self.strings[select._order_by_clause] + order_by = self.process(select._order_by_clause) # MSSQL only allows ORDER BY in subqueries if there is a LIMIT if order_by and (not self.is_subquery(select) or select._limit): diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index bc0691d397..cbf71070f3 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -1315,13 +1315,13 @@ class MySQLCompiler(ansisql.ANSICompiler): } ) - def visit_cast(self, cast): + def visit_cast(self, cast, **kwargs): if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)): - return super(MySQLCompiler, self).visit_cast(cast) + return super(MySQLCompiler, self).visit_cast(cast, **kwargs) else: # so just skip the CAST altogether for now. # TODO: put whatever MySQL does for CAST here. - self.strings[cast] = self.strings[cast.clause] + return self.process(cast.clause) def for_update_clause(self, select): if select.for_update == 'read': diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 6d876d9109..2018b93ccc 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -457,6 +457,11 @@ class OracleDialect(ansisql.ANSIDialect): OracleDialect.logger = logging.class_logger(OracleDialect) +class _OuterJoinColumn(sql.ClauseElement): + __visit_name__ = 'outer_join_column' + def __init__(self, column): + self.column = column + class OracleCompiler(ansisql.ANSICompiler): """Oracle compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if @@ -470,6 +475,10 @@ class OracleCompiler(ansisql.ANSICompiler): } ) + def __init__(self, *args, **kwargs): + super(OracleCompiler, self).__init__(*args, **kwargs) + self.__wheres = {} + def default_from(self): """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. @@ -481,44 +490,45 @@ class OracleCompiler(ansisql.ANSICompiler): def apply_function_parens(self, func): return len(func.clauses) > 0 - def visit_join(self, join): + def visit_join(self, join, **kwargs): if self.dialect.use_ansi: - return ansisql.ANSICompiler.visit_join(self, join) - - self.froms[join] = self.froms[join.left] + ", " + self.froms[join.right] - where = self.wheres.get(join.left, None) + return ansisql.ANSICompiler.visit_join(self, join, **kwargs) + + (where, parentjoin) = self.__wheres.get(join, (None, None)) + + class VisitOn(sql.ClauseVisitor): + def visit_binary(s, binary): + if binary.operator == operator.eq: + if binary.left.table is join.right: + binary.left = _OuterJoinColumn(binary.left) + elif binary.right.table is join.right: + binary.right = _OuterJoinColumn(binary.right) + if where is not None: - self.wheres[join] = sql.and_(where, join.onclause) + self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(VisitOn().traverse(join.onclause, clone=True), where), parentjoin) else: - self.wheres[join] = join.onclause -# self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause) - self.strings[join] = self.froms[join] - - if join.isouter: - # if outer join, push on the right side table as the current "outertable" - self._outertable = join.right - - # now re-visit the onclause, which will be used as a where clause - # (the first visit occured via the Join object itself right before it called visit_join()) - self.traverse(join.onclause) - - self._outertable = None - - self.traverse_single(self.wheres[join]) + self.__wheres[join.left] = self.__wheres[join] = (VisitOn().traverse(join.onclause, clone=True), join) + return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) + + def get_whereclause(self, f): + if f in self.__wheres: + return self.__wheres[f][0] + else: + return None + + def visit_outer_join_column(self, vc): + return self.process(vc.column) + "(+)" def uses_sequences_for_inserts(self): return True - def visit_alias(self, alias): + def visit_alias(self, alias, asfrom=False, **kwargs): """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" - - self.froms[alias] = self.froms[alias.original] + " " + alias.name - self.strings[alias] = self.strings[alias.original] - - def visit_column(self, column): - ansisql.ANSICompiler.visit_column(self, column) - if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable: - self.strings[column] = self.strings[column] + "(+)" + + if asfrom: + return self.process(alias.original) + " " + alias.name + else: + return self.process(alias.original) def visit_insert(self, insert): """``INSERT`` s are required to have the primary keys be explicitly present. @@ -537,18 +547,18 @@ class OracleCompiler(ansisql.ANSICompiler): """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" pass - def visit_select(self, select): + def visit_select(self, select, **kwargs): """Look for ``LIMIT`` and OFFSET in a select statement, and if so tries to wrap it in a subquery with ``row_number()`` criterion. """ if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None): # to use ROW_NUMBER(), an ORDER BY is required. - orderby = self.strings[select._order_by_clause] + orderby = self.process(select._order_by_clause) if not orderby: orderby = select.oid_column self.traverse(orderby) - orderby = self.strings[orderby] + orderby = self.process(orderby) oldselect = select select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None) @@ -561,11 +571,9 @@ class OracleCompiler(ansisql.ANSICompiler): limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset)) else: limitselect.append_whereclause("ora_rn<=%d" % select._limit) - self.traverse(limitselect) - self.strings[oldselect] = self.strings[limitselect] - self.froms[oldselect] = self.froms[limitselect] + return self.process(limitselect) else: - ansisql.ANSICompiler.visit_select(self, select) + return ansisql.ANSICompiler.visit_select(self, select, **kwargs) def limit_clause(self, select): return "" diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index f97a18ff2a..d94618e1c0 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -559,7 +559,7 @@ class PGCompiler(ansisql.ANSICompiler): text += " OFFSET " + str(select._offset) return text - def visit_select_precolumns(self, select): + def get_select_precolumns(self, select): if select._distinct: if type(select._distinct) == bool: return "DISTINCT " diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 9cc35a8cc1..bb6a75738e 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -327,12 +327,12 @@ class SQLiteDialect(ansisql.ANSIDialect): class SQLiteCompiler(ansisql.ANSICompiler): def visit_cast(self, cast): if self.dialect.supports_cast: - super(SQLiteCompiler, self).visit_cast(cast) + return super(SQLiteCompiler, self).visit_cast(cast) else: if len(self.select_stack): # not sure if we want to set the typemap here... self.typemap.setdefault("CAST", cast.type) - self.strings[cast] = self.strings[cast.clause] + return self.process(cast.clause) def limit_clause(self, select): text = "" diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 5b1ebfcd39..35696c4271 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -430,10 +430,6 @@ class Compiled(sql.ClauseVisitor): self.bind = bind self.can_execute = statement.supports_execution() - def compile(self): - self.traverse(self.statement) - self.after_compile() - def __str__(self): """Return the string text of the generated SQL statement.""" diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index c9847f6e29..3b5defb04d 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -856,18 +856,15 @@ class ClauseVisitor(object): """ __traverse_options__ = {} - def traverse_single(self, obj): + def traverse_single(self, obj, **kwargs): meth = getattr(self, "visit_%s" % obj.__visit_name__, None) if meth: - return meth(obj) + return meth(obj, **kwargs) def traverse(self, obj, stop_on=None, clone=False): if clone: obj = obj._clone() - # entry flag indicates to also call a before-descent "enter_XXXX" method - entry = self.__traverse_options__.get('entry', False) - v = self visitors = [] while v is not None: @@ -877,12 +874,6 @@ class ClauseVisitor(object): def _trav(obj): if stop_on is not None and obj in stop_on: return - if entry: - for v in visitors: - meth = getattr(v, "enter_%s" % obj.__visit_name__, None) - if meth: - meth(obj) - if clone: obj._copy_internals() for c in obj.get_children(**self.__traverse_options__): diff --git a/test/orm/sharding/alltests.py b/test/orm/sharding/alltests.py index 443c5cd793..0cdb838a9d 100644 --- a/test/orm/sharding/alltests.py +++ b/test/orm/sharding/alltests.py @@ -1,8 +1,6 @@ import testbase import unittest -import orm.inheritance.alltests as inheritance - def suite(): modules_to_test = ( 'orm.sharding.shard', @@ -13,7 +11,6 @@ def suite(): for token in name.split('.')[1:]: mod = getattr(mod, token) alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - alltests.addTest(inheritance.suite()) return alltests diff --git a/test/sql/labels.py b/test/sql/labels.py index 5cca4160e5..553a3a3bc3 100644 --- a/test/sql/labels.py +++ b/test/sql/labels.py @@ -94,7 +94,7 @@ class LongLabelsTest(PersistTest): x = select([tt], use_labels=True, order_by=tt.oid_column).compile(dialect=dialect) #print x # assert it doesnt end with "ORDER BY foo.some_large_named_table_this_is_the_primarykey_column" - assert str(x).endswith("""ORDER BY foo.some_large_named_table_t_1""") + assert str(x).endswith("""ORDER BY foo.some_large_named_table_t_2""") if __name__ == '__main__': testbase.main() diff --git a/test/sql/select.py b/test/sql/select.py index 497c69baa1..72f328535d 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -130,6 +130,15 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A crit = q.c.myid == table1.c.myid self.runtest(select(['*'], crit), """SELECT * FROM (SELECT mytable.myid AS myid FROM mytable ORDER BY mytable.myid) AS foo, mytable WHERE foo.myid = mytable.myid""", dialect=sqlite.dialect()) self.runtest(select(['*'], crit), """SELECT * FROM (SELECT mytable.myid AS myid FROM mytable) AS foo, mytable WHERE foo.myid = mytable.myid""", dialect=mssql.dialect()) + + def testmssql_aliases_schemas(self): + self.runtest(table4.select(), "SELECT remotetable.rem_id, remotetable.datatype_id, remotetable.value FROM remote_owner.remotetable") + + dialect = mssql.dialect() + self.runtest(table4.select(), "SELECT remotetable_1.rem_id, remotetable_1.datatype_id, remotetable_1.value FROM remote_owner.remotetable AS remotetable_1", dialect=dialect) + + # TODO: this is probably incorrect; no "AS " is being applied to the table + self.runtest(table1.join(table4, table1.c.myid==table4.c.rem_id).select(), "SELECT mytable.myid, mytable.name, mytable.description, remotetable.rem_id, remotetable.datatype_id, remotetable.value FROM mytable JOIN remote_owner.remotetable ON remotetable.rem_id = mytable.myid") def testdontovercorrelate(self): self.runtest(select([table1], from_obj=[table1, table1.select()]), """SELECT mytable.myid, mytable.name, mytable.description FROM mytable, (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable)""") @@ -232,7 +241,9 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A order_by = ['dist', places.c.nm] ) - self.runtest(q,"SELECT places.id, places.nm, zips.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE zips.zipcode = :zips_zipcode_1), (SELECT zips.longitude FROM zips WHERE zips.zipcode = :zips_zipcode_2)) AS dist FROM places, zips WHERE zips.zipcode = :zips_zipcode ORDER BY dist, places.nm") + self.runtest(q,"SELECT places.id, places.nm, zips.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE " + "zips.zipcode = :zips_zipcode), (SELECT zips.longitude FROM zips WHERE zips.zipcode = :zips_zipcode_1)) AS dist " + "FROM places, zips WHERE zips.zipcode = :zips_zipcode_2 ORDER BY dist, places.nm") zalias = zips.alias('main_zip') qlat = select([zips.c.latitude], zips.c.zipcode == zalias.c.zipcode, scalar=True) @@ -759,6 +770,7 @@ EXISTS (select yay from foo where boo = lar)", dialect=postgres.dialect() ) + self.runtest(query, "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid(+) AND \ @@ -1052,7 +1064,8 @@ class CRUDTest(SQLTest): values = { table1.c.name : table1.c.name + "lala", table1.c.myid : func.do_stuff(table1.c.myid, literal('hoho')) - }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=(mytable.name || :mytable_name) WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal || mytable.name || :literal_1") + }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal), name=(mytable.name || :mytable_name) " + "WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal_1 || mytable.name || :literal_2") def testcorrelatedupdate(self): # test against a straight text subquery diff --git a/test/zblog/tables.py b/test/zblog/tables.py index 3a5059972d..5b4054a195 100644 --- a/test/zblog/tables.py +++ b/test/zblog/tables.py @@ -10,7 +10,7 @@ users = Table('users', metadata, Column('user_id', Integer, primary_key=True), Column('user_name', String(30), nullable=False), Column('fullname', String(100), nullable=False), - Column('password', String(30), nullable=False), + Column('password', String(40), nullable=False), Column('groupname', String(20), nullable=False), )