From: Mike Bayer Date: Sat, 11 Aug 2007 16:04:38 +0000 (+0000) Subject: - removed _calculate_correlations() methods, removed correlation_stack, select_stack; X-Git-Tag: rel_0_4beta1~8 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ac219b0192814cea0611f7251f7bb3927e5c3201;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - removed _calculate_correlations() methods, removed correlation_stack, select_stack; all are merged into a single stack thats all within ansicompiler. clause visiting cut down significantly. --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 14bae1d170..bfb08d3376 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -150,12 +150,9 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): # actually present in the generated SQL self.bind_names = {} - # when the compiler visits a SELECT statement, the clause object is appended - # to this stack. various visit operations will check this stack to determine - # additional choices (TODO: it seems to be all typemap stuff. shouldnt this only - # apply to the topmost-level SELECT statement ?) - self.select_stack = [] - + # a stack. what recursive compiler doesn't have a stack ? :) + self.stack = [] + # a dictionary of result-set column names (strings) to TypeEngine instances, # which will be passed to a ResultProxy and used for resultset-level value conversion self.typemap = {} @@ -184,13 +181,6 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): # an ANSIIdentifierPreparer that formats the quoting of identifiers self.preparer = dialect.identifier_preparer - # a dictionary containing attributes about all select() - # elements located within the clause, regarding which are subqueries, which are - # selected from, and which elements should be correlated to an enclosing select. - # used mostly to determine the list of FROM elements for each select statement, as well - # as some dialect-specific rules regarding subqueries. - self.correlate_state = {} - # for UPDATE and INSERT statements, a set of columns whos values are being set # from a SQL expression (i.e., not one of the bind parameter values). if present, # default-value logic in the Dialect knows not to fire off column defaults @@ -230,11 +220,17 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): self.string = self.process(self.statement) self.after_compile() - def process(self, obj, **kwargs): - return self.traverse_single(obj, **kwargs) + def process(self, obj, stack=None, **kwargs): + if stack: + self.stack.append(stack) + try: + return self.traverse_single(obj, **kwargs) + finally: + if stack: + self.stack.pop(-1) def is_subquery(self, select): - return self.correlate_state[select].get('is_subquery', False) + return self.stack and self.stack[-1].get('is_subquery') def get_whereclause(self, obj): """given a FROM clause, return an additional WHERE condition that should be @@ -292,7 +288,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): def visit_label(self, label): labelname = self._truncated_identifier("colident", label.name) - if self.select_stack: + if self.stack and self.stack[-1].get('select'): self.typemap.setdefault(labelname.lower(), label.obj.type) if isinstance(label.obj, sql._ColumnClause): self.column_labels[label.obj._label] = labelname @@ -310,7 +306,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): else: name = column.name - if self.select_stack: + if self.stack and self.stack[-1].get('select'): # if we are within a visit to a Select, set up the "typemap" # for this column which is used to translate result set values self.typemap.setdefault(name.lower(), column.type) @@ -369,28 +365,28 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): return self.process(clause.clause_expr) def visit_cast(self, cast, **kwargs): - if self.select_stack: + if self.stack and self.stack[-1].get('select'): # not sure if we want to set the typemap here... self.typemap.setdefault("CAST", cast.type) return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) def visit_function(self, func, **kwargs): - if self.select_stack: + if self.stack and self.stack[-1].get('select'): self.typemap.setdefault(func.name, func.type) if not self.apply_function_parens(func): return ".".join(func.packagenames + [func.name]) else: return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr) - def visit_compound_select(self, cs, asfrom=False, **kwargs): - text = string.join([self.process(c) for c in cs.selects], " " + cs.keyword + " ") - group_by = self.process(cs._group_by_clause) + def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): + text = string.join([self.process(c, asfrom=asfrom, parens=False) for c in cs.selects], " " + cs.keyword + " ") + group_by = self.process(cs._group_by_clause, asfrom=asfrom) if group_by: text += " GROUP BY " + group_by text += self.order_by_clause(cs) text += (cs._limit or cs._offset) and self.limit_clause(cs) or "" - if asfrom: + if asfrom and parens: return "(" + text + ")" else: return text @@ -499,7 +495,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): # names look like table.colname. so if column is in a "selected from" # subquery, label it synoymously with its column name if \ - self.correlate_state[select].get('is_selected_from', False) and \ + (self.stack and self.stack[-1].get('is_selected_from')) and \ isinstance(column, sql._ColumnClause) and \ not column.is_literal and \ column.table is not None and \ @@ -507,16 +503,37 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): return column.label(column.name) else: return None - - def visit_select(self, select, asfrom=False, **kwargs): - select._calculate_correlations(self.correlate_state) - self.select_stack.append(select) + def visit_select(self, select, asfrom=False, parens=True, **kwargs): + + stack_entry = {'select':select} + + if asfrom: + stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True + elif self.stack and self.stack[-1].get('select'): + stack_entry['is_subquery'] = True + + if self.stack and self.stack[-1].get('from'): + existingfroms = self.stack[-1]['from'] + else: + existingfroms = None + froms = select._get_display_froms(existingfroms) + + correlate_froms = util.Set() + for f in froms: + correlate_froms.add(f) + for f2 in f._get_from_objects(): + correlate_froms.add(f2) + + # TODO: might want to propigate existing froms for select(select(select)) + # where innermost select should correlate to outermost +# if existingfroms: +# correlate_froms = correlate_froms.union(existingfroms) + stack_entry['from'] = correlate_froms + self.stack.append(stack_entry) # the actual list of columns to print in the SELECT column list. inner_columns = util.OrderedSet() - - froms = select._get_display_froms(self.correlate_state) for co in select.inner_columns: if select.use_labels: @@ -533,9 +550,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): inner_columns.add(self.process(l)) else: inner_columns.add(self.process(co)) - - self.select_stack.pop(-1) - + collist = string.join(inner_columns.difference(util.Set([None])), ', ') text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " " @@ -579,7 +594,9 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): text += (select._limit or select._offset) and self.limit_clause(select) or "" text += self.for_update_clause(select) - if asfrom: + self.stack.pop(-1) + + if asfrom and parens: return "(" + text + ")" else: return text @@ -652,7 +669,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")") def visit_update(self, update_stmt): - update_stmt._calculate_correlations(self.correlate_state) + self.stack.append({'from':util.Set([update_stmt.table])}) # search for columns who will be required to have an explicit bound value. # for updates, this includes Python-side "onupdate" defaults. @@ -672,7 +689,9 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) - + + self.stack.pop(-1) + return text def _get_colparams(self, stmt, required_cols): @@ -735,13 +754,15 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): return values def visit_delete(self, delete_stmt): - delete_stmt._calculate_correlations(self.correlate_state) + self.stack.append({'from':util.Set([delete_stmt.table])}) text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) if delete_stmt._whereclause: text += " WHERE " + self.process(delete_stmt._whereclause) + self.stack.pop(-1) + return text def visit_savepoint(self, savepoint_stmt): diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 5524cdc056..28b750c4ad 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -3113,17 +3113,15 @@ class Select(_SelectBaseMixin, FromClause): _SelectBaseMixin.__init__(self, **kwargs) - def _get_display_froms(self, correlation_state=None): + def _get_display_froms(self, existing_froms=None): """return the full list of 'from' clauses to be displayed. - takes into account an optional 'correlation_state' - dictionary which contains information about this Select's - correlation to an enclosing select, which may cause some 'from' - clauses to not display in this Select's FROM clause. - this dictionary is generated during compile time by the - _calculate_correlations() method. - + takes into account a set of existing froms which + may be rendered in the FROM clause of enclosing selects; + this Select may want to leave those absent if it is automatically + correlating. """ + froms = util.OrderedSet() hide_froms = util.Set() @@ -3150,8 +3148,8 @@ class Select(_SelectBaseMixin, FromClause): if len(froms) > 1: corr = self.__correlate - if correlation_state is not None: - corr = correlation_state[self].get('correlate', util.Set()).union(corr) + if self._should_correlate and existing_froms is not None: + corr = existing_froms.union(corr) f = froms.difference(corr) if len(f) == 0: raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate())) @@ -3178,78 +3176,6 @@ class Select(_SelectBaseMixin, FromClause): for f in elem._get_from_objects(): froms.add(f) return froms - - def _calculate_correlations(self, correlation_state): - """generate a 'correlation_state' dictionary used by the _get_display_froms() method. - - The dictionary is passed in initially empty, or already - containing the state information added by an enclosing - Select construct. The method will traverse through all - embedded Select statements and add information about their - position and "from" objects to the dictionary. Those Select - statements will later consult the 'correlation_state' dictionary - when their list of 'FROM' clauses are generated using their - _get_display_froms() method. - """ - - if self not in correlation_state: - correlation_state[self] = {} - - display_froms = self._get_display_froms(correlation_state) - - class CorrelatedVisitor(NoColumnVisitor): - def __init__(self, is_where=False, is_column=False, is_from=False): - self.is_where = is_where - self.is_column = is_column - self.is_from = is_from - - def visit_compound_select(self, cs): - self.visit_select(cs) - - def visit_select(s, select): - if select not in correlation_state: - correlation_state[select] = {} - - if select is self: - return - - select_state = correlation_state[select] - if s.is_from: - select_state['is_selected_from'] = True - if s.is_where: - select_state['is_where'] = True - select_state['is_subquery'] = True - - if select._should_correlate: - corr = select_state.setdefault('correlate', util.Set()) - # not crazy about this part. need to be clearer on what elements in the - # subquery correspond to elements in the enclosing query. - for f in display_froms: - corr.add(f) - for f2 in f._get_from_objects(): - corr.add(f2) - - col_vis = CorrelatedVisitor(is_column=True) - where_vis = CorrelatedVisitor(is_where=True) - from_vis = CorrelatedVisitor(is_from=True) - - for col in self._raw_columns: - col_vis.traverse(col) - for f in col._get_from_objects(): - if f is not self: - from_vis.traverse(f) - - for col in list(self._order_by_clause) + list(self._group_by_clause): - col_vis.traverse(col) - - if self._whereclause is not None: - where_vis.traverse(self._whereclause) - for f in self._whereclause._get_from_objects(is_where=True): - if f is not self: - from_vis.traverse(f) - - for elem in self._froms: - from_vis.traverse(elem) def _get_inner_columns(self): for c in self._raw_columns: @@ -3439,24 +3365,6 @@ class _UpdateBase(ClauseElement): def supports_execution(self): return True - def _calculate_correlations(self, correlate_state): - class SelectCorrelator(NoColumnVisitor): - def visit_select(s, select): - if select._should_correlate: - select_state = correlate_state.setdefault(select, {}) - corr = select_state.setdefault('correlate', util.Set()) - corr.add(self.table) - - vis = SelectCorrelator() - - if self._whereclause is not None: - vis.traverse(self._whereclause) - - if getattr(self, 'parameters', None) is not None: - for key, value in self.parameters.items(): - if isinstance(value, ClauseElement): - vis.traverse(value) - def _process_colparams(self, parameters): """Receive the *values* of an ``INSERT`` or ``UPDATE`` statement and construct appropriate bind parameters.