# actually present in the generated SQL
self.bind_names = {}
- # a dictionary which stores the string representation for every ClauseElement
- # processed by this compiler.
- self.strings = {}
-
- # a dictionary which stores the string representation for ClauseElements
- # processed by this compiler, which are to be used in the FROM clause
- # of a select. items are often placed in "froms" as well as "strings"
- # and sometimes with different representations.
- self.froms = {}
-
- # slightly hacky. maps FROM clauses to WHERE clauses, and used in select
- # generation to modify the WHERE clause of the select. currently a hack
- # used by the oracle module.
- self.wheres = {}
-
# when the compiler visits a SELECT statement, the clause object is appended
# to this stack. various visit operations will check this stack to determine
# additional choices (TODO: it seems to be all typemap stuff. shouldnt this only
# this re will search for params like :param
# it has a negative lookbehind for an extra ':' so that it doesnt match
# postgres '::text' tokens
- text = self.strings[self.statement]
+ text = self.string
if ':' not in text:
return
text = BIND_PARAMS.sub(getnum, text)
# un-escape any \:params
text = BIND_PARAMS_ESC.sub(lambda m: m.group(1), text)
- self.strings[self.statement] = text
+ self.string = text
+ def compile(self):
+ self.string = self.process(self.statement)
+ self.after_compile()
+
+ def process(self, obj, **kwargs):
+ return self.traverse_single(obj, **kwargs)
+
def is_subquery(self, select):
return self.correlate_state[select].get('is_subquery', False)
def get_whereclause(self, obj):
- return self.wheres.get(obj, None)
+ """given a FROM clause, return an additional WHERE condition that should be
+ applied to a SELECT.
+
+ Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN
+ constructs in non-ansi mode.
+ """
+
+ return None
def construct_params(self, params):
"""Return a sql.ClauseParameters object.
return ""
- def visit_grouping(self, grouping):
- self.strings[grouping] = self.froms[grouping] = "(" + self.strings[grouping.elem] + ")"
+ def visit_grouping(self, grouping, **kwargs):
+ return "(" + self.process(grouping.elem) + ")"
def visit_label(self, label):
labelname = self._truncated_identifier("colident", label.name)
if isinstance(label.obj, sql._ColumnClause):
self.column_labels[label.obj._label] = labelname
self.column_labels[label.name] = labelname
- self.strings[label] = " ".join([self.strings[label.obj], self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)])
+ return " ".join([self.process(label.obj), self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)])
- def visit_column(self, column):
+ def visit_column(self, column, **kwargs):
# there is actually somewhat of a ruleset when you would *not* necessarily
# want to truncate a column identifier, if its mapped to the name of a
# physical column. but thats very hard to identify at this point, and
else:
name = column.name
+ if len(self.select_stack):
+ # if we are within a visit to a Select, set up the "typemap"
+ # for this column which is used to translate result set values
+ self.typemap.setdefault(name.lower(), column.type)
+ self.column_labels.setdefault(column._label, name.lower())
+
if column.table is None or not column.table.named_with_column():
- self.strings[column] = self.preparer.format_column(column, name=name)
+ return self.preparer.format_column(column, name=name)
else:
if column.table.oid_column is column:
n = self.dialect.oid_column_name(column)
if n is not None:
- self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n)
+ return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n)
elif len(column.table.primary_key) != 0:
pk = list(column.table.primary_key)[0]
pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name))
- self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name))
+ return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name))
else:
- self.strings[column] = None
+ return None
else:
- self.strings[column] = self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name))
+ return self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name))
- if len(self.select_stack):
- # if we are within a visit to a Select, set up the "typemap"
- # for this column which is used to translate result set values
- self.typemap.setdefault(name.lower(), column.type)
- self.column_labels.setdefault(column._label, name.lower())
- def visit_fromclause(self, fromclause):
- self.froms[fromclause] = fromclause.name
+ def visit_fromclause(self, fromclause, **kwargs):
+ return fromclause.name
- def visit_index(self, index):
- self.strings[index] = index.name
+ def visit_index(self, index, **kwargs):
+ return index.name
- def visit_typeclause(self, typeclause):
- self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec()
+ def visit_typeclause(self, typeclause, **kwargs):
+ return typeclause.type.dialect_impl(self.dialect).get_col_spec()
- def visit_textclause(self, textclause):
- self.strings[textclause] = textclause.text
- self.froms[textclause] = textclause.text
+ def visit_textclause(self, textclause, **kwargs):
+ for bind in textclause.bindparams.values():
+ self.process(bind)
if textclause.typemap is not None:
self.typemap.update(textclause.typemap)
+ return textclause.text
- def visit_null(self, null):
- self.strings[null] = 'NULL'
+ def visit_null(self, null, **kwargs):
+ return 'NULL'
- def visit_clauselist(self, clauselist):
+ def visit_clauselist(self, clauselist, **kwargs):
sep = clauselist.operator
if sep is None:
sep = " "
sep = ', '
else:
sep = " " + self.operator_string(clauselist.operator) + " "
- self.strings[clauselist] = string.join([s for s in [self.strings[c] for c in clauselist.clauses] if s is not None], sep)
+ return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep)
def apply_function_parens(self, func):
return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
- def visit_calculatedclause(self, clause):
- self.strings[clause] = self.strings[clause.clause_expr]
+ def visit_calculatedclause(self, clause, **kwargs):
+ return self.process(clause.clause_expr)
- def visit_cast(self, cast):
+ def visit_cast(self, cast, **kwargs):
if len(self.select_stack):
# not sure if we want to set the typemap here...
self.typemap.setdefault("CAST", cast.type)
- self.strings[cast] = "CAST(%s AS %s)" % (self.strings[cast.clause],self.strings[cast.typeclause])
+ return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
- def visit_function(self, func):
+ def visit_function(self, func, **kwargs):
if len(self.select_stack):
self.typemap.setdefault(func.name, func.type)
if not self.apply_function_parens(func):
- self.strings[func] = ".".join(func.packagenames + [func.name])
- self.froms[func] = self.strings[func]
+ return ".".join(func.packagenames + [func.name])
else:
- self.strings[func] = ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.strings[func.clause_expr]
- self.froms[func] = self.strings[func]
+ return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr)
- def visit_compound_select(self, cs):
- text = string.join([self.strings[c] for c in cs.selects], " " + cs.keyword + " ")
- group_by = self.strings[cs._group_by_clause]
+ def visit_compound_select(self, cs, asfrom=False, **kwargs):
+ text = string.join([self.process(c) for c in cs.selects], " " + cs.keyword + " ")
+ group_by = self.process(cs._group_by_clause)
if group_by:
text += " GROUP BY " + group_by
text += self.order_by_clause(cs)
- text += self.visit_select_postclauses(cs)
- self.strings[cs] = text
- self.froms[cs] = "(" + text + ")"
+ text += (cs._limit or cs._offset) and self.limit_clause(cs) or ""
+
+ if asfrom:
+ return "(" + text + ")"
+ else:
+ return text
- def visit_unary(self, unary):
- s = self.strings[unary.element]
+ def visit_unary(self, unary, **kwargs):
+ s = self.process(unary.element)
if unary.operator:
s = self.operator_string(unary.operator) + " " + s
if unary.modifier:
s = s + " " + unary.modifier
- self.strings[unary] = s
+ return s
- def visit_binary(self, binary):
+ def visit_binary(self, binary, **kwargs):
op = self.operator_string(binary.operator)
if callable(op):
- self.strings[binary] = op(binary.left, binary.right)
+ return op(self.process(binary.left), self.process(binary.right))
else:
- self.strings[binary] = self.strings[binary.left] + " " + op + " " + self.strings[binary.right]
+ return self.process(binary.left) + " " + op + " " + self.process(binary.right)
def operator_string(self, operator):
return self.operators.get(operator, str(operator))
- def visit_bindparam(self, bindparam):
+ def visit_bindparam(self, bindparam, **kwargs):
# apply truncation to the ultimate generated name
if bindparam.shortname != bindparam.key:
key = bindparam.key + tag
count += 1
bindparam.key = key
- self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam))
+ return self.bindparam_string(self._truncate_bindparam(bindparam))
else:
existing = self.binds.get(bindparam.key)
if existing is not None and existing.unique:
raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
- self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam))
self.binds[bindparam.key] = bindparam
+ return self.bindparam_string(self._truncate_bindparam(bindparam))
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names:
def bindparam_string(self, name):
return self.bindtemplate % name
- def visit_alias(self, alias):
- self.froms[alias] = self.froms[alias.original] + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name))
- self.strings[alias] = self.strings[alias.original]
-
- def enter_select(self, select):
- select._calculate_correlations(self.correlate_state)
- self.select_stack.append(select)
-
- def enter_update(self, update):
- update._calculate_correlations(self.correlate_state)
+ def visit_alias(self, alias, asfrom=False, **kwargs):
+ if asfrom:
+ return self.process(alias.original, asfrom=True, **kwargs) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name))
+ else:
+ return self.process(alias.original, **kwargs)
- def enter_delete(self, delete):
- delete._calculate_correlations(self.correlate_state)
-
def label_select_column(self, select, column):
"""convert a column from a select's "columns" clause.
else:
return None
- def visit_select(self, select):
+ def visit_select(self, select, asfrom=False, **kwargs):
+
+ select._calculate_correlations(self.correlate_state)
+ self.select_stack.append(select)
+
# the actual list of columns to print in the SELECT column list.
- inner_columns = util.OrderedDict()
+ inner_columns = util.OrderedSet()
froms = select._get_display_froms(self.correlate_state)
- for f in froms:
- if f not in self.strings:
- self.traverse(f)
for co in select.inner_columns:
if select.use_labels:
labelname = co._label
if labelname is not None:
l = co.label(labelname)
- self.traverse(l)
- inner_columns[labelname] = l
+ inner_columns.add(self.process(l))
else:
self.traverse(co)
- inner_columns[self.strings[co]] = co
+ inner_columns.add(self.process(co))
else:
l = self.label_select_column(select, co)
if l is not None:
- self.traverse(l)
- inner_columns[self.strings[l.obj]] = l
+ inner_columns.add(self.process(l))
else:
- self.traverse(co)
- inner_columns[self.strings[co]] = co
+ inner_columns.add(self.process(co))
self.select_stack.pop(-1)
- collist = string.join([self.strings[v] for v in inner_columns.values()], ', ')
+ collist = string.join(inner_columns.difference(util.Set([None])), ', ')
text = "SELECT "
- text += self.visit_select_precolumns(select)
+ text += self.get_select_precolumns(select)
text += collist
whereclause = select._whereclause
from_strings = []
for f in froms:
- # special thingy used by oracle to redefine a join
+ from_strings.append(self.process(f, asfrom=True))
+
w = self.get_whereclause(f)
if w is not None:
- # TODO: move this more into the oracle module
if whereclause is not None:
- whereclause = self.traverse(sql.and_(w, whereclause), stop_on=util.Set([whereclause, w]))
+ whereclause = sql.and_(w, whereclause)
else:
whereclause = w
- from_strings.append(self.froms[f])
-
if len(froms):
text += " \nFROM "
text += string.join(from_strings, ', ')
text += self.default_from()
if whereclause is not None:
- t = self.strings[whereclause]
+ t = self.process(whereclause)
if t:
text += " \nWHERE " + t
- group_by = self.strings[select._group_by_clause]
+ group_by = self.process(select._group_by_clause)
if group_by:
text += " GROUP BY " + group_by
if select._having is not None:
- t = self.strings[select._having]
+ t = self.process(select._having)
if t:
text += " \nHAVING " + t
text += self.order_by_clause(select)
- text += self.visit_select_postclauses(select)
+ text += (select._limit or select._offset) and self.limit_clause(select) or ""
text += self.for_update_clause(select)
- self.strings[select] = text
- self.froms[select] = "(" + text + ")"
+ if asfrom:
+ return "(" + text + ")"
+ else:
+ return text
- def visit_select_precolumns(self, select):
+ def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just before column list."""
-
return select._distinct and "DISTINCT " or ""
- def visit_select_postclauses(self, select):
- """Called when building a ``SELECT`` statement, position is after all other ``SELECT`` clauses.
-
- Most DB syntaxes put ``LIMIT``/``OFFSET`` here.
- """
-
- return (select._limit or select._offset) and self.limit_clause(select) or ""
-
def order_by_clause(self, select):
- order_by = self.strings[select._order_by_clause]
+ order_by = self.process(select._order_by_clause)
if order_by:
return " ORDER BY " + order_by
else:
text += " OFFSET " + str(select._offset)
return text
- def visit_table(self, table):
- self.froms[table] = self.preparer.format_table(table)
- self.strings[table] = ""
-
- def visit_join(self, join):
- righttext = self.froms[join.right]
- if join.isouter:
- self.froms[join] = (self.froms[join.left] + " LEFT OUTER JOIN " + righttext +
- " ON " + self.strings[join.onclause])
+ def visit_table(self, table, asfrom=False, **kwargs):
+ if asfrom:
+ return self.preparer.format_table(table)
else:
- self.froms[join] = (self.froms[join.left] + " JOIN " + righttext +
- " ON " + self.strings[join.onclause])
- self.strings[join] = self.froms[join]
+ return ""
+
+ def visit_join(self, join, asfrom=False, **kwargs):
+ return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \
+ self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
def uses_sequences_for_inserts(self):
return False
self.isinsert = True
colparams = self._get_colparams(insert_stmt, required_cols)
- text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
+ return ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
" VALUES (" + string.join([c[1] for c in colparams], ', ') + ")")
- self.strings[insert_stmt] = text
-
def visit_update(self, update_stmt):
+ update_stmt._calculate_correlations(self.correlate_state)
# search for columns who will be required to have an explicit bound value.
# for updates, this includes Python-side "onupdate" defaults.
text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ')
if update_stmt._whereclause:
- text += " WHERE " + self.strings[update_stmt._whereclause]
+ text += " WHERE " + self.process(update_stmt._whereclause)
- self.strings[update_stmt] = text
+ return text
def _get_colparams(self, stmt, required_cols):
"""create a set of tuples representing column/string pairs for use
def create_clause_param(col, value):
self.traverse(value)
self.inline_params.add(col)
- return self.strings[value]
+ return self.process(value)
self.inline_params = util.Set()
return values
def visit_delete(self, delete_stmt):
+ delete_stmt._calculate_correlations(self.correlate_state)
+
text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
if delete_stmt._whereclause:
- text += " WHERE " + self.strings[delete_stmt._whereclause]
+ text += " WHERE " + self.process(delete_stmt._whereclause)
- self.strings[delete_stmt] = text
+ return text
def visit_savepoint(self, savepoint_stmt):
- text = "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
- self.strings[savepoint_stmt] = text
+ return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
def visit_rollback_to_savepoint(self, savepoint_stmt):
- text = "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
- self.strings[savepoint_stmt] = text
+ return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
def visit_release_savepoint(self, savepoint_stmt):
- text = "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
- self.strings[savepoint_stmt] = text
+ return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
def __str__(self):
- return self.strings[self.statement]
+ return self.string
class ANSISchemaBase(engine.SchemaIterator):
def find_alterables(self, tables):
super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs)
self.tablealiases = {}
- def visit_select_precolumns(self, select):
+ def get_select_precolumns(self, select):
""" MS-SQL puts TOP, it's version of LIMIT here """
s = select._distinct and "DISTINCT " or ""
if select._limit:
# Limit in mssql is after the select keyword
return ""
- def visit_table(self, table):
+ def _schema_aliased_table(self, table):
+ if getattr(table, 'schema', None) is not None:
+ if not self.tablealiases.has_key(table):
+ self.tablealiases[table] = table.alias()
+ return self.tablealiases[table]
+ else:
+ return None
+
+ def visit_table(self, table, mssql_aliased=False, **kwargs):
+ if mssql_aliased:
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
# alias schema-qualified tables
- if getattr(table, 'schema', None) is not None and not self.tablealiases.has_key(table):
- alias = table.alias()
- self.tablealiases[table] = alias
- self.traverse(alias)
- self.froms[('alias', table)] = self.froms[table]
- for c in alias.c:
- self.traverse(c)
- self.traverse(alias.oid_column)
- self.tablealiases[alias] = self.froms[table]
- self.froms[table] = self.froms[alias]
+ alias = self._schema_aliased_table(table)
+ if alias is not None:
+ return self.process(alias, mssql_aliased=True, **kwargs)
else:
- super(MSSQLCompiler, self).visit_table(table)
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
- def visit_alias(self, alias):
+ def visit_alias(self, alias, **kwargs):
# translate for schema-qualified table aliases
- if self.froms.has_key(('alias', alias.original)):
- self.froms[alias] = self.froms[('alias', alias.original)] + " AS " + alias.name
- self.strings[alias] = ""
- else:
- super(MSSQLCompiler, self).visit_alias(alias)
+ self.tablealiases[alias.original] = alias
+ return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
def visit_column(self, column):
- # translate for schema-qualified table aliases
- super(MSSQLCompiler, self).visit_column(column)
- if column.table is not None and self.tablealiases.has_key(column.table):
- self.strings[column] = \
- self.strings[self.tablealiases[column.table].corresponding_column(column)]
+ if column.table is not None:
+ # translate for schema-qualified table aliases
+ t = self._schema_aliased_table(column.table)
+ if t is not None:
+ return self.process(t.corresponding_column(column))
+ return super(MSSQLCompiler, self).visit_column(column)
def visit_binary(self, binary):
"""Move bind parameters to the right-hand side of an operator, where possible."""
- if isinstance(binary.left, sql._BindParamClause) and binary.operator == '=':
- binary.left, binary.right = binary.right, binary.left
- super(MSSQLCompiler, self).visit_binary(binary)
+ if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq:
+ return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator))
+ else:
+ return super(MSSQLCompiler, self).visit_binary(binary)
def label_select_column(self, select, column):
if isinstance(column, sql._Function):
return ''
def order_by_clause(self, select):
- order_by = self.strings[select._order_by_clause]
+ order_by = self.process(select._order_by_clause)
# MSSQL only allows ORDER BY in subqueries if there is a LIMIT
if order_by and (not self.is_subquery(select) or select._limit):
OracleDialect.logger = logging.class_logger(OracleDialect)
+class _OuterJoinColumn(sql.ClauseElement):
+ __visit_name__ = 'outer_join_column'
+ def __init__(self, column):
+ self.column = column
+
class OracleCompiler(ansisql.ANSICompiler):
"""Oracle compiler modifies the lexical structure of Select
statements to work under non-ANSI configured Oracle databases, if
}
)
+ def __init__(self, *args, **kwargs):
+ super(OracleCompiler, self).__init__(*args, **kwargs)
+ self.__wheres = {}
+
def default_from(self):
"""Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
def apply_function_parens(self, func):
return len(func.clauses) > 0
- def visit_join(self, join):
+ def visit_join(self, join, **kwargs):
if self.dialect.use_ansi:
- return ansisql.ANSICompiler.visit_join(self, join)
-
- self.froms[join] = self.froms[join.left] + ", " + self.froms[join.right]
- where = self.wheres.get(join.left, None)
+ return ansisql.ANSICompiler.visit_join(self, join, **kwargs)
+
+ (where, parentjoin) = self.__wheres.get(join, (None, None))
+
+ class VisitOn(sql.ClauseVisitor):
+ def visit_binary(s, binary):
+ if binary.operator == operator.eq:
+ if binary.left.table is join.right:
+ binary.left = _OuterJoinColumn(binary.left)
+ elif binary.right.table is join.right:
+ binary.right = _OuterJoinColumn(binary.right)
+
if where is not None:
- self.wheres[join] = sql.and_(where, join.onclause)
+ self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(VisitOn().traverse(join.onclause, clone=True), where), parentjoin)
else:
- self.wheres[join] = join.onclause
-# self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause)
- self.strings[join] = self.froms[join]
-
- if join.isouter:
- # if outer join, push on the right side table as the current "outertable"
- self._outertable = join.right
-
- # now re-visit the onclause, which will be used as a where clause
- # (the first visit occured via the Join object itself right before it called visit_join())
- self.traverse(join.onclause)
-
- self._outertable = None
-
- self.traverse_single(self.wheres[join])
+ self.__wheres[join.left] = self.__wheres[join] = (VisitOn().traverse(join.onclause, clone=True), join)
+ return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
+
+ def get_whereclause(self, f):
+ if f in self.__wheres:
+ return self.__wheres[f][0]
+ else:
+ return None
+
+ def visit_outer_join_column(self, vc):
+ return self.process(vc.column) + "(+)"
def uses_sequences_for_inserts(self):
return True
- def visit_alias(self, alias):
+ def visit_alias(self, alias, asfrom=False, **kwargs):
"""Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??"""
-
- self.froms[alias] = self.froms[alias.original] + " " + alias.name
- self.strings[alias] = self.strings[alias.original]
-
- def visit_column(self, column):
- ansisql.ANSICompiler.visit_column(self, column)
- if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable:
- self.strings[column] = self.strings[column] + "(+)"
+
+ if asfrom:
+ return self.process(alias.original) + " " + alias.name
+ else:
+ return self.process(alias.original)
def visit_insert(self, insert):
"""``INSERT`` s are required to have the primary keys be explicitly present.
"""Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
pass
- def visit_select(self, select):
+ def visit_select(self, select, **kwargs):
"""Look for ``LIMIT`` and OFFSET in a select statement, and if
so tries to wrap it in a subquery with ``row_number()`` criterion.
"""
if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None):
# to use ROW_NUMBER(), an ORDER BY is required.
- orderby = self.strings[select._order_by_clause]
+ orderby = self.process(select._order_by_clause)
if not orderby:
orderby = select.oid_column
self.traverse(orderby)
- orderby = self.strings[orderby]
+ orderby = self.process(orderby)
oldselect = select
select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None)
limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset))
else:
limitselect.append_whereclause("ora_rn<=%d" % select._limit)
- self.traverse(limitselect)
- self.strings[oldselect] = self.strings[limitselect]
- self.froms[oldselect] = self.froms[limitselect]
+ return self.process(limitselect)
else:
- ansisql.ANSICompiler.visit_select(self, select)
+ return ansisql.ANSICompiler.visit_select(self, select, **kwargs)
def limit_clause(self, select):
return ""