# actually present in the generated SQL
self.bind_names = {}
- # when the compiler visits a SELECT statement, the clause object is appended
- # to this stack. various visit operations will check this stack to determine
- # additional choices (TODO: it seems to be all typemap stuff. shouldnt this only
- # apply to the topmost-level SELECT statement ?)
- self.select_stack = []
-
+ # a stack. what recursive compiler doesn't have a stack ? :)
+ self.stack = []
+
# a dictionary of result-set column names (strings) to TypeEngine instances,
# which will be passed to a ResultProxy and used for resultset-level value conversion
self.typemap = {}
# an ANSIIdentifierPreparer that formats the quoting of identifiers
self.preparer = dialect.identifier_preparer
- # a dictionary containing attributes about all select()
- # elements located within the clause, regarding which are subqueries, which are
- # selected from, and which elements should be correlated to an enclosing select.
- # used mostly to determine the list of FROM elements for each select statement, as well
- # as some dialect-specific rules regarding subqueries.
- self.correlate_state = {}
-
# for UPDATE and INSERT statements, a set of columns whos values are being set
# from a SQL expression (i.e., not one of the bind parameter values). if present,
# default-value logic in the Dialect knows not to fire off column defaults
self.string = self.process(self.statement)
self.after_compile()
- def process(self, obj, **kwargs):
- return self.traverse_single(obj, **kwargs)
+ def process(self, obj, stack=None, **kwargs):
+ if stack:
+ self.stack.append(stack)
+ try:
+ return self.traverse_single(obj, **kwargs)
+ finally:
+ if stack:
+ self.stack.pop(-1)
def is_subquery(self, select):
- return self.correlate_state[select].get('is_subquery', False)
+ return self.stack and self.stack[-1].get('is_subquery')
def get_whereclause(self, obj):
"""given a FROM clause, return an additional WHERE condition that should be
def visit_label(self, label):
labelname = self._truncated_identifier("colident", label.name)
- if self.select_stack:
+ if self.stack and self.stack[-1].get('select'):
self.typemap.setdefault(labelname.lower(), label.obj.type)
if isinstance(label.obj, sql._ColumnClause):
self.column_labels[label.obj._label] = labelname
else:
name = column.name
- if self.select_stack:
+ if self.stack and self.stack[-1].get('select'):
# if we are within a visit to a Select, set up the "typemap"
# for this column which is used to translate result set values
self.typemap.setdefault(name.lower(), column.type)
return self.process(clause.clause_expr)
def visit_cast(self, cast, **kwargs):
- if self.select_stack:
+ if self.stack and self.stack[-1].get('select'):
# not sure if we want to set the typemap here...
self.typemap.setdefault("CAST", cast.type)
return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
def visit_function(self, func, **kwargs):
- if self.select_stack:
+ if self.stack and self.stack[-1].get('select'):
self.typemap.setdefault(func.name, func.type)
if not self.apply_function_parens(func):
return ".".join(func.packagenames + [func.name])
else:
return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr)
- def visit_compound_select(self, cs, asfrom=False, **kwargs):
- text = string.join([self.process(c) for c in cs.selects], " " + cs.keyword + " ")
- group_by = self.process(cs._group_by_clause)
+ def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
+ text = string.join([self.process(c, asfrom=asfrom, parens=False) for c in cs.selects], " " + cs.keyword + " ")
+ group_by = self.process(cs._group_by_clause, asfrom=asfrom)
if group_by:
text += " GROUP BY " + group_by
text += self.order_by_clause(cs)
text += (cs._limit or cs._offset) and self.limit_clause(cs) or ""
- if asfrom:
+ if asfrom and parens:
return "(" + text + ")"
else:
return text
# names look like table.colname. so if column is in a "selected from"
# subquery, label it synoymously with its column name
if \
- self.correlate_state[select].get('is_selected_from', False) and \
+ (self.stack and self.stack[-1].get('is_selected_from')) and \
isinstance(column, sql._ColumnClause) and \
not column.is_literal and \
column.table is not None and \
return column.label(column.name)
else:
return None
-
- def visit_select(self, select, asfrom=False, **kwargs):
- select._calculate_correlations(self.correlate_state)
- self.select_stack.append(select)
+ def visit_select(self, select, asfrom=False, parens=True, **kwargs):
+
+ stack_entry = {'select':select}
+
+ if asfrom:
+ stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True
+ elif self.stack and self.stack[-1].get('select'):
+ stack_entry['is_subquery'] = True
+
+ if self.stack and self.stack[-1].get('from'):
+ existingfroms = self.stack[-1]['from']
+ else:
+ existingfroms = None
+ froms = select._get_display_froms(existingfroms)
+
+ correlate_froms = util.Set()
+ for f in froms:
+ correlate_froms.add(f)
+ for f2 in f._get_from_objects():
+ correlate_froms.add(f2)
+
+ # TODO: might want to propigate existing froms for select(select(select))
+ # where innermost select should correlate to outermost
+# if existingfroms:
+# correlate_froms = correlate_froms.union(existingfroms)
+ stack_entry['from'] = correlate_froms
+ self.stack.append(stack_entry)
# the actual list of columns to print in the SELECT column list.
inner_columns = util.OrderedSet()
-
- froms = select._get_display_froms(self.correlate_state)
for co in select.inner_columns:
if select.use_labels:
inner_columns.add(self.process(l))
else:
inner_columns.add(self.process(co))
-
- self.select_stack.pop(-1)
-
+
collist = string.join(inner_columns.difference(util.Set([None])), ', ')
text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " "
text += (select._limit or select._offset) and self.limit_clause(select) or ""
text += self.for_update_clause(select)
- if asfrom:
+ self.stack.pop(-1)
+
+ if asfrom and parens:
return "(" + text + ")"
else:
return text
" VALUES (" + string.join([c[1] for c in colparams], ', ') + ")")
def visit_update(self, update_stmt):
- update_stmt._calculate_correlations(self.correlate_state)
+ self.stack.append({'from':util.Set([update_stmt.table])})
# search for columns who will be required to have an explicit bound value.
# for updates, this includes Python-side "onupdate" defaults.
if update_stmt._whereclause:
text += " WHERE " + self.process(update_stmt._whereclause)
-
+
+ self.stack.pop(-1)
+
return text
def _get_colparams(self, stmt, required_cols):
return values
def visit_delete(self, delete_stmt):
- delete_stmt._calculate_correlations(self.correlate_state)
+ self.stack.append({'from':util.Set([delete_stmt.table])})
text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
if delete_stmt._whereclause:
text += " WHERE " + self.process(delete_stmt._whereclause)
+ self.stack.pop(-1)
+
return text
def visit_savepoint(self, savepoint_stmt):
_SelectBaseMixin.__init__(self, **kwargs)
- def _get_display_froms(self, correlation_state=None):
+ def _get_display_froms(self, existing_froms=None):
"""return the full list of 'from' clauses to be displayed.
- takes into account an optional 'correlation_state'
- dictionary which contains information about this Select's
- correlation to an enclosing select, which may cause some 'from'
- clauses to not display in this Select's FROM clause.
- this dictionary is generated during compile time by the
- _calculate_correlations() method.
-
+ takes into account a set of existing froms which
+ may be rendered in the FROM clause of enclosing selects;
+ this Select may want to leave those absent if it is automatically
+ correlating.
"""
+
froms = util.OrderedSet()
hide_froms = util.Set()
if len(froms) > 1:
corr = self.__correlate
- if correlation_state is not None:
- corr = correlation_state[self].get('correlate', util.Set()).union(corr)
+ if self._should_correlate and existing_froms is not None:
+ corr = existing_froms.union(corr)
f = froms.difference(corr)
if len(f) == 0:
raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
for f in elem._get_from_objects():
froms.add(f)
return froms
-
- def _calculate_correlations(self, correlation_state):
- """generate a 'correlation_state' dictionary used by the _get_display_froms() method.
-
- The dictionary is passed in initially empty, or already
- containing the state information added by an enclosing
- Select construct. The method will traverse through all
- embedded Select statements and add information about their
- position and "from" objects to the dictionary. Those Select
- statements will later consult the 'correlation_state' dictionary
- when their list of 'FROM' clauses are generated using their
- _get_display_froms() method.
- """
-
- if self not in correlation_state:
- correlation_state[self] = {}
-
- display_froms = self._get_display_froms(correlation_state)
-
- class CorrelatedVisitor(NoColumnVisitor):
- def __init__(self, is_where=False, is_column=False, is_from=False):
- self.is_where = is_where
- self.is_column = is_column
- self.is_from = is_from
-
- def visit_compound_select(self, cs):
- self.visit_select(cs)
-
- def visit_select(s, select):
- if select not in correlation_state:
- correlation_state[select] = {}
-
- if select is self:
- return
-
- select_state = correlation_state[select]
- if s.is_from:
- select_state['is_selected_from'] = True
- if s.is_where:
- select_state['is_where'] = True
- select_state['is_subquery'] = True
-
- if select._should_correlate:
- corr = select_state.setdefault('correlate', util.Set())
- # not crazy about this part. need to be clearer on what elements in the
- # subquery correspond to elements in the enclosing query.
- for f in display_froms:
- corr.add(f)
- for f2 in f._get_from_objects():
- corr.add(f2)
-
- col_vis = CorrelatedVisitor(is_column=True)
- where_vis = CorrelatedVisitor(is_where=True)
- from_vis = CorrelatedVisitor(is_from=True)
-
- for col in self._raw_columns:
- col_vis.traverse(col)
- for f in col._get_from_objects():
- if f is not self:
- from_vis.traverse(f)
-
- for col in list(self._order_by_clause) + list(self._group_by_clause):
- col_vis.traverse(col)
-
- if self._whereclause is not None:
- where_vis.traverse(self._whereclause)
- for f in self._whereclause._get_from_objects(is_where=True):
- if f is not self:
- from_vis.traverse(f)
-
- for elem in self._froms:
- from_vis.traverse(elem)
def _get_inner_columns(self):
for c in self._raw_columns:
def supports_execution(self):
return True
- def _calculate_correlations(self, correlate_state):
- class SelectCorrelator(NoColumnVisitor):
- def visit_select(s, select):
- if select._should_correlate:
- select_state = correlate_state.setdefault(select, {})
- corr = select_state.setdefault('correlate', util.Set())
- corr.add(self.table)
-
- vis = SelectCorrelator()
-
- if self._whereclause is not None:
- vis.traverse(self._whereclause)
-
- if getattr(self, 'parameters', None) is not None:
- for key, value in self.parameters.items():
- if isinstance(value, ClauseElement):
- vis.traverse(value)
-
def _process_colparams(self, parameters):
"""Receive the *values* of an ``INSERT`` or ``UPDATE``
statement and construct appropriate bind parameters.