From 13ac46eb3fc92d9e4ad3582057ddcc75e6b71201 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 24 Jun 2007 19:58:41 +0000 Subject: [PATCH] - merge of generative_sql branch - copy_container() removed. ClauseVisitor.traverse() now features "clone" flag which allows traversal with copy-and-modify-in-place behavior - select() objects copyable now [ticket:52] [ticket:569] - improved support for custom column_property() attributes which feature correlated subqueries...work better with eager loading now. - accept_visitor() methods removed. ClauseVisitor now genererates method names based on class names, or an optional __visit_name__ attribute. calls regular visit_XXX methods as they exist, can optionally call an additional "pre-descent" enter_XXX method to allow stack-based operations on traversals - select() and union()'s now have "generative" behavior. methods like order_by() and group_by() return a *new* instance - the original instance is left unchanged. non-generative methods remain as well. - the internals of select/union vastly simplified - all decision making regarding "is subquery" and "correlation" pushed to SQL generation phase. select() elements are now *never* mutated by their enclosing containers or by any dialect's compilation process --- CHANGES | 14 + lib/sqlalchemy/ansisql.py | 176 ++-- lib/sqlalchemy/databases/firebird.py | 10 +- lib/sqlalchemy/databases/informix.py | 14 +- lib/sqlalchemy/databases/mssql.py | 26 +- lib/sqlalchemy/databases/mysql.py | 12 +- lib/sqlalchemy/databases/oracle.py | 76 +- lib/sqlalchemy/databases/postgres.py | 10 +- lib/sqlalchemy/databases/sqlite.py | 10 +- lib/sqlalchemy/engine/base.py | 92 +- lib/sqlalchemy/ext/sqlsoup.py | 2 +- lib/sqlalchemy/orm/mapper.py | 4 +- lib/sqlalchemy/orm/properties.py | 26 +- lib/sqlalchemy/orm/query.py | 32 +- lib/sqlalchemy/orm/strategies.py | 54 +- lib/sqlalchemy/schema.py | 111 +-- lib/sqlalchemy/sql.py | 1119 ++++++++++------------ lib/sqlalchemy/sql_util.py | 28 +- lib/sqlalchemy/util.py | 4 +- test/orm/generative.py | 4 +- test/orm/inheritance/poly_linked_list.py | 30 +- test/orm/mapper.py | 36 +- test/perf/masseagerload.py | 1 + test/sql/alltests.py | 2 + test/sql/generative.py | 212 ++++ test/sql/select.py | 36 +- 26 files changed, 1166 insertions(+), 975 deletions(-) create mode 100644 test/sql/generative.py diff --git a/CHANGES b/CHANGES index febda7d016..b118856fee 100644 --- a/CHANGES +++ b/CHANGES @@ -12,6 +12,8 @@ auto-construction of joins which cross the same paths but are querying divergent criteria. ClauseElements at the front of filter_by() are removed (use filter()). + - improved support for custom column_property() attributes which + feature correlated subqueries...work better with eager loading now. - along with recent speedups to ResultProxy, total number of function calls significantly reduced for large loads. test/perf/masseagerload.py reports 0.4 as having the fewest number @@ -35,6 +37,18 @@ - added undefer_group() MapperOption, sets a set of "deferred" columns joined by a "group" to load as "undeferred". - sql + - significant architectural overhaul to SQL elements (ClauseElement). + all elements share a common "mutability" framework which allows a + consistent approach to in-place modifications of elements as well as + generative behavior. improves stability of the ORM which makes + heavy usage of mutations to SQL expressions. + - select() and union()'s now have "generative" behavior. methods like + order_by() and group_by() return a *new* instance - the original instance + is left unchanged. non-generative methods remain as well. + - the internals of select/union vastly simplified - all decision making + regarding "is subquery" and "correlation" pushed to SQL generation phase. + select() elements are now *never* mutated by their enclosing containers + or by any dialect's compilation process [ticket:52] [ticket:569] - result sets from CRUD operations close their underlying cursor immediately. will also autoclose the connection if defined for the operation; this allows more efficient usage of connections for successive CRUD operations diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index c489d7929a..e8610f8644 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -66,13 +66,13 @@ class ANSIDialect(default.DefaultDialect): """ return ANSIIdentifierPreparer(self) -class ANSICompiler(sql.Compiled): +class ANSICompiler(engine.Compiled): """Default implementation of Compiled. Compiles ClauseElements into ANSI-compliant SQL strings. """ - __traverse_options__ = {'column_collections':False} + __traverse_options__ = {'column_collections':False, 'entry':True} def __init__(self, dialect, statement, parameters=None, **kwargs): """Construct a new ``ANSICompiler`` object. @@ -92,7 +92,7 @@ class ANSICompiler(sql.Compiled): correspond to the keys present in the parameters. """ - sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs) + super(ANSICompiler, self).__init__(dialect, statement, parameters, **kwargs) # if we are insert/update. set to true when we visit an INSERT or UPDATE self.isinsert = self.isupdate = False @@ -158,7 +158,14 @@ class ANSICompiler(sql.Compiled): # an ANSIIdentifierPreparer that formats the quoting of identifiers self.preparer = dialect.identifier_preparer - + + # a dictionary containing attributes about all select() + # elements located within the clause, regarding which are subqueries, which are + # selected from, and which elements should be correlated to an enclosing select. + # used mostly to determine the list of FROM elements for each select statement, as well + # as some dialect-specific rules regarding subqueries. + self.correlate_state = {} + # for UPDATE and INSERT statements, a set of columns whos values are being set # from a SQL expression (i.e., not one of the bind parameter values). if present, # default-value logic in the Dialect knows not to fire off column defaults @@ -193,7 +200,10 @@ class ANSICompiler(sql.Compiled): def get_str(self, obj): return self.strings[obj] - + + def is_subquery(self, select): + return self.correlate_state[select].get('is_subquery', False) + def get_whereclause(self, obj): return self.wheres.get(obj, None) @@ -343,7 +353,7 @@ class ANSICompiler(sql.Compiled): def visit_compound_select(self, cs): text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ") - group_by = self.get_str(cs.group_by_clause) + group_by = self.get_str(cs._group_by_clause) if group_by: text += " GROUP BY " + group_by text += self.order_by_clause(cs) @@ -424,40 +434,68 @@ class ANSICompiler(sql.Compiled): self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias) self.strings[alias] = self.get_str(alias.original) + def enter_select(self, select): + select.calculate_correlations(self.correlate_state) + self.select_stack.append(select) + + def enter_update(self, update): + update.calculate_correlations(self.correlate_state) + + def enter_delete(self, delete): + delete.calculate_correlations(self.correlate_state) + + def label_select_column(self, select, column): + """convert a column from a select's "columns" clause. + + given a select() and a column element from its inner_columns collection, return a + Label object if this column should be labeled in the columns clause. Otherwise, + return None and the column will be used as-is. + + The calling method will traverse the returned label to acquire its string + representation. + """ + + # SQLite doesnt like selecting from a subquery where the column + # names look like table.colname. so if column is in a "selected from" + # subquery, label it synoymously with its column name + if \ + self.correlate_state[select].get('is_selected_from', False) and \ + isinstance(column, sql._ColumnClause) and \ + not column.is_literal and \ + column.table is not None and \ + not isinstance(column.table, sql.Select): + return column.label(column.name) + else: + return None + def visit_select(self, select): # the actual list of columns to print in the SELECT column list. inner_columns = util.OrderedDict() - - self.select_stack.append(select) - for c in select._raw_columns: - if hasattr(c, '_selectable'): - s = c._selectable() + + froms = select.get_display_froms(self.correlate_state) + for f in froms: + if f not in self.strings: + self.traverse(f) + + for co in select.inner_columns: + if select.use_labels: + labelname = co._label + if labelname is not None: + l = co.label(labelname) + self.traverse(l) + inner_columns[labelname] = l + else: + self.traverse(co) + inner_columns[self.get_str(co)] = co else: - self.traverse(c) - inner_columns[self.get_str(c)] = c - continue - for co in s.columns: - if select.use_labels: - labelname = co._label - if labelname is not None: - l = co.label(labelname) - self.traverse(l) - inner_columns[labelname] = l - else: - self.traverse(co) - inner_columns[self.get_str(co)] = co - # TODO: figure this out, a ColumnClause with a select as a parent - # is different from any other kind of parent - elif select.is_selected_from and isinstance(co, sql._ColumnClause) and not co.is_literal and co.table is not None and not isinstance(co.table, sql.Select): - # SQLite doesnt like selecting from a subquery where the column - # names look like table.colname, so add a label synonomous with - # the column name - l = co.label(co.name) + l = self.label_select_column(select, co) + if l is not None: self.traverse(l) inner_columns[self.get_str(l.obj)] = l else: self.traverse(co) inner_columns[self.get_str(co)] = co + self.select_stack.pop(-1) collist = string.join([self.get_str(v) for v in inner_columns.values()], ', ') @@ -466,29 +504,10 @@ class ANSICompiler(sql.Compiled): text += self.visit_select_precolumns(select) text += collist - whereclause = select.whereclause - - froms = [] - for f in select.froms: - - if self.parameters is not None: - # TODO: whack this feature in 0.4 - # look at our own parameters, see if they - # are all present in the form of BindParamClauses. if - # not, then append to the above whereclause column conditions - # matching those keys - for c in f.columns: - if sql.is_column(c) and self.parameters.has_key(c.key) and not self.binds.has_key(c.key): - value = self.parameters[c.key] - else: - continue - clause = c==value - if whereclause is not None: - whereclause = self.traverse(sql.and_(clause, whereclause), stop_on=util.Set([whereclause])) - else: - whereclause = clause - self.traverse(whereclause) + whereclause = select._whereclause + from_strings = [] + for f in froms: # special thingy used by oracle to redefine a join w = self.get_whereclause(f) if w is not None: @@ -500,11 +519,11 @@ class ANSICompiler(sql.Compiled): t = self.get_from_text(f) if t is not None: - froms.append(t) + from_strings.append(t) if len(froms): text += " \nFROM " - text += string.join(froms, ', ') + text += string.join(from_strings, ', ') else: text += self.default_from() @@ -513,12 +532,12 @@ class ANSICompiler(sql.Compiled): if t: text += " \nWHERE " + t - group_by = self.get_str(select.group_by_clause) + group_by = self.get_str(select._group_by_clause) if group_by: text += " GROUP BY " + group_by - if select.having is not None: - t = self.get_str(select.having) + if select._having is not None: + t = self.get_str(select._having) if t: text += " \nHAVING " + t @@ -532,7 +551,7 @@ class ANSICompiler(sql.Compiled): def visit_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list.""" - return select.distinct and "DISTINCT " or "" + return select._distinct and "DISTINCT " or "" def visit_select_postclauses(self, select): """Called when building a ``SELECT`` statement, position is after all other ``SELECT`` clauses. @@ -540,10 +559,10 @@ class ANSICompiler(sql.Compiled): Most DB syntaxes put ``LIMIT``/``OFFSET`` here. """ - return (select.limit or select.offset) and self.limit_clause(select) or "" + return (select._limit or select._offset) and self.limit_clause(select) or "" def order_by_clause(self, select): - order_by = self.get_str(select.order_by_clause) + order_by = self.get_str(select._order_by_clause) if order_by: return " ORDER BY " + order_by else: @@ -557,12 +576,12 @@ class ANSICompiler(sql.Compiled): def limit_clause(self, select): text = "" - if select.limit is not None: - text += " \n LIMIT " + str(select.limit) - if select.offset is not None: - if select.limit is None: + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: text += " \n LIMIT -1" - text += " OFFSET " + str(select.offset) + text += " OFFSET " + str(select._offset) return text def visit_table(self, table): @@ -696,8 +715,8 @@ class ANSICompiler(sql.Compiled): text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ') - if update_stmt.whereclause: - text += " WHERE " + self.get_str(update_stmt.whereclause) + if update_stmt._whereclause: + text += " WHERE " + self.get_str(update_stmt._whereclause) self.strings[update_stmt] = text @@ -755,13 +774,14 @@ class ANSICompiler(sql.Compiled): if sql._is_literal(value): value = sql.bindparam(c.key, value, type=c.type, unique=True) values.append((c, value)) + return values def visit_delete(self, delete_stmt): text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) - if delete_stmt.whereclause: - text += " WHERE " + self.get_str(delete_stmt.whereclause) + if delete_stmt._whereclause: + text += " WHERE " + self.get_str(delete_stmt._whereclause) self.strings[delete_stmt] = text @@ -795,7 +815,7 @@ class ANSISchemaGenerator(ANSISchemaBase): def visit_metadata(self, metadata): collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))] for table in collection: - table.accept_visitor(self) + self.traverse_single(table) if self.dialect.supports_alter(): for alterable in self.find_alterables(collection): self.add_foreignkey(alterable) @@ -803,7 +823,7 @@ class ANSISchemaGenerator(ANSISchemaBase): def visit_table(self, table): for column in table.columns: if column.default is not None: - column.default.accept_visitor(self) + self.traverse_single(column.default) #if column.onupdate is not None: # column.onupdate.accept_visitor(visitor) @@ -820,20 +840,20 @@ class ANSISchemaGenerator(ANSISchemaBase): if column.primary_key: first_pk = True for constraint in column.constraints: - constraint.accept_visitor(self) + self.traverse_single(constraint) # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) if len(table.primary_key): - table.primary_key.accept_visitor(self) + self.traverse_single(table.primary_key) for constraint in [c for c in table.constraints if c is not table.primary_key]: - constraint.accept_visitor(self) + self.traverse_single(constraint) self.append("\n)%s\n\n" % self.post_create_table(table)) self.execute() if hasattr(table, 'indexes'): for index in table.indexes: - index.accept_visitor(self) + self.traverse_single(index) def post_create_table(self, table): return '' @@ -929,7 +949,7 @@ class ANSISchemaDropper(ANSISchemaBase): for alterable in self.find_alterables(collection): self.drop_foreignkey(alterable) for table in collection: - table.accept_visitor(self) + self.traverse_single(table) def visit_index(self, index): self.append("\nDROP INDEX " + index.name) @@ -942,7 +962,7 @@ class ANSISchemaDropper(ANSISchemaBase): def visit_table(self, table): for column in table.columns: if column.default is not None: - column.default.accept_visitor(self) + self.traverse_single(column.default) self.append("\nDROP TABLE " + self.preparer.format_table(table)) self.execute() diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index a02781c846..64f5842384 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -324,11 +324,11 @@ class FBCompiler(ansisql.ANSICompiler): """ result = "" - if select.limit: - result += " FIRST %d " % select.limit - if select.offset: - result +=" SKIP %d " % select.offset - if select.distinct: + if select._limit: + result += " FIRST %d " % select._limit + if select._offset: + result +=" SKIP %d " % select._offset + if select._distinct: result += " DISTINCT " return result diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index 2fb508280c..99bc3896c9 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -373,19 +373,19 @@ class InfoCompiler(ansisql.ANSICompiler): return " from systables where tabname = 'systables' " def visit_select_precolumns( self , select ): - s = select.distinct and "DISTINCT " or "" + s = select._distinct and "DISTINCT " or "" # only has limit - if select.limit: - off = select.offset or 0 - s += " FIRST %s " % ( select.limit + off ) + if select._limit: + off = select._offset or 0 + s += " FIRST %s " % ( select._limit + off ) else: s += "" return s def visit_select(self, select): - if select.offset: - self.offset = select.offset - self.limit = select.limit or 0 + if select._offset: + self.offset = select._offset + self.limit = select._limit or 0 # the column in order by clause must in select too def __label( c ): diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 2b6808eaca..8b81884bb9 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -25,7 +25,7 @@ * Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT`` -* ``select.limit`` implemented as ``SELECT TOP n`` +* ``select._limit`` implemented as ``SELECT TOP n`` Known issues / TODO: @@ -756,10 +756,10 @@ class MSSQLCompiler(ansisql.ANSICompiler): def visit_select_precolumns(self, select): """ MS-SQL puts TOP, it's version of LIMIT here """ - s = select.distinct and "DISTINCT " or "" - if select.limit: - s += "TOP %s " % (select.limit,) - if select.offset: + s = select._distinct and "DISTINCT " or "" + if select._limit: + s += "TOP %s " % (select._limit,) + if select._offset: raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset') return s @@ -803,13 +803,11 @@ class MSSQLCompiler(ansisql.ANSICompiler): binary.left, binary.right = binary.right, binary.left super(MSSQLCompiler, self).visit_binary(binary) - def visit_select(self, select): - # label function calls, so they return a name in cursor.description - for i,c in enumerate(select._raw_columns): - if isinstance(c, sql._Function): - select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:]) - - super(MSSQLCompiler, self).visit_select(select) + def label_select_column(self, select, column): + if isinstance(column, sql._Function): + return co.label(co.name + "_" + hex(random.randint(0, 65535))[2:]) + else: + return super(MSSQLCompiler, self).label_select_column(select, column) function_rewrites = {'current_date': 'getdate', 'length': 'len', @@ -823,10 +821,10 @@ class MSSQLCompiler(ansisql.ANSICompiler): return '' def order_by_clause(self, select): - order_by = self.get_str(select.order_by_clause) + order_by = self.get_str(select._order_by_clause) # MSSQL only allows ORDER BY in subqueries if there is a LIMIT - if order_by and (not select.is_subquery or select.limit): + if order_by and (not self.is_subquery(select) or select._limit): return " ORDER BY " + order_by else: return "" diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index e45536a756..66bb306b56 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -1200,13 +1200,13 @@ class MySQLCompiler(ansisql.ANSICompiler): def limit_clause(self, select): text = "" - if select.limit is not None: - text += " \n LIMIT " + str(select.limit) - if select.offset is not None: - if select.limit is None: - # striaght from the MySQL docs, I kid you not + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + # straight from the MySQL docs, I kid you not text += " \n LIMIT 18446744073709551615" - text += " OFFSET " + str(select.offset) + text += " OFFSET " + str(select._offset) return text class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 4210a94974..b88bea663f 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -433,11 +433,6 @@ class OracleCompiler(ansisql.ANSICompiler): the use_ansi flag is False. """ - def __init__(self, *args, **kwargs): - super(OracleCompiler, self).__init__(*args, **kwargs) - # we have to modify SELECT objects a little bit, so store state here - self._select_state = {} - def default_from(self): """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. @@ -472,7 +467,7 @@ class OracleCompiler(ansisql.ANSICompiler): self._outertable = None - self.wheres[join].accept_visitor(self) + self.traverse_single(self.wheres[join]) def visit_insert_sequence(self, column, sequence, parameters): """This is the `sequence` equivalent to ``ANSICompiler``'s @@ -508,74 +503,35 @@ class OracleCompiler(ansisql.ANSICompiler): def _TODO_visit_compound_select(self, select): """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" - - if getattr(select, '_oracle_visit', False): - # cancel out the compiled order_by on the select - if hasattr(select, "order_by_clause"): - self.strings[select.order_by_clause] = "" - ansisql.ANSICompiler.visit_compound_select(self, select) - return - - if select.limit is not None or select.offset is not None: - select._oracle_visit = True - # to use ROW_NUMBER(), an ORDER BY is required. - orderby = self.strings[select.order_by_clause] - if not orderby: - orderby = select.oid_column - self.traverse(orderby) - orderby = self.strings[orderby] - class SelectVisitor(sql.NoColumnVisitor): - def visit_select(self, select): - select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")) - SelectVisitor().traverse(select) - limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) - if select.offset is not None: - limitselect.append_whereclause("ora_rn>%d" % select.offset) - if select.limit is not None: - limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset)) - else: - limitselect.append_whereclause("ora_rn<=%d" % select.limit) - self.traverse(limitselect) - self.strings[select] = self.strings[limitselect] - self.froms[select] = self.froms[limitselect] - else: - ansisql.ANSICompiler.visit_compound_select(self, select) + pass def visit_select(self, select): """Look for ``LIMIT`` and OFFSET in a select statement, and if so tries to wrap it in a subquery with ``row_number()`` criterion. """ - # TODO: put a real copy-container on Select and copy, or somehow make this - # not modify the Select statement - if self._select_state.get((select, 'visit'), False): - # cancel out the compiled order_by on the select - if hasattr(select, "order_by_clause"): - self.strings[select.order_by_clause] = "" - ansisql.ANSICompiler.visit_select(self, select) - return - - if select.limit is not None or select.offset is not None: - self._select_state[(select, 'visit')] = True + if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None): # to use ROW_NUMBER(), an ORDER BY is required. - orderby = self.strings[select.order_by_clause] + orderby = self.strings[select._order_by_clause] if not orderby: orderby = select.oid_column self.traverse(orderby) orderby = self.strings[orderby] - if not hasattr(select, '_oracle_visit'): - select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")) - select._oracle_visit = True + + oldselect = select + select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None) + select._oracle_visit = True + limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) - if select.offset is not None: - limitselect.append_whereclause("ora_rn>%d" % select.offset) - if select.limit is not None: - limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset)) + if select._offset is not None: + limitselect.append_whereclause("ora_rn>%d" % select._offset) + if select._limit is not None: + limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset)) else: - limitselect.append_whereclause("ora_rn<=%d" % select.limit) + limitselect.append_whereclause("ora_rn<=%d" % select._limit) self.traverse(limitselect) - self.strings[select] = self.strings[limitselect] - self.froms[select] = self.froms[limitselect] + self.strings[oldselect] = self.strings[limitselect] + self.froms[oldselect] = self.froms[limitselect] else: ansisql.ANSICompiler.visit_select(self, select) diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index b48a709d8c..ea92cf7fd2 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -423,16 +423,16 @@ class PGCompiler(ansisql.ANSICompiler): return text def visit_select_precolumns(self, select): - if select.distinct: - if type(select.distinct) == bool: + if select._distinct: + if type(select._distinct) == bool: return "DISTINCT " - if type(select.distinct) == list: + if type(select._distinct) == list: dist_set = "DISTINCT ON (" - for col in select.distinct: + for col in select._distinct: dist_set += self.strings[col] + ", " dist_set = dist_set[:-2] + ") " return dist_set - return "DISTINCT ON (" + str(select.distinct) + ") " + return "DISTINCT ON (" + str(select._distinct) + ") " else: return "" diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 0bd7cf6aee..e3282e028a 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -327,12 +327,12 @@ class SQLiteCompiler(ansisql.ANSICompiler): def limit_clause(self, select): text = "" - if select.limit is not None: - text += " \n LIMIT " + str(select.limit) - if select.offset is not None: - if select.limit is None: + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: text += " \n LIMIT -1" - text += " OFFSET " + str(select.offset) + text += " OFFSET " + str(select._offset) else: text += " OFFSET 0" return text diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index b1e1ee5cc9..de4d6b2aeb 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -364,8 +364,90 @@ class ExecutionContext(object): raise NotImplementedError() +class Compiled(sql.ClauseVisitor): + """Represent a compiled SQL expression. + + The ``__str__`` method of the ``Compiled`` object should produce + the actual text of the statement. ``Compiled`` objects are + specific to their underlying database dialect, and also may + or may not be specific to the columns referenced within a + particular set of bind parameters. In no case should the + ``Compiled`` object be dependent on the actual values of those + bind parameters, even though it may reference those values as + defaults. + """ + + def __init__(self, dialect, statement, parameters, engine=None): + """Construct a new ``Compiled`` object. + + statement + ``ClauseElement`` to be compiled. + + parameters + Optional dictionary indicating a set of bind parameters + specified with this ``Compiled`` object. These parameters + are the *default* values corresponding to the + ``ClauseElement``'s ``_BindParamClauses`` when the + ``Compiled`` is executed. In the case of an ``INSERT`` or + ``UPDATE`` statement, these parameters will also result in + the creation of new ``_BindParamClause`` objects for each + key and will also affect the generated column list in an + ``INSERT`` statement and the ``SET`` clauses of an + ``UPDATE`` statement. The keys of the parameter dictionary + can either be the string names of columns or + ``_ColumnClause`` objects. + + engine + Optional Engine to compile this statement against. + """ + self.dialect = dialect + self.statement = statement + self.parameters = parameters + self.engine = engine + self.can_execute = statement.supports_execution() + + def compile(self): + self.traverse(self.statement) + self.after_compile() + + def __str__(self): + """Return the string text of the generated SQL statement.""" + + raise NotImplementedError() + + def get_params(self, **params): + """Deprecated. use construct_params(). (supports unicode names) + """ + + return self.construct_params(params) + + def construct_params(self, params): + """Return the bind params for this compiled object. + + Will start with the default parameters specified when this + ``Compiled`` object was first constructed, and will override + those values with those sent via `**params`, which are + key/value pairs. Each key should match one of the + ``_BindParamClause`` objects compiled into this object; either + the `key` or `shortname` property of the ``_BindParamClause``. + """ + raise NotImplementedError() + + def execute(self, *multiparams, **params): + """Execute this compiled object.""" + + e = self.engine + if e is None: + raise exceptions.InvalidRequestError("This Compiled object is not bound to any engine.") + return e.execute_compiled(self, *multiparams, **params) + + def scalar(self, *multiparams, **params): + """Execute this compiled object and return the result's scalar value.""" + + return self.execute(*multiparams, **params).scalar() + -class Connectable(sql.Executor): +class Connectable(object): """Interface for an object that can provide an Engine and a Connection object which correponds to that Engine.""" def contextual_connect(self): @@ -522,7 +604,7 @@ class Connection(Connectable): raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object))) def execute_default(self, default, **kwargs): - return default.accept_visitor(self.__engine.dialect.defaultrunner(self)) + return self.__engine.dialect.defaultrunner(self).traverse_single(default) def execute_text(self, statement, *multiparams, **params): if len(multiparams) == 0: @@ -729,7 +811,7 @@ class Engine(Connectable): else: conn = connection try: - element.accept_visitor(visitorcallable(conn, **kwargs)) + visitorcallable(conn, **kwargs).traverse(element) finally: if connection is None: conn.close() @@ -1248,13 +1330,13 @@ class DefaultRunner(schema.SchemaVisitor): def get_column_default(self, column): if column.default is not None: - return column.default.accept_visitor(self) + return self.traverse_single(column.default) else: return None def get_column_onupdate(self, column): if column.onupdate is not None: - return column.onupdate.accept_visitor(self) + return self.traverse_single(column.onupdate) else: return None diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index a9b93bc564..15be667090 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -425,7 +425,7 @@ def _selectable_name(selectable): if isinstance(selectable, sql.Alias): return _selectable_name(selectable.selectable) elif isinstance(selectable, sql.Select): - return ''.join([_selectable_name(s) for s in selectable.froms]) + return ''.join([_selectable_name(s) for s in selectable.get_display_froms()]) elif isinstance(selectable, schema.Table): return selectable.name.capitalize() else: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index cb12611306..8b0878688d 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1546,7 +1546,7 @@ class Mapper(object): return obj def _deferred_inheritance_condition(self, needs_tables): - cond = self.inherit_condition.copy_container() + cond = self.inherit_condition param_names = [] def visit_binary(binary): @@ -1560,7 +1560,7 @@ class Mapper(object): elif rightcol not in needs_tables: binary.right = sql.bindparam(rightcol.name, None, type=binary.right.type, unique=True) param_names.append(rightcol) - mapperutil.BinaryVisitor(visit_binary).traverse(cond) + cond = mapperutil.BinaryVisitor(visit_binary).traverse(cond, clone=True) return cond, param_names def translate_row(self, tomapper, row): diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index b5b8f83069..79fa101d25 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -399,15 +399,13 @@ class PropertyLoader(StrategizedProperty): # if the target mapper loads polymorphically, adapt the clauses to the target's selectable if self.loads_polymorphic: if self.secondaryjoin: - self.polymorphic_secondaryjoin = self.secondaryjoin.copy_container() - sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.polymorphic_secondaryjoin) - self.polymorphic_primaryjoin = self.primaryjoin.copy_container() + self.polymorphic_secondaryjoin = sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True) + self.polymorphic_primaryjoin = self.primaryjoin else: - self.polymorphic_primaryjoin = self.primaryjoin.copy_container() if self.direction is sync.ONETOMANY: - sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin) + self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True) elif self.direction is sync.MANYTOONE: - sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin) + self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True) self.polymorphic_secondaryjoin = None # load "polymorphic" versions of the columns present in "remote_side" - this is # important for lazy-clause generation which goes off the polymorphic target selectable @@ -422,8 +420,8 @@ class PropertyLoader(StrategizedProperty): else: raise exceptions.AssertionError(str(self) + ": Could not find corresponding column for " + str(c) + " in selectable " + str(self.mapper.select_table)) else: - self.polymorphic_primaryjoin = self.primaryjoin.copy_container() - self.polymorphic_secondaryjoin = self.secondaryjoin and self.secondaryjoin.copy_container() or None + self.polymorphic_primaryjoin = self.primaryjoin + self.polymorphic_secondaryjoin = self.secondaryjoin def _post_init(self): if logging.is_info_enabled(self.logger): @@ -466,17 +464,13 @@ class PropertyLoader(StrategizedProperty): return self._parent_join_cache[(parent, primary, secondary)] except KeyError: parent_equivalents = parent._get_equivalent_columns() - primaryjoin = self.polymorphic_primaryjoin.copy_container() - if self.secondaryjoin is not None: - secondaryjoin = self.polymorphic_secondaryjoin.copy_container() - else: - secondaryjoin = None + secondaryjoin = self.polymorphic_secondaryjoin if self.direction is sync.ONETOMANY: - sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) + primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) elif self.direction is sync.MANYTOONE: - sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) + primaryjoin = sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) elif self.secondaryjoin: - sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) + primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) if secondaryjoin is not None: if secondary and not primary: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 37f3232486..0fb8939c93 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -331,15 +331,15 @@ class Query(object): else: if prop.secondary: if create_aliases: - join = prop.get_join(mapper, primary=True, secondary=False).copy_container() + join = prop.get_join(mapper, primary=True, secondary=False) secondary_alias = prop.secondary.alias() if alias is not None: - sql_util.ClauseAdapter(alias).traverse(join) + join = sql_util.ClauseAdapter(alias).traverse(join, clone=True) sql_util.ClauseAdapter(secondary_alias).traverse(join) clause = clause.join(secondary_alias, join) alias = prop.select_table.alias() - join = prop.get_join(mapper, primary=False).copy_container() - sql_util.ClauseAdapter(secondary_alias).traverse(join) + join = prop.get_join(mapper, primary=False) + join = sql_util.ClauseAdapter(secondary_alias).traverse(join, clone=True) sql_util.ClauseAdapter(alias).traverse(join) clause = clause.join(alias, join) else: @@ -347,11 +347,11 @@ class Query(object): clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False)) else: if create_aliases: - join = prop.get_join(mapper).copy_container() + join = prop.get_join(mapper) if alias is not None: - sql_util.ClauseAdapter(alias).traverse(join) + join = sql_util.ClauseAdapter(alias).traverse(join, clone=True) alias = prop.select_table.alias() - sql_util.ClauseAdapter(alias).traverse(join) + join = sql_util.ClauseAdapter(alias).traverse(join, clone=True) clause = clause.join(alias, join) else: clause = clause.join(prop.select_table, prop.get_join(mapper)) @@ -401,7 +401,7 @@ class Query(object): For performance, only use subselect if `order_by` attribute is set. """ - ops = {'distinct':self._distinct, 'order_by':self._order_by, 'from_obj':self._from_obj} + ops = {'distinct':self._distinct, 'order_by':self._order_by or None, 'from_obj':self._from_obj} if self._order_by is not False: s1 = sql.select([col], self._criterion, **ops).alias('u') @@ -781,12 +781,8 @@ class Query(object): # from there context = QueryContext(self) order_by = context.order_by - group_by = context.group_by from_obj = context.from_obj lockmode = context.lockmode - distinct = context.distinct - limit = context.limit - offset = context.offset if order_by is False: order_by = self.mapper.order_by if order_by is False: @@ -821,20 +817,20 @@ class Query(object): else: cf = [] - s2 = sql.select(self.table.primary_key + list(cf), whereclause, use_labels=True, from_obj=from_obj, **context.select_args()) + s2 = sql.select(self.primary_key_columns + list(cf), whereclause, use_labels=True, from_obj=from_obj, correlate=False, **context.select_args()) if order_by: - s2.order_by(*util.to_list(order_by)) + s2 = s2.order_by(*util.to_list(order_by)) s3 = s2.alias('tbl_row_count') - crit = s3.primary_key==self.table.primary_key + crit = s3.primary_key==self.primary_key_columns statement = sql.select([], crit, use_labels=True, for_update=for_update) # now for the order by, convert the columns to their corresponding columns # in the "rowcount" query, and tack that new order by onto the "rowcount" query if order_by: - statement.order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by)) + statement.append_order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by)) else: statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args()) if order_by: - statement.order_by(*util.to_list(order_by)) + statement.append_order_by(*util.to_list(order_by)) # for a DISTINCT query, you need the columns explicitly specified in order # to use it in "order_by". ensure they are in the column criterion (particularly oid). @@ -1101,7 +1097,7 @@ class QueryContext(OperationContext): ``QueryContext`` that can be applied to a ``sql.Select`` statement. """ - return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by} + return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by or None} def accept_option(self, opt): """Accept a ``MapperOption`` which will process (modify) the diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 8de2d00e5f..b4a66a6dc3 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -290,7 +290,7 @@ class LazyLoader(AbstractRelationLoader): # based polymorphic loads on a per-query basis, this code needs to switch between "mapper" and "select_mapper", # probably via the query's own "mapper" property, and also use one of two "lazy" clauses, # one against the "union" the other not - for primary_key in self.select_mapper.pks_by_table[self.select_mapper.mapped_table]: + for primary_key in self.select_mapper.primary_key: bind = self.lazyreverse[primary_key] ident.append(params[bind.key]) return q.get(ident) @@ -303,6 +303,15 @@ class LazyLoader(AbstractRelationLoader): q = q.options(*options) q = q.filter(self.lazywhere).params(**params) + result = q.all() + if self.uselist: + return result + else: + if len(result): + return result[0] + else: + return None + if self.uselist: return q.all() else: @@ -378,16 +387,15 @@ class LazyLoader(AbstractRelationLoader): sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type, unique=True)) reverse[leftcol] = binds[col] - lazywhere = primaryjoin.copy_container() + lazywhere = primaryjoin li = mapperutil.BinaryVisitor(visit_binary) if not secondaryjoin or not reverse_direction: - li.traverse(lazywhere) + lazywhere = li.traverse(lazywhere, clone=True) if secondaryjoin is not None: - secondaryjoin = secondaryjoin.copy_container() if reverse_direction: - li.traverse(secondaryjoin) + secondaryjoin = li.traverse(secondaryjoin, clone=True) lazywhere = sql.and_(lazywhere, secondaryjoin) return (lazywhere, binds, reverse) _create_lazy_clause = classmethod(_create_lazy_clause) @@ -461,18 +469,18 @@ class EagerLoader(AbstractRelationLoader): else: aliasizer = sql_util.ClauseAdapter(self.eagertarget).\ chain(sql_util.ClauseAdapter(self.eagersecondary)) - self.eagersecondaryjoin = eagerloader.polymorphic_secondaryjoin.copy_container() - aliasizer.traverse(self.eagersecondaryjoin) - self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container() - aliasizer.traverse(self.eagerprimary) + self.eagersecondaryjoin = eagerloader.polymorphic_secondaryjoin + self.eagersecondaryjoin = aliasizer.traverse(self.eagersecondaryjoin, clone=True) + self.eagerprimary = eagerloader.polymorphic_primaryjoin + self.eagerprimary = aliasizer.traverse(self.eagerprimary, clone=True) else: - self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container() + self.eagerprimary = eagerloader.polymorphic_primaryjoin if parentclauses is not None: aliasizer = sql_util.ClauseAdapter(self.eagertarget) aliasizer.chain(sql_util.ClauseAdapter(parentclauses.eagertarget, exclude=eagerloader.parent_property.remote_side)) else: aliasizer = sql_util.ClauseAdapter(self.eagertarget) - aliasizer.traverse(self.eagerprimary) + self.eagerprimary = aliasizer.traverse(self.eagerprimary, clone=True) if eagerloader.order_by: self.eager_order_by = sql_util.ClauseAdapter(self.eagertarget).copy_and_process(util.to_list(eagerloader.order_by)) @@ -492,8 +500,15 @@ class EagerLoader(AbstractRelationLoader): if column in self.extra_cols: return self.extra_cols[column] - aliased_column = column.copy_container() - sql_util.ClauseAdapter(self.eagertarget).traverse(aliased_column) + aliased_column = column + # for column-level subqueries, swap out its selectable with our + # eager version as appropriate, and manually build the + # "correlation" list of the subquery. + class ModifySubquery(sql.ClauseVisitor): + def visit_select(s, select): + select._should_correlate = False + select.append_correlation(self.eagertarget) + aliased_column = sql_util.ClauseAdapter(self.eagertarget).chain(ModifySubquery()).traverse(aliased_column, clone=True) alias = self._aliashash(column.name) aliased_column = aliased_column.label(alias) self._row_decorator.map[column] = alias @@ -561,7 +576,7 @@ class EagerLoader(AbstractRelationLoader): # this will locate the selectable inside of any containers it may be a part of (such # as a join). if its inside of a join, we want to outer join on that join, not the # selectable. - for fromclause in statement.froms: + for fromclause in statement.get_display_froms(): if fromclause is localparent.mapped_table: towrap = fromclause break @@ -571,7 +586,7 @@ class EagerLoader(AbstractRelationLoader): break else: raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join to, in query '%s' %s" % (str(statement), localparent.mapped_table)) - + try: clauses = self.clauses[parentclauses] except KeyError: @@ -584,16 +599,17 @@ class EagerLoader(AbstractRelationLoader): if self.secondaryjoin is not None: statement._outerjoin = sql.outerjoin(towrap, clauses.eagersecondary, clauses.eagerprimary).outerjoin(clauses.eagertarget, clauses.eagersecondaryjoin) if self.order_by is False and self.secondary.default_order_by() is not None: - statement.order_by(*clauses.eagersecondary.default_order_by()) + statement.append_order_by(*clauses.eagersecondary.default_order_by()) else: statement._outerjoin = towrap.outerjoin(clauses.eagertarget, clauses.eagerprimary) if self.order_by is False and clauses.eagertarget.default_order_by() is not None: - statement.order_by(*clauses.eagertarget.default_order_by()) + statement.append_order_by(*clauses.eagertarget.default_order_by()) if clauses.eager_order_by: - statement.order_by(*util.to_list(clauses.eager_order_by)) - + statement.append_order_by(*util.to_list(clauses.eager_order_by)) + statement.append_from(statement._outerjoin) + for value in self.select_mapper.props.values(): value.setup(context, eagertable=clauses.eagertarget, parentclauses=clauses, parentmapper=self.select_mapper) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 5b2d229c4b..713064d9cc 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -28,6 +28,8 @@ __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', ' class SchemaItem(object): """Base class for items that define a database schema.""" + __metaclass__ = sql._FigureVisitName + def _init_items(self, *args): """Initialize the list of child items for this SchemaItem.""" @@ -128,7 +130,7 @@ def _get_table_key(name, schema): else: return schema + "." + name -class _TableSingleton(type): +class _TableSingleton(sql._FigureVisitName): """A metaclass used by the ``Table`` object to provide singleton behavior.""" def __call__(self, name, metadata, *args, **kwargs): @@ -721,11 +723,6 @@ class ForeignKey(SchemaItem): column = property(lambda s: s._init_column()) - def accept_visitor(self, visitor): - """Call the `visit_foreign_key` method on the given visitor.""" - - visitor.visit_foreign_key(self) - def _get_parent(self): return self.parent @@ -777,9 +774,6 @@ class PassiveDefault(DefaultGenerator): super(PassiveDefault, self).__init__(**kwargs) self.arg = arg - def accept_visitor(self, visitor): - return visitor.visit_passive_default(self) - def __repr__(self): return "PassiveDefault(%s)" % repr(self.arg) @@ -794,13 +788,12 @@ class ColumnDefault(DefaultGenerator): super(ColumnDefault, self).__init__(**kwargs) self.arg = arg - def accept_visitor(self, visitor): - """Call the visit_column_default method on the given visitor.""" - + def _visit_name(self): if self.for_update: - return visitor.visit_column_onupdate(self) + return "column_onupdate" else: - return visitor.visit_column_default(self) + return "column_default" + __visit_name__ = property(_visit_name) def __repr__(self): return "ColumnDefault(%s)" % repr(self.arg) @@ -834,10 +827,6 @@ class Sequence(DefaultGenerator): def drop(self, connectable=None, checkfirst=True): self.get_engine(connectable=connectable).drop(self, checkfirst=checkfirst) - def accept_visitor(self, visitor): - """Call the visit_seauence method on the given visitor.""" - - return visitor.visit_sequence(self) class Constraint(SchemaItem): """Represent a table-level ``Constraint`` such as a composite primary key, foreign key, or unique constraint. @@ -876,11 +865,12 @@ class CheckConstraint(Constraint): super(CheckConstraint, self).__init__(name) self.sqltext = sqltext - def accept_visitor(self, visitor): + def _visit_name(self): if isinstance(self.parent, Table): - visitor.visit_check_constraint(self) + return "check_constraint" else: - visitor.visit_column_check_constraint(self) + return "column_check_constraint" + __visit_name__ = property(_visit_name) def _set_parent(self, parent): self.parent = parent @@ -909,9 +899,6 @@ class ForeignKeyConstraint(Constraint): for (c, r) in zip(self.__colnames, self.__refcolnames): self.append_element(c,r) - def accept_visitor(self, visitor): - visitor.visit_foreign_key_constraint(self) - def append_element(self, col, refcol): fk = ForeignKey(refcol, constraint=self, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter) fk._set_parent(self.table.c[col]) @@ -935,9 +922,6 @@ class PrimaryKeyConstraint(Constraint): for c in self.__colnames: self.append_column(table.c[c]) - def accept_visitor(self, visitor): - visitor.visit_primary_key_constraint(self) - def add(self, col): self.append_column(col) @@ -969,9 +953,6 @@ class UniqueConstraint(Constraint): def append_column(self, col): self.columns.add(col) - def accept_visitor(self, visitor): - visitor.visit_unique_constraint(self) - def copy(self): return UniqueConstraint(name=self.name, *self.__colnames) @@ -1048,9 +1029,6 @@ class Index(SchemaItem): else: self.get_engine().drop(self) - def accept_visitor(self, visitor): - visitor.visit_index(self) - def __str__(self): return repr(self) @@ -1063,6 +1041,8 @@ class Index(SchemaItem): class MetaData(SchemaItem): """Represent a collection of Tables and their associated schema constructs.""" + __visit_name__ = 'metadata' + def __init__(self, url=None, engine=None, **kwargs): """create a new MetaData object. @@ -1174,9 +1154,6 @@ class MetaData(SchemaItem): connectable = self.get_engine() connectable.drop(self, checkfirst=checkfirst, tables=tables) - def accept_visitor(self, visitor): - visitor.visit_metadata(self) - def _derived_metadata(self): return self @@ -1186,6 +1163,8 @@ class BoundMetaData(MetaData): """ + __visit_name__ = 'metadata' + def __init__(self, engine_or_url, **kwargs): from sqlalchemy.engine.url import URL if isinstance(engine_or_url, basestring) or isinstance(engine_or_url, URL): @@ -1200,6 +1179,8 @@ multiple ``Engine`` implementations on a dynamically alterable, thread-local basis. """ + __visit_name__ = 'metadata' + def __init__(self, threadlocal=True, **kwargs): if threadlocal: self.context = util.ThreadLocal() @@ -1245,61 +1226,3 @@ class SchemaVisitor(sql.ClauseVisitor): """Define the visiting for ``SchemaItem`` objects.""" __traverse_options__ = {'schema_visitor':True} - - def visit_schema(self, schema): - """Visit a generic ``SchemaItem``.""" - pass - - def visit_table(self, table): - """Visit a ``Table``.""" - pass - - def visit_column(self, column): - """Visit a ``Column``.""" - pass - - def visit_foreign_key(self, join): - """Visit a ``ForeignKey``.""" - pass - - def visit_index(self, index): - """Visit an ``Index``.""" - pass - - def visit_passive_default(self, default): - """Visit a passive default.""" - pass - - def visit_column_default(self, default): - """Visit a ``ColumnDefault``.""" - pass - - def visit_column_onupdate(self, onupdate): - """Visit a ``ColumnDefault`` with the `for_update` flag set.""" - pass - - def visit_sequence(self, sequence): - """Visit a ``Sequence``.""" - pass - - def visit_primary_key_constraint(self, constraint): - """Visit a ``PrimaryKeyConstraint``.""" - pass - - def visit_foreign_key_constraint(self, constraint): - """Visit a ``ForeignKeyConstraint``.""" - pass - - def visit_unique_constraint(self, constraint): - """Visit a ``UniqueConstraint``.""" - pass - - def visit_check_constraint(self, constraint): - """Visit a ``CheckConstraint``.""" - pass - - def visit_column_check_constraint(self, constraint): - """Visit a ``CheckConstraint`` on a ``Column``.""" - pass - - diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 31aa4788ac..afeb7dd668 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -28,11 +28,10 @@ from sqlalchemy import util, exceptions, logging from sqlalchemy import types as sqltypes import string, re, random, sets - __all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters', 'ClauseVisitor', 'ColumnCollection', 'ColumnElement', - 'Compiled', 'CompoundSelect', 'Executor', 'FromClause', 'Join', - 'Select', 'Selectable', 'TableClause', 'alias', 'and_', 'asc', + 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', + 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', 'between_', 'bindparam', 'case', 'cast', 'column', 'delete', 'desc', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier', 'insert', 'intersect', 'intersect_all', 'join', 'literal', @@ -126,7 +125,7 @@ def join(left, right, onclause=None, **kwargs): return Join(left, right, onclause, **kwargs) -def select(columns=None, whereclause = None, from_obj = [], **kwargs): +def select(columns=None, whereclause=None, from_obj=[], **kwargs): """Returns a ``SELECT`` clause element. Similar functionality is also available via the ``select()`` method on any @@ -237,7 +236,7 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs): """ - return Select(columns, whereclause = whereclause, from_obj = from_obj, **kwargs) + return Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs) def subquery(alias, *args, **kwargs): """Return an [sqlalchemy.sql#Alias] object derived from a [sqlalchemy.sql#Select]. @@ -253,7 +252,7 @@ def subquery(alias, *args, **kwargs): return Select(*args, **kwargs).alias(alias) def insert(table, values = None, **kwargs): - """Return an [sqlalchemy.sql#_Insert] clause element. + """Return an [sqlalchemy.sql#Insert] clause element. Similar functionality is available via the ``insert()`` method on [sqlalchemy.schema#Table]. @@ -286,10 +285,10 @@ def insert(table, values = None, **kwargs): against the ``INSERT`` statement. """ - return _Insert(table, values, **kwargs) + return Insert(table, values, **kwargs) def update(table, whereclause = None, values = None, **kwargs): - """Return an [sqlalchemy.sql#_Update] clause element. + """Return an [sqlalchemy.sql#Update] clause element. Similar functionality is available via the ``update()`` method on [sqlalchemy.schema#Table]. @@ -326,10 +325,10 @@ def update(table, whereclause = None, values = None, **kwargs): against the ``UPDATE`` statement. """ - return _Update(table, whereclause, values, **kwargs) + return Update(table, whereclause, values, **kwargs) def delete(table, whereclause = None, **kwargs): - """Return a [sqlalchemy.sql#_Delete] clause element. + """Return a [sqlalchemy.sql#Delete] clause element. Similar functionality is available via the ``delete()`` method on [sqlalchemy.schema#Table]. @@ -343,7 +342,7 @@ def delete(table, whereclause = None, **kwargs): """ - return _Delete(table, whereclause, **kwargs) + return Delete(table, whereclause, **kwargs) def and_(*clauses): """Join a list of clauses together using the ``AND`` operator. @@ -384,7 +383,7 @@ def between(ctest, cleft, cright): provides similar functionality. """ - return _BinaryExpression(ctest, and_(_literals_as_binds(cleft, type=ctest.type), _literals_as_binds(cright, type=ctest.type)), 'BETWEEN') + return _BinaryExpression(ctest, and_(_literal_as_binds(cleft, type=ctest.type), _literal_as_binds(cright, type=ctest.type)), 'BETWEEN') def between_(*args, **kwargs): """synonym for [sqlalchemy.sql#between()] (deprecated).""" @@ -757,13 +756,13 @@ def _compound_select(keyword, *selects, **kwargs): def _is_literal(element): return not isinstance(element, ClauseElement) -def _literals_as_text(element): +def _literal_as_text(element): if _is_literal(element): return _TextClause(unicode(element)) else: return element -def _literals_as_binds(element, name='literal', type=None): +def _literal_as_binds(element, name='literal', type=None): if _is_literal(element): if element is None: return null() @@ -860,22 +859,46 @@ class ClauseVisitor(object): these options can indicate modifications to the set of elements returned, such as to not return column collections (column_collections=False) or to return Schema-level items - (schema_visitor=True).""" + (schema_visitor=True). + + """ __traverse_options__ = {} - def traverse(self, obj, stop_on=None): - stack = [obj] - traversal = [] - while len(stack) > 0: - t = stack.pop() - if stop_on is None or t not in stop_on: - traversal.insert(0, t) - for c in t.get_children(**self.__traverse_options__): - stack.append(c) - for target in traversal: - v = self - while v is not None: - target.accept_visitor(v) - v = getattr(v, '_next', None) + + def traverse_single(self, obj): + meth = getattr(self, "visit_%s" % obj.__visit_name__, None) + if meth: + return meth(obj) + + def traverse(self, obj, stop_on=None, clone=False): + if clone: + obj = obj._clone() + + # entry flag indicates to also call a before-descent "enter_XXXX" method + entry = self.__traverse_options__.get('entry', False) + + v = self + visitors = [] + while v is not None: + visitors.append(v) + v = getattr(v, '_next', None) + + def _trav(obj): + if stop_on is not None and obj in stop_on: + return + if entry: + for v in visitors: + meth = getattr(v, "enter_%s" % obj.__visit_name__, None) + if meth: + meth(obj) + + for c in obj.get_children(clone=clone, **self.__traverse_options__): + _trav(c) + + for v in visitors: + meth = getattr(v, "visit_%s" % obj.__visit_name__, None) + if meth: + meth(obj) + _trav(obj) return obj def chain(self, visitor): @@ -887,78 +910,6 @@ class ClauseVisitor(object): tail = tail._next tail._next = visitor return self - - def visit_column(self, column): - pass - def visit_table(self, table): - pass - def visit_fromclause(self, fromclause): - pass - def visit_bindparam(self, bindparam): - pass - def visit_textclause(self, textclause): - pass - def visit_compound(self, compound): - pass - def visit_compound_select(self, compound): - pass - def visit_binary(self, binary): - pass - def visit_unary(self, unary): - pass - def visit_alias(self, alias): - pass - def visit_select(self, select): - pass - def visit_join(self, join): - pass - def visit_null(self, null): - pass - def visit_clauselist(self, list): - pass - def visit_calculatedclause(self, calcclause): - pass - def visit_grouping(self, gr): - pass - def visit_function(self, func): - pass - def visit_cast(self, cast): - pass - def visit_label(self, label): - pass - def visit_typeclause(self, typeclause): - pass - -class LoggingClauseVisitor(ClauseVisitor): - """extends ClauseVisitor to include debug logging of all traversal. - - To install this visitor, set logging.DEBUG for - 'sqlalchemy.sql.ClauseVisitor' **before** you import the - sqlalchemy.sql module. - """ - - def traverse(self, obj, stop_on=None): - stack = [(obj, "")] - traversal = [] - while len(stack) > 0: - (t, indent) = stack.pop() - if stop_on is None or t not in stop_on: - traversal.insert(0, (t, indent)) - for c in t.get_children(**self.__traverse_options__): - stack.append((c, indent + " ")) - - for (target, indent) in traversal: - self.logger.debug(indent + repr(target)) - v = self - while v is not None: - target.accept_visitor(v) - v = getattr(v, '_next', None) - return obj - -LoggingClauseVisitor.logger = logging.class_logger(ClauseVisitor) - -if logging.is_debug_enabled(LoggingClauseVisitor.logger): - ClauseVisitor=LoggingClauseVisitor class NoColumnVisitor(ClauseVisitor): """a ClauseVisitor that will not traverse the exported Column @@ -971,109 +922,31 @@ class NoColumnVisitor(ClauseVisitor): """ __traverse_options__ = {'column_collections':False} - -class Executor(object): - """Interface representing a "thing that can produce Compiled objects - and execute them".""" - def execute_compiled(self, compiled, parameters, echo=None, **kwargs): - """Execute a Compiled object.""" - - raise NotImplementedError() - - def compiler(self, statement, parameters, **kwargs): - """Return a Compiled object for the given statement and parameters.""" - - raise NotImplementedError() - -class Compiled(ClauseVisitor): - """Represent a compiled SQL expression. - - The ``__str__`` method of the ``Compiled`` object should produce - the actual text of the statement. ``Compiled`` objects are - specific to their underlying database dialect, and also may - or may not be specific to the columns referenced within a - particular set of bind parameters. In no case should the - ``Compiled`` object be dependent on the actual values of those - bind parameters, even though it may reference those values as - defaults. - """ - - def __init__(self, dialect, statement, parameters, engine=None): - """Construct a new ``Compiled`` object. - - statement - ``ClauseElement`` to be compiled. - - parameters - Optional dictionary indicating a set of bind parameters - specified with this ``Compiled`` object. These parameters - are the *default* values corresponding to the - ``ClauseElement``'s ``_BindParamClauses`` when the - ``Compiled`` is executed. In the case of an ``INSERT`` or - ``UPDATE`` statement, these parameters will also result in - the creation of new ``_BindParamClause`` objects for each - key and will also affect the generated column list in an - ``INSERT`` statement and the ``SET`` clauses of an - ``UPDATE`` statement. The keys of the parameter dictionary - can either be the string names of columns or - ``_ColumnClause`` objects. - - engine - Optional Engine to compile this statement against. - """ - self.dialect = dialect - self.statement = statement - self.parameters = parameters - self.engine = engine - self.can_execute = statement.supports_execution() - - def compile(self): - self.traverse(self.statement) - self.after_compile() - - def __str__(self): - """Return the string text of the generated SQL statement.""" - - raise NotImplementedError() - def get_params(self, **params): - """Deprecated. use construct_params(). (supports unicode names) - """ - - return self.construct_params(params) - - def construct_params(self, params): - """Return the bind params for this compiled object. - - Will start with the default parameters specified when this - ``Compiled`` object was first constructed, and will override - those values with those sent via `**params`, which are - key/value pairs. Each key should match one of the - ``_BindParamClause`` objects compiled into this object; either - the `key` or `shortname` property of the ``_BindParamClause``. - """ - raise NotImplementedError() +class _FigureVisitName(type): + def __init__(cls, clsname, bases, dict): + if not '__visit_name__' in cls.__dict__: + m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname) + x = m.group(1) + x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x) + cls.__visit_name__ = x.lower() + super(_FigureVisitName, cls).__init__(clsname, bases, dict) - def execute(self, *multiparams, **params): - """Execute this compiled object.""" - - e = self.engine - if e is None: - raise exceptions.InvalidRequestError("This Compiled object is not bound to any engine.") - return e.execute_compiled(self, *multiparams, **params) - - def scalar(self, *multiparams, **params): - """Execute this compiled object and return the result's scalar value.""" - - return self.execute(*multiparams, **params).scalar() - class ClauseElement(object): """Base class for elements of a programmatically constructed SQL expression. """ + __metaclass__ = _FigureVisitName + + def _clone(self): + # shallow copy. mutator operations always create + # clones of container objects. + c = self.__class__.__new__(self.__class__) + c.__dict__ = self.__dict__.copy() + return c - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): """Return objects represented in this ``ClauseElement`` that should be added to the ``FROM`` list of a query, when this ``ClauseElement`` is placed in the column clause of a @@ -1082,7 +955,7 @@ class ClauseElement(object): raise NotImplementedError(repr(self)) - def _hide_froms(self): + def _hide_froms(self, **modifiers): """Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces. """ @@ -1098,18 +971,16 @@ class ClauseElement(object): return self is other - def accept_visitor(self, visitor): - """Accept a ``ClauseVisitor`` and call the appropriate - ``visit_xxx`` method. - """ - - raise NotImplementedError(repr(self)) - - def get_children(self, **kwargs): + def get_children(self, clone=False, **kwargs): """return immediate child elements of this ``ClauseElement``. this is used for visit traversal. + clone indicates child items should be _cloned(), replacing + the elements contained by this element, and the cloned + copy returned. this allows modifying traversals + to take place. + \**kwargs may contain flags that change the collection that is returned, for example to return a subset of items in order to cut down on larger traversals, or to return @@ -1127,18 +998,6 @@ class ClauseElement(object): return False - def copy_container(self): - """Return a copy of this ``ClauseElement``, if this - ``ClauseElement`` contains other ``ClauseElements``. - - If this ``ClauseElement`` is not a container, it should return - self. This is used to create copies of expression trees that - still reference the same *leaf nodes*. The new structure can - then be restructured without affecting the original. - """ - - return self - def _find_engine(self): """Default strategy for locating an engine within the clause element. @@ -1429,9 +1288,6 @@ class Selectable(ClauseElement): def _selectable(self): return self - def accept_visitor(self, visitor): - raise NotImplementedError(repr(self)) - def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) @@ -1589,19 +1445,18 @@ class FromClause(Selectable): clause of a ``SELECT`` statement. """ + __visit_name__ = 'fromclause' + def __init__(self, name=None): self.name = name - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): # this could also be [self], at the moment it doesnt matter to the Select object return [] def default_order_by(self): return [self.oid_column] - def accept_visitor(self, visitor): - visitor.visit_fromclause(self) - def count(self, whereclause=None, **params): if len(self.primary_key): col = list(self.primary_key)[0] @@ -1643,6 +1498,13 @@ class FromClause(Selectable): FindCols().traverse(self) return ret + def is_derived_from(self, fromclause): + """return True if this FromClause is 'derived' from the given FromClause. + + An example would be an Alias of a Table is derived from that Table.""" + + return False + def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False): """Given a ``ColumnElement``, return the exported ``ColumnElement`` object from this ``Selectable`` which @@ -1701,6 +1563,15 @@ class FromClause(Selectable): self._export_columns() return getattr(self, name) + def _clone_from_clause(self): + # delete all the "generated" collections of columns for a newly cloned FromClause, + # so that they will be re-derived from the item. + # this is because FromClause subclasses, when cloned, need to reestablish new "proxied" + # columns that are linked to the new item + for attr in ('_columns', '_primary_key' '_foreign_keys', '_orig_cols', '_oid_column'): + if hasattr(self, attr): + delattr(self, attr) + columns = property(lambda s:s._get_exported_attribute('_columns')) c = property(lambda s:s._get_exported_attribute('_columns')) primary_key = property(lambda s:s._get_exported_attribute('_primary_key')) @@ -1731,7 +1602,7 @@ class FromClause(Selectable): self._primary_key = ColumnCollection() self._foreign_keys = util.Set() self._orig_cols = {} - for co in self._adjusted_exportable_columns(): + for co in self._flatten_exportable_columns(): cp = self._proxy_column(co) for ci in cp.orig_set: # note that some ambiguity is raised here, whereby a selectable might have more than @@ -1741,13 +1612,13 @@ class FromClause(Selectable): for ci in self.oid_column.orig_set: self._orig_cols[ci] = self.oid_column - def _adjusted_exportable_columns(self): + def _flatten_exportable_columns(self): """return the list of ColumnElements represented within this FromClause's _exportable_columns""" export = self._exportable_columns() for column in export: - try: + if hasattr(column, '_selectable'): s = column._selectable() - except AttributeError: + else: continue for co in s.columns: yield co @@ -1764,6 +1635,8 @@ class _BindParamClause(ClauseElement, _CompareMixin): Public constructor is the ``bindparam()`` function. """ + __visit_name__ = 'bindparam' + def __init__(self, key, value, shortname=None, type=None, unique=False): """Construct a _BindParamClause. @@ -1805,15 +1678,9 @@ class _BindParamClause(ClauseElement, _CompareMixin): self.unique = unique self.type = sqltypes.to_instance(type) - def accept_visitor(self, visitor): - visitor.visit_bindparam(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): return [] - def copy_container(self): - return _BindParamClause(self.key, self.value, self.shortname, self.type, unique=self.unique) - def typeprocess(self, value, dialect): return self.type.dialect_impl(dialect).convert_bind_param(value, dialect) @@ -1836,13 +1703,12 @@ class _TypeClause(ClauseElement): Used by the ``Case`` statement. """ + __visit_name__ = 'typeclause' + def __init__(self, type): self.type = type - def accept_visitor(self, visitor): - visitor.visit_typeclause(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): return [] class _TextClause(ClauseElement): @@ -1851,6 +1717,8 @@ class _TextClause(ClauseElement): Public constructor is the ``text()`` function. """ + __visit_name__ = 'textclause' + def __init__(self, text = "", engine=None, bindparams=None, typemap=None): self._engine = engine self.bindparams = {} @@ -1879,13 +1747,13 @@ class _TextClause(ClauseElement): columns = property(lambda s:[]) - def get_children(self, **kwargs): + def get_children(self, clone=False, **kwargs): + if clone: + self.bindparams = [b._clone() for b in self.bindparams] + return self.bindparams.values() - def accept_visitor(self, visitor): - visitor.visit_textclause(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): return [] def supports_execution(self): @@ -1900,10 +1768,7 @@ class _Null(ColumnElement): def __init__(self): self.type = sqltypes.NULLTYPE - def accept_visitor(self, visitor): - visitor.visit_null(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): return [] class ClauseList(ClauseElement): @@ -1911,14 +1776,16 @@ class ClauseList(ClauseElement): By default, is comma-separated, such as a column listing. """ - + __visit_name__ = 'clauselist' + def __init__(self, *clauses, **kwargs): self.clauses = [] self.operator = kwargs.pop('operator', ',') self.group = kwargs.pop('group', True) self.group_contents = kwargs.pop('group_contents', True) for c in clauses: - if c is None: continue + if c is None: + continue self.append(c) def __iter__(self): @@ -1926,10 +1793,6 @@ class ClauseList(ClauseElement): def __len__(self): return len(self.clauses) - def copy_container(self): - clauses = [clause.copy_container() for clause in self.clauses] - return ClauseList(operator=self.operator, *clauses) - def self_group(self, against=None): if self.group: return _Grouping(self) @@ -1940,20 +1803,20 @@ class ClauseList(ClauseElement): # TODO: not sure if i like the 'group_contents' flag. need to define the difference between # a ClauseList of ClauseLists, and a "flattened" ClauseList of ClauseLists. flatten() method ? if self.group_contents: - self.clauses.append(_literals_as_text(clause).self_group(against=self.operator)) + self.clauses.append(_literal_as_text(clause).self_group(against=self.operator)) else: - self.clauses.append(_literals_as_text(clause)) + self.clauses.append(_literal_as_text(clause)) - def get_children(self, **kwargs): + def get_children(self, clone=False, **kwargs): + if clone: + self.clauses = [clause._clone() for clause in self.clauses] + return self.clauses - def accept_visitor(self, visitor): - visitor.visit_clauselist(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): f = [] for c in self.clauses: - f += c._get_from_objects() + f += c._get_from_objects(**modifiers) return f def self_group(self, against=None): @@ -1984,7 +1847,8 @@ class _CalculatedClause(ColumnElement): Extends ``ColumnElement`` to provide column-level comparison operators. """ - + __visit_name__ = 'calculatedclause' + def __init__(self, name, *clauses, **kwargs): self.name = name self.type = sqltypes.to_instance(kwargs.get('type', None)) @@ -1998,17 +1862,13 @@ class _CalculatedClause(ColumnElement): key = property(lambda self:self.name or "_calc_") - def copy_container(self): - clauses = [clause.copy_container() for clause in self.clauses] - return _CalculatedClause(type=self.type, engine=self._engine, *clauses) - - def get_children(self, **kwargs): + def get_children(self, clone=False, **kwargs): + if clone: + self.clause_expr = self.clause_expr._clone() return self.clause_expr, - def accept_visitor(self, visitor): - visitor.visit_calculatedclause(self) - def _get_from_objects(self): - return self.clauses._get_from_objects() + def _get_from_objects(self, **modifiers): + return self.clauses._get_from_objects(**modifiers) def _bind_param(self, obj): return _BindParamClause(self.name, obj, type=self.type, unique=True) @@ -2043,18 +1903,16 @@ class _Function(_CalculatedClause, FromClause): key = property(lambda self:self.name) - - def append(self, clause): - self.clauses.append(_literals_as_binds(clause, self.name)) - - def copy_container(self): - clauses = [clause.copy_container() for clause in self.clauses] - return _Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses) + def get_children(self, clone=False, **kwargs): + if clone: + self._clone_from_clause() + return _CalculatedClause.get_children(self, clone=clone, **kwargs) - def accept_visitor(self, visitor): - visitor.visit_function(self) + def append(self, clause): + self.clauses.append(_literal_as_binds(clause, self.name)) class _Cast(ColumnElement): + def __init__(self, clause, totype, **kwargs): if not hasattr(clause, 'label'): clause = literal(clause) @@ -2062,13 +1920,15 @@ class _Cast(ColumnElement): self.clause = clause self.typeclause = _TypeClause(self.type) - def get_children(self, **kwargs): + def get_children(self, clone=False, **kwargs): + if clone: + self.clause = self.clause._clone() + self.typeclause = self.typeclause._clone() + return self.clause, self.typeclause - def accept_visitor(self, visitor): - visitor.visit_cast(self) - def _get_from_objects(self): - return self.clause._get_from_objects() + def _get_from_objects(self, **modifiers): + return self.clause._get_from_objects(**modifiers) def _make_proxy(self, selectable, name=None): if name is not None: @@ -2085,22 +1945,18 @@ class _UnaryExpression(ColumnElement): self.operator = operator self.modifier = modifier - self.element = _literals_as_text(element).self_group(against=self.operator or self.modifier) + self.element = _literal_as_text(element).self_group(against=self.operator or self.modifier) self.type = sqltypes.to_instance(type) self.negate = negate - def copy_container(self): - return self.__class__(self.element.copy_container(), operator=self.operator, modifier=self.modifier, type=self.type, negate=self.negate) - - def _get_from_objects(self): - return self.element._get_from_objects() + def _get_from_objects(self, **modifiers): + return self.element._get_from_objects(**modifiers) - def get_children(self, **kwargs): + def get_children(self, clone=False, **kwargs): + if clone: + self.element = self.element._clone() return self.element, - def accept_visitor(self, visitor): - visitor.visit_unary(self) - def compare(self, other): """Compare this ``_UnaryClause`` against the given ``ClauseElement``.""" @@ -2109,6 +1965,7 @@ class _UnaryExpression(ColumnElement): self.modifier == other.modifier and self.element.compare(other.element) ) + def _negate(self): if self.negate is not None: return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type) @@ -2120,24 +1977,22 @@ class _BinaryExpression(ColumnElement): """Represent an expression that is ``LEFT RIGHT``.""" def __init__(self, left, right, operator, type=None, negate=None): - self.left = _literals_as_text(left).self_group(against=operator) - self.right = _literals_as_text(right).self_group(against=operator) + self.left = _literal_as_text(left).self_group(against=operator) + self.right = _literal_as_text(right).self_group(against=operator) self.operator = operator self.type = sqltypes.to_instance(type) self.negate = negate - def copy_container(self): - return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator) - - def _get_from_objects(self): - return self.left._get_from_objects() + self.right._get_from_objects() + def _get_from_objects(self, **modifiers): + return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) - def get_children(self, **kwargs): + def get_children(self, clone=False, **kwargs): + if clone: + self.left = self.left._clone() + self.right = self.right._clone() + return self.left, self.right - def accept_visitor(self, visitor): - visitor.visit_binary(self) - def compare(self, other): """Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``.""" @@ -2159,13 +2014,15 @@ class _BinaryExpression(ColumnElement): return super(_BinaryExpression, self)._negate() class _Exists(_UnaryExpression): + __visit_name__ = _UnaryExpression.__visit_name__ + def __init__(self, *args, **kwargs): kwargs['correlate'] = True s = select(*args, **kwargs).self_group() _UnaryExpression.__init__(self, s, operator="EXISTS") - def _hide_froms(self): - return self._get_from_objects() + def _hide_froms(self, **modifiers): + return self._get_from_objects(**modifiers) class Join(FromClause): """represent a ``JOIN`` construct between two ``FromClause`` @@ -2192,7 +2049,7 @@ class Join(FromClause): def _init_primary_key(self): pkcol = util.OrderedSet() - for col in self._adjusted_exportable_columns(): + for col in self._flatten_exportable_columns(): if col.primary_key: pkcol.add(col) for col in list(pkcol): @@ -2213,6 +2070,16 @@ class Join(FromClause): self._foreign_keys.add(f) return column + def get_children(self, clone=False, **kwargs): + if clone: + self._clone_from_clause() + self.left = self.left._clone() + self.right = self.right._clone() + self.onclause = self.onclause._clone() + self.__folded_equivalents = None + self._init_primary_key() + return self.left, self.right, self.onclause + def _match_primaries(self, primary, secondary): crit = [] constraints = util.Set() @@ -2300,12 +2167,6 @@ class Join(FromClause): return select(collist, whereclause, from_obj=[self], **kwargs) - def get_children(self, **kwargs): - return self.left, self.right, self.onclause - - def accept_visitor(self, visitor): - visitor.visit_join(self) - engine = property(lambda s:s.left.engine or s.right.engine) def alias(self, name=None): @@ -2316,11 +2177,11 @@ class Join(FromClause): return self.select(use_labels=True, correlate=False).alias(name) - def _hide_froms(self): - return self.left._get_from_objects() + self.right._get_from_objects() + def _hide_froms(self, **modifiers): + return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) - def _get_from_objects(self): - return [self] + self.onclause._get_from_objects() + self.left._get_from_objects() + self.right._get_from_objects() + def _get_from_objects(self, **modifiers): + return [self] + self.onclause._get_from_objects(**modifiers) + self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) class Alias(FromClause): """represent an alias, as typically applied to any @@ -2351,6 +2212,14 @@ class Alias(FromClause): self.encodedname = alias.encode('ascii', 'backslashreplace') self.case_sensitive = getattr(baseselectable, "case_sensitive", True) + def is_derived_from(self, fromclause): + x = self.selectable + while isinstance(x, Alias): + x = x.selectable + if x is fromclause: + return True + return False + def supports_execution(self): return self.original.supports_execution() @@ -2367,14 +2236,18 @@ class Alias(FromClause): #return self.selectable._exportable_columns() return self.selectable.columns - def get_children(self, **kwargs): + def get_children(self, clone=False, **kwargs): + if clone: + self._clone_from_clause() + self.selectable = self.selectable._clone() + baseselectable = self.selectable + while isinstance(baseselectable, Alias): + baseselectable = baseselectable.selectable + self.original = baseselectable for c in self.c: yield c yield self.selectable - def accept_visitor(self, visitor): - visitor.visit_alias(self) - def _get_from_objects(self): return [self] @@ -2392,17 +2265,16 @@ class _Grouping(ColumnElement): _label = property(lambda s: s.elem._label) orig_set = property(lambda s:s.elem.orig_set) - def copy_container(self): - return _Grouping(self.elem.copy_container()) - - def accept_visitor(self, visitor): - visitor.visit_grouping(self) - def get_children(self, **kwargs): + def get_children(self, clone=False, **kwargs): + if clone: + self.elem = self.elem._clone() return self.elem, - def _hide_froms(self): - return self.elem._hide_froms() - def _get_from_objects(self): - return self.elem._get_from_objects() + + def _hide_froms(self, **modifiers): + return self.elem._hide_froms(**modifiers) + + def _get_from_objects(self, **modifiers): + return self.elem._get_from_objects(**modifiers) class _Label(ColumnElement): """represent a label, as typically applied to any column-level element @@ -2429,17 +2301,16 @@ class _Label(ColumnElement): def _compare_self(self): return self.obj - def get_children(self, **kwargs): + def get_children(self, clone=False, **kwargs): + if clone: + self.obj = self.obj._clone() return self.obj, - def accept_visitor(self, visitor): - visitor.visit_label(self) + def _get_from_objects(self, **modifiers): + return self.obj._get_from_objects(**modifiers) - def _get_from_objects(self): - return self.obj._get_from_objects() - - def _hide_froms(self): - return self.obj._hide_froms() + def _hide_froms(self, **modifiers): + return self.obj._hide_froms(**modifiers) def _make_proxy(self, selectable, name = None): if isinstance(self.obj, Selectable): @@ -2489,7 +2360,11 @@ class _ColumnClause(ColumnElement): self.__label = None self.case_sensitive = case_sensitive self.is_literal = is_literal - + + def _clone(self): + # ColumnClause is immutable + return self + def _get_label(self): """Generate a 'label' for this column. @@ -2527,10 +2402,7 @@ class _ColumnClause(ColumnElement): else: return super(_ColumnClause, self).label(name) - def accept_visitor(self, visitor): - visitor.visit_column(self) - - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): if self.table is not None: return [self.table] else: @@ -2575,6 +2447,10 @@ class TableClause(FromClause): self.append_column(c) self._oid_column = _ColumnClause('oid', self, _is_oid=True) + def _clone(self): + # TableClause is immutable + return self + def named_with_column(self): return True @@ -2603,9 +2479,6 @@ class TableClause(FromClause): else: return [] - def accept_visitor(self, visitor): - visitor.visit_table(self) - def _exportable_columns(self): raise NotImplementedError() @@ -2640,67 +2513,95 @@ class TableClause(FromClause): def delete(self, whereclause = None): return delete(self, whereclause) - def _get_from_objects(self): + def _get_from_objects(self, **modifiers): return [self] class _SelectBaseMixin(object): """Base class for ``Select`` and ``CompoundSelects``.""" + def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, connectable=None, scalar=False, engine=None): + self.use_labels = use_labels + self.for_update = for_update + self._limit = limit + self._offset = offset + self._engine = connectable or engine + self.is_scalar = scalar + if self.is_scalar: + # allow corresponding_column to return None + self.orig_set = util.Set() + + self.append_order_by(*util.to_list(order_by, [])) + self.append_group_by(*util.to_list(group_by, [])) + def supports_execution(self): return True + def _generate(self): + s = self._clone() + s._clone_from_clause() + return s + + def limit(self, limit): + s = self._generate() + s._limit = limit + return s + + def offset(self, offset): + s = self._generate() + s._offset = offset + return s + def order_by(self, *clauses): - if len(clauses) == 1 and clauses[0] is None: - self.order_by_clause = ClauseList() - elif getattr(self, 'order_by_clause', None): - self.order_by_clause = ClauseList(*(list(self.order_by_clause.clauses) + list(clauses))) - else: - self.order_by_clause = ClauseList(*clauses) + s = self._generate() + s.append_order_by(*clauses) + return s def group_by(self, *clauses): - if len(clauses) == 1 and clauses[0] is None: - self.group_by_clause = ClauseList() - elif getattr(self, 'group_by_clause', None): - self.group_by_clause = ClauseList(*(list(clauses)+list(self.group_by_clause.clauses))) + s = self._generate() + s.append_group_by(*clauses) + return s + + def append_order_by(self, *clauses): + if clauses == [None]: + self._order_by_clause = ClauseList() else: - self.group_by_clause = ClauseList(*clauses) + if getattr(self, '_order_by_clause', None): + clauses = list(self._order_by_clause) + list(clauses) + self._order_by_clause = ClauseList(*clauses) + def append_group_by(self, *clauses): + if clauses == [None]: + self._group_by_clause = ClauseList() + else: + if getattr(self, '_group_by_clause', None): + clauses = list(self._group_by_clause) + list(clauses) + self._group_by_clause = ClauseList(*clauses) + def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) - def _get_from_objects(self): - if self.is_where or self.is_scalar: + def _get_from_objects(self, is_where=False, **modifiers): + if is_where or self.is_scalar: return [] else: return [self] class CompoundSelect(_SelectBaseMixin, FromClause): def __init__(self, keyword, *selects, **kwargs): - _SelectBaseMixin.__init__(self) + self._should_correlate = kwargs.pop('correlate', False) self.keyword = keyword - self.use_labels = kwargs.pop('use_labels', False) - self.should_correlate = kwargs.pop('correlate', False) - self.for_update = kwargs.pop('for_update', False) - self.nowait = kwargs.pop('nowait', False) - self.limit = kwargs.pop('limit', None) - self.offset = kwargs.pop('offset', None) - self.is_compound = True - self.is_where = False - self.is_scalar = False - self.is_subquery = False - - self.selects = selects + self.selects = [] # some DBs do not like ORDER BY in the inner queries of a UNION, etc. for s in selects: - s.order_by(None) + if len(s._order_by_clause): + s = s.order_by(None) + self.selects.append(s) - self.group_by(*kwargs.pop('group_by', [None])) - self.order_by(*kwargs.pop('order_by', [None])) - if len(kwargs): - raise TypeError("invalid keyword argument(s) for CompoundSelect: %s" % repr(kwargs.keys())) self._col_map = {} + _SelectBaseMixin.__init__(self, **kwargs) + name = property(lambda s:s.keyword + " statement") def self_group(self, against=None): @@ -2728,12 +2629,18 @@ class CompoundSelect(_SelectBaseMixin, FromClause): col.orig_set = colset return col - def get_children(self, column_collections=True, **kwargs): - return (column_collections and list(self.c) or []) + \ - [self.order_by_clause, self.group_by_clause] + list(self.selects) - def accept_visitor(self, visitor): - visitor.visit_compound_select(self) + def get_children(self, clone=False, column_collections=True, **kwargs): + if clone: + self._clone_from_clause() + self._col_map = {} + self.selects = [s._clone() for s in self.selects] + for attr in ('_order_by_clause', '_group_by_clause'): + if getattr(self, attr) is not None: + setattr(self, attr, getattr(self, attr)._clone()) + return (column_collections and list(self.c) or []) + \ + [self._order_by_clause, self._group_by_clause] + list(self.selects) + def _find_engine(self): for s in self.selects: e = s._find_engine() @@ -2748,127 +2655,212 @@ class Select(_SelectBaseMixin, FromClause): """ - def __init__(self, columns=None, whereclause=None, from_obj=[], - order_by=None, group_by=None, having=None, - use_labels=False, distinct=False, for_update=False, - engine=None, limit=None, offset=None, scalar=False, - correlate=True): + def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, **kwargs): """construct a Select object. The public constructor for Select is the [sqlalchemy.sql#select()] function; see that function for argument descriptions. """ - _SelectBaseMixin.__init__(self) - self.__froms = util.OrderedSet() - self.__hide_froms = util.Set([self]) - self.use_labels = use_labels - self.whereclause = None - self.having = None - self._engine = engine - self.limit = limit - self.offset = offset - self.for_update = for_update - self.is_compound = False - # indicates that this select statement should not expand its columns - # into the column clause of an enclosing select, and should instead - # act like a single scalar column - self.is_scalar = scalar - if scalar: - # allow corresponding_column to return None - self.orig_set = util.Set() - - # indicates if this select statement, as a subquery, should automatically correlate - # its FROM clause to that of an enclosing select, update, or delete statement. - # note that the "correlate" method can be used to explicitly add a value to be correlated. - self.should_correlate = correlate - - # indicates if this select statement is a subquery inside another query - self.is_subquery = False - - # indicates if this select statement is in the from clause of another query - self.is_selected_from = False + self._should_correlate = correlate + self._distinct = distinct - # indicates if this select statement is a subquery as a criterion - # inside of a WHERE clause - self.is_where = False - - self.distinct = distinct self._raw_columns = [] - self.__correlated = {} - self.__correlator = Select._CorrelatedVisitor(self, False) - self.__wherecorrelator = Select._CorrelatedVisitor(self, True) - self.__fromvisitor = Select._FromVisitor(self) - - - self.order_by_clause = self.group_by_clause = None + self.__correlate = util.Set() + self._froms = util.OrderedSet() + self._whereclause = None + self._having = None if columns is not None: for c in columns: self.append_column(c) - if order_by: - order_by = util.to_list(order_by) - if group_by: - group_by = util.to_list(group_by) - self.order_by(*(order_by or [None])) - self.group_by(*(group_by or [None])) - for c in self.order_by_clause: - self.__correlator.traverse(c) - for c in self.group_by_clause: - self.__correlator.traverse(c) - - for f in from_obj: - self.append_from(f) - - # whereclauses must be appended after the columns/FROM, since it affects - # the correlation of subqueries. see test/sql/select.py SelectTest.testwheresubquery + if from_obj is not None: + for f in from_obj: + self.append_from(f) + if whereclause is not None: self.append_whereclause(whereclause) + if having is not None: self.append_having(having) + _SelectBaseMixin.__init__(self, **kwargs) - class _CorrelatedVisitor(NoColumnVisitor): - """Visit a clause, locate any ``Select`` clauses, and tell - them that they should correlate their ``FROM`` list to that of - their parent. - """ - - def __init__(self, select, is_where): - NoColumnVisitor.__init__(self) - self.select = select - self.is_where = is_where - - def visit_compound_select(self, cs): - self.visit_select(cs) - - def visit_column(self, c): - pass + def get_display_froms(self, correlation_state=None): + froms = util.Set() + hide_froms = util.Set() + + for col in self._raw_columns: + for f in col._hide_froms(): + hide_froms.add(f) + for f in col._get_from_objects(): + froms.add(f) - def visit_table(self, c): - pass + if self._whereclause is not None: + for f in self._whereclause._get_from_objects(is_where=True): + froms.add(f) + + for elem in self._froms: + froms.add(elem) + for f in elem._get_from_objects(): + froms.add(f) - def visit_select(self, select): - if select is self.select: - return - select.is_where = self.is_where - select.is_subquery = True - if not select.should_correlate: - return - [select.correlate(x) for x in self.select._Select__froms] + for elem in froms: + for f in elem._hide_froms(): + hide_froms.add(f) - class _FromVisitor(NoColumnVisitor): - def __init__(self, select): - NoColumnVisitor.__init__(self) - self.select = select + froms = froms.difference(hide_froms) + + if len(froms) > 1: + corr = self.__correlate + if correlation_state is not None: + corr = correlation_state[self].get('correlate', util.Set()).union(corr) + return froms.difference(corr) + else: + return froms + + def locate_all_froms(self): + froms = util.Set() + for col in self._raw_columns: + for f in col._get_from_objects(): + froms.add(f) + + if self._whereclause is not None: + for f in self._whereclause._get_from_objects(is_where=True): + froms.add(f) + + for elem in self._froms: + froms.add(elem) + for f in elem._get_from_objects(): + froms.add(f) + return froms + + def calculate_correlations(self, correlation_state): + if self not in correlation_state: + correlation_state[self] = {} + + display_froms = self.get_display_froms(correlation_state) + + class CorrelatedVisitor(NoColumnVisitor): + def __init__(self, is_where=False, is_column=False, is_from=False): + self.is_where = is_where + self.is_column = is_column + self.is_from = is_from + + def visit_compound_select(self, cs): + self.visit_select(cs) + + def visit_select(s, select): + if select not in correlation_state: + correlation_state[select] = {} + + if select is self: + return + + select_state = correlation_state[select] + if s.is_from: + select_state['is_selected_from'] = True + if s.is_where: + select_state['is_where'] = True + select_state['is_subquery'] = True + + if select._should_correlate: + corr = select_state.setdefault('correlate', util.Set()) + # not crazy about this part. need to be clearer on what elements in the + # subquery correspond to elements in the enclosing query. + for f in display_froms: + corr.add(f) + for f2 in f._get_from_objects(): + corr.add(f2) + + col_vis = CorrelatedVisitor(is_column=True) + where_vis = CorrelatedVisitor(is_where=True) + from_vis = CorrelatedVisitor(is_from=True) + + for col in self._raw_columns: + col_vis.traverse(col) + for f in col._get_from_objects(): + if f is not self: + from_vis.traverse(f) + + for col in list(self._order_by_clause) + list(self._group_by_clause): + col_vis.traverse(col) + + if self._whereclause is not None: + where_vis.traverse(self._whereclause) + for f in self._whereclause._get_from_objects(is_where=True): + if f is not self: + from_vis.traverse(f) + + for elem in self._froms: + from_vis.traverse(elem) + + def _get_inner_columns(self): + for c in self._raw_columns: + if hasattr(c, '_selectable'): + for co in c._selectable().columns: + yield co + else: + yield c + + inner_columns = property(_get_inner_columns) + + def get_children(self, clone=False, column_collections=True, **kwargs): + if clone: + self._clone_from_clause() + self._raw_columns = [c._clone() for c in self._raw_columns] + self._recorrelate_froms([f._clone() for f in self._froms]) + for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'): + if getattr(self, attr) is not None: + setattr(self, attr, getattr(self, attr)._clone()) + + return (column_collections and list(self.columns) or []) + \ + list(self._froms) + \ + [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None] + + def _recorrelate_froms(self, froms): + newcorrelate = util.Set() + for f in froms: + if f in self.__correlate: + newcorrelate.add(cl) + self.__correlate.remove(f) + self.__correlate = self.__correlate.union(newcorrelate) + self._froms = froms + + def column(self, column): + s = self._generate() + s.append_column(column) + return s + + def where(self, whereclause): + s = self._generate() + s.append_whereclause(whereclause) + return s + + def having(self, having): + s = self._generate() + s.append_having(having) + return s + + def distinct(self): + s = self._generate() + s.distinct = True + return s + + def select_from(self, fromclause): + s = self._generate() + s.append_from(fromclause) + return s + + def correlate_to(self, fromclause): + s = self._generate() + s.append_correlation(fromclause) + return s + + def append_correlation(self, fromclause): + self.__correlate.add(fromclause) - def visit_select(self, select): - if select is self.select: - return - select.is_selected_from = True - select.is_subquery = True - def append_column(self, column): if _is_literal(column): column = literal_column(str(column)) @@ -2878,22 +2870,26 @@ class Select(_SelectBaseMixin, FromClause): self._raw_columns.append(column) - if self.is_scalar and not hasattr(self, 'type'): - self.type = column.type - - # if the column is a Select statement itself, - # accept visitor - self.__correlator.traverse(column) + def append_whereclause(self, whereclause): + if self._whereclause is not None: + self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) + else: + self._whereclause = _literal_as_text(whereclause) + + def append_having(self, having): + if self._having is not None: + self._having = and_(self._having, _literal_as_text(having)) + else: + self._having = _literal_as_text(having) - # visit the FROM objects of the column looking for more Selects - for f in column._get_from_objects(): - if f is not self: - self.__correlator.traverse(f) - self._process_froms(column, False) + def append_from(self, fromclause): + if _is_literal(fromclause): + fromclause = FromClause(fromclause) + self._froms.add(fromclause) def _make_proxy(self, selectable, name): if self.is_scalar: - return self._raw_columns[0]._make_proxy(selectable, name) + return list(self.inner_columns)[0]._make_proxy(selectable, name) else: raise exceptions.InvalidRequestError("Not a scalar select statement") @@ -2903,6 +2899,13 @@ class Select(_SelectBaseMixin, FromClause): else: return label(name, self) + def _get_type(self): + if self.is_scalar: + return list(self.inner_columns)[0].type + else: + return None + type = property(_get_type) + def _exportable_columns(self): return [c for c in self._raw_columns if isinstance(c, Selectable)] @@ -2912,51 +2915,11 @@ class Select(_SelectBaseMixin, FromClause): else: return column._make_proxy(self) - def _process_froms(self, elem, asfrom): - for f in elem._get_from_objects(): - self.__fromvisitor.traverse(f) - self.__froms.add(f) - if asfrom: - self.__froms.add(elem) - for f in elem._hide_froms(): - self.__hide_froms.add(f) - def self_group(self, against=None): return _Grouping(self) - - def append_whereclause(self, whereclause): - self._append_condition('whereclause', whereclause) - - def append_having(self, having): - self._append_condition('having', having) - - def _append_condition(self, attribute, condition): - if type(condition) == str: - condition = _TextClause(condition) - self.__wherecorrelator.traverse(condition) - self._process_froms(condition, False) - if getattr(self, attribute) is not None: - setattr(self, attribute, and_(getattr(self, attribute), condition)) - else: - setattr(self, attribute, condition) - - def correlate(self, from_obj): - """Given a ``FROM`` object, correlate this ``SELECT`` statement to it. - - This basically means the given from object will not come out - in this select statement's ``FROM`` clause when printed. - """ - - self.__correlated[from_obj] = from_obj - - def append_from(self, fromclause): - if type(fromclause) == str: - fromclause = FromClause(fromclause) - self.__correlator.traverse(fromclause) - self._process_froms(fromclause, True) def _locate_oid_column(self): - for f in self.__froms: + for f in self.locate_all_froms(): if f is self: # we might be in our own _froms list if a column with us as the parent is attached, # which includes textual columns. @@ -2967,25 +2930,6 @@ class Select(_SelectBaseMixin, FromClause): else: return None - def _calc_froms(self): - f = self.__froms.difference(self.__hide_froms) - if (len(f) > 1): - return f.difference(self.__correlated) - else: - return f - - froms = property(_calc_froms, - doc="""A collection containing all elements - of the ``FROM`` clause.""") - - def get_children(self, column_collections=True, **kwargs): - return (column_collections and list(self.columns) or []) + \ - list(self.froms) + \ - [x for x in (self.whereclause, self.having, self.order_by_clause, self.group_by_clause) if x is not None] - - def accept_visitor(self, visitor): - visitor.visit_select(self) - def union(self, other, **kwargs): return union(self, other, **kwargs) @@ -2999,7 +2943,7 @@ class Select(_SelectBaseMixin, FromClause): if self._engine is not None: return self._engine - for f in self.__froms: + for f in self._froms: if f is self: continue e = f.engine @@ -3024,20 +2968,24 @@ class _UpdateBase(ClauseElement): def supports_execution(self): return True - class _SelectCorrelator(NoColumnVisitor): - def __init__(self, table): - NoColumnVisitor.__init__(self) - self.table = table - - def visit_select(self, select): - if select.should_correlate: - select.correlate(self.table) - - def _process_whereclause(self, whereclause): - if whereclause is not None: - _UpdateBase._SelectCorrelator(self.table).traverse(whereclause) - return whereclause - + def calculate_correlations(self, correlate_state): + class SelectCorrelator(NoColumnVisitor): + def visit_select(s, select): + if select._should_correlate: + select_state = correlate_state.setdefault(select, {}) + corr = select_state.setdefault('correlate', util.Set()) + corr.add(self.table) + + vis = SelectCorrelator() + + if self._whereclause is not None: + vis.traverse(self._whereclause) + + if getattr(self, 'parameters', None) is not None: + for key, value in self.parameters.items(): + if isinstance(value, ClauseElement): + vis.traverse(value) + def _process_colparams(self, parameters): """Receive the *values* of an ``INSERT`` or ``UPDATE`` statement and construct appropriate bind parameters. @@ -3054,11 +3002,10 @@ class _UpdateBase(ClauseElement): i +=1 parameters = pp - correlator = _UpdateBase._SelectCorrelator(self.table) for key in parameters.keys(): value = parameters[key] if isinstance(value, ClauseElement): - correlator.traverse(value) + pass elif _is_literal(value): if _is_literal(key): col = self.table.c[key] @@ -3073,7 +3020,7 @@ class _UpdateBase(ClauseElement): def _find_engine(self): return self.table.engine -class _Insert(_UpdateBase): +class Insert(_UpdateBase): def __init__(self, table, values=None): self.table = table self.select = None @@ -3084,32 +3031,26 @@ class _Insert(_UpdateBase): return self.select, else: return () - def accept_visitor(self, visitor): - visitor.visit_insert(self) -class _Update(_UpdateBase): +class Update(_UpdateBase): def __init__(self, table, whereclause, values=None): self.table = table - self.whereclause = self._process_whereclause(whereclause) + self._whereclause = whereclause self.parameters = self._process_colparams(values) def get_children(self, **kwargs): - if self.whereclause is not None: - return self.whereclause, + if self._whereclause is not None: + return self._whereclause, else: return () - def accept_visitor(self, visitor): - visitor.visit_update(self) -class _Delete(_UpdateBase): +class Delete(_UpdateBase): def __init__(self, table, whereclause): self.table = table - self.whereclause = self._process_whereclause(whereclause) + self._whereclause = whereclause def get_children(self, **kwargs): - if self.whereclause is not None: - return self.whereclause, + if self._whereclause is not None: + return self._whereclause, else: return () - def accept_visitor(self, visitor): - visitor.visit_delete(self) diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py index 7a67402318..36d127c98c 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql_util.py @@ -125,7 +125,7 @@ class AbstractClauseProcessor(sql.NoColumnVisitor): process the new list. """ - list_ = [o.copy_container() for o in list_] + list_ = list(list_) self.process_list(list_) return list_ @@ -137,7 +137,7 @@ class AbstractClauseProcessor(sql.NoColumnVisitor): if elem is not None: list_[i] = elem else: - self.traverse(list_[i]) + list_[i] = self.traverse(list_[i], clone=True) def visit_grouping(self, grouping): elem = self.convert_element(grouping.elem) @@ -162,8 +162,25 @@ class AbstractClauseProcessor(sql.NoColumnVisitor): elem = self.convert_element(binary.right) if elem is not None: binary.right = elem - - # TODO: visit_select(). + + def visit_select(self, select): + fr = util.OrderedSet() + for elem in select._froms: + n = self.convert_element(elem) + if n is None: + fr.add(elem) + else: + fr.add(n) + select._recorrelate_froms(fr) + + col = [] + for elem in select._raw_columns: + n = self.convert_element(elem) + if n is None: + col.append(elem) + else: + col.append(n) + select._raw_columns = col class ClauseAdapter(AbstractClauseProcessor): """Given a clause (like as in a WHERE criterion), locate columns @@ -200,6 +217,9 @@ class ClauseAdapter(AbstractClauseProcessor): self.equivalents = equivalents def convert_element(self, col): + if isinstance(col, sql.FromClause): + if self.selectable.is_derived_from(col): + return self.selectable if not isinstance(col, sql.ColumnElement): return None if self.include is not None: diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 38f06584fc..a0088f1366 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -55,9 +55,9 @@ else: self[key] = value = self.creator(key) return value -def to_list(x): +def to_list(x, default=None): if x is None: - return None + return default if not isinstance(x, list) and not isinstance(x, tuple): return [x] else: diff --git a/test/orm/generative.py b/test/orm/generative.py index ad07b5b21a..a83b81758a 100644 --- a/test/orm/generative.py +++ b/test/orm/generative.py @@ -177,7 +177,7 @@ class RelationsTest(AssertMixin): }) session = create_session() query = session.query(tables.User) - x = query.outerjoin('orders').outerjoin('items').filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)) + x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)) print x.compile() self.assert_result(list(x), tables.User, *tables.user_result[1:3]) def test_outerjointo_count(self): @@ -189,7 +189,7 @@ class RelationsTest(AssertMixin): }) session = create_session() query = session.query(tables.User) - x = query.outerjoin('orders').outerjoin('items').filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count() + x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count() assert x==2 def test_from(self): mapper(tables.User, tables.users, properties={ diff --git a/test/orm/inheritance/poly_linked_list.py b/test/orm/inheritance/poly_linked_list.py index a9482f28c9..7858689b1d 100644 --- a/test/orm/inheritance/poly_linked_list.py +++ b/test/orm/inheritance/poly_linked_list.py @@ -29,14 +29,15 @@ class PolymorphicCircularTest(testbase.ORMTest): Column('data', String(30)) ) - join = polymorphic_union( - { - 'table3' : table1.join(table3), - 'table2' : table1.join(table2), - 'table1' : table1.select(table1.c.type.in_('table1', 'table1b')), - }, None, 'pjoin') - - # still with us so far ? + #join = polymorphic_union( + # { + # 'table3' : table1.join(table3), + # 'table2' : table1.join(table2), + # 'table1' : table1.select(table1.c.type.in_('table1', 'table1b')), + # }, None, 'pjoin') + + join = table1.outerjoin(table2).outerjoin(table3).alias('pjoin') + #join = None class Table1(object): def __init__(self, name, data=None): @@ -62,10 +63,10 @@ class PolymorphicCircularTest(testbase.ORMTest): return "%s(%d, %s)" % (self.__class__.__name__, self.id, repr(str(self.data))) try: - # this is how the mapping used to work. insure that this raises an error now + # this is how the mapping used to work. ensure that this raises an error now table1_mapper = mapper(Table1, table1, select_table=join, - polymorphic_on=join.c.type, + polymorphic_on=table1.c.type, polymorphic_identity='table1', properties={ 'next': relation(Table1, @@ -86,8 +87,8 @@ class PolymorphicCircularTest(testbase.ORMTest): # exception now. since eager loading would never work for that relation anyway, its better that the user # gets an exception instead of it silently not eager loading. table1_mapper = mapper(Table1, table1, - select_table=join, - polymorphic_on=join.c.type, + #select_table=join, + polymorphic_on=table1.c.type, polymorphic_identity='table1', properties={ 'next': relation(Table1, @@ -104,7 +105,10 @@ class PolymorphicCircularTest(testbase.ORMTest): polymorphic_identity='table2') table3_mapper = mapper(Table3, table3, inherits=table1_mapper, polymorphic_identity='table3') - + + table1_mapper.compile() + assert table1_mapper.primary_key == [table1.c.id], table1_mapper.primary_key + def testone(self): self.do_testlist([Table1, Table2, Table1, Table2]) diff --git a/test/orm/mapper.py b/test/orm/mapper.py index b061416ae8..ab53e36c45 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -328,6 +328,10 @@ class MapperTest(MapperSuperTest): 'concat': column_property(f), 'count': column_property(select([func.count(addresses.c.address_id)], users.c.user_id==addresses.c.user_id, scalar=True).label('count')) }) + + mapper(Address, addresses, properties={ + 'user':relation(User, lazy=False) + }) sess = create_session() l = sess.query(User).select() @@ -336,24 +340,19 @@ class MapperTest(MapperSuperTest): assert l[0].concat == l[0].user_id * 2 == 14 assert l[1].concat == l[1].user_id * 2 == 16 - ### eager loads, not really working across all DBs, no column aliasing in place so - # results still wont be good for larger situations - clear_mappers() - mapper(Address, addresses, properties={ - 'user':relation(User, lazy=False) - }) - - mapper(User, users, properties={ - 'concat': column_property(f), - }) - - for x in range(0, 2): - sess.clear() - l = sess.query(Address).select() - for a in l: - print "User", a.user.user_id, a.user.user_name, a.user.concat - assert l[0].user.concat == l[0].user.user_id * 2 == 14 - assert l[1].user.concat == l[1].user.user_id * 2 == 16 + for option in (None, eagerload('user')): + for x in range(0, 2): + sess.clear() + l = sess.query(Address) + if option: + l = l.options(option) + l = l.all() + for a in l: + print "User", a.user.user_id, a.user.user_name, a.user.concat, a.user.count + assert l[0].user.concat == l[0].user.user_id * 2 == 14 + assert l[1].user.concat == l[1].user.user_id * 2 == 16 + assert l[0].user.count == 1 + assert l[1].user.count == 3 @testbase.unsupported('firebird') @@ -1114,6 +1113,7 @@ class EagerTest(MapperSuperTest): """test eager loading of a mapper which is against a select""" s = select([orders], orders.c.isopen==1).alias('openorders') + print "SELECT:", id(s), str(s) mapper(Order, s, properties={ 'user':relation(User, lazy=False) }) diff --git a/test/perf/masseagerload.py b/test/perf/masseagerload.py index dc3416089f..2d87b391e6 100644 --- a/test/perf/masseagerload.py +++ b/test/perf/masseagerload.py @@ -1,6 +1,7 @@ from testbase import PersistTest, AssertMixin import unittest, sys, os from sqlalchemy import * +from sqlalchemy.orm import * from testbase import Table, Column import StringIO import testbase diff --git a/test/sql/alltests.py b/test/sql/alltests.py index 7be1a3ffb6..ebb3fe34c6 100644 --- a/test/sql/alltests.py +++ b/test/sql/alltests.py @@ -7,6 +7,8 @@ def suite(): 'sql.testtypes', 'sql.constraints', + 'sql.generative', + # SQL syntax 'sql.select', 'sql.selectable', diff --git a/test/sql/generative.py b/test/sql/generative.py new file mode 100644 index 0000000000..befa5f9221 --- /dev/null +++ b/test/sql/generative.py @@ -0,0 +1,212 @@ +import testbase +from sqlalchemy import * + +class TraversalTest(testbase.AssertMixin): + """test ClauseVisitor's traversal, particularly its ability to copy and modify + a ClauseElement in place.""" + + def setUpAll(self): + global A, B + + # establish two ficticious ClauseElements. + # define deep equality semantics as well as deep identity semantics. + class A(ClauseElement): + def __init__(self, expr): + self.expr = expr + + def accept_visitor(self, visitor): + visitor.visit_a(self) + + def is_other(self, other): + return other is self + + def __eq__(self, other): + return other.expr == self.expr + + def __ne__(self, other): + return other.expr != self.expr + + def __str__(self): + return "A(%s)" % repr(self.expr) + + class B(ClauseElement): + def __init__(self, *items): + self.items = items + + def is_other(self, other): + if other is not self: + return False + for i1, i2 in zip(self.items, other.items): + if i1 is not i2: + return False + return True + + def __eq__(self, other): + for i1, i2 in zip(self.items, other.items): + if i1 != i2: + return False + return True + + def __ne__(self, other): + for i1, i2 in zip(self.items, other.items): + if i1 != i2: + return True + return False + + def get_children(self, clone=False, **kwargs): + if clone: + self.items = [i._clone() for i in self.items] + return self.items + + def accept_visitor(self, visitor): + visitor.visit_b(self) + + def __str__(self): + return "B(%s)" % repr([str(i) for i in self.items]) + + def test_test_classes(self): + a1 = A("expr1") + struct = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) + struct2 = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) + struct3 = B(a1, A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3")) + + assert a1.is_other(a1) + assert struct.is_other(struct) + assert struct == struct2 + assert struct != struct3 + assert not struct.is_other(struct2) + assert not struct.is_other(struct3) + + def test_clone(self): + struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) + + class Vis(ClauseVisitor): + def visit_a(self, a): + pass + def visit_b(self, b): + pass + + vis = Vis() + s2 = vis.traverse(struct, clone=True) + assert struct == s2 + assert not struct.is_other(s2) + + def test_no_clone(self): + struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) + + class Vis(ClauseVisitor): + def visit_a(self, a): + pass + def visit_b(self, b): + pass + + vis = Vis() + s2 = vis.traverse(struct, clone=False) + assert struct == s2 + assert struct.is_other(s2) + + def test_change_in_place(self): + struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")) + struct2 = B(A("expr1"), A("expr2modified"), B(A("expr1b"), A("expr2b")), A("expr3")) + struct3 = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3")) + + class Vis(ClauseVisitor): + def visit_a(self, a): + if a.expr == "expr2": + a.expr = "expr2modified" + def visit_b(self, b): + pass + + vis = Vis() + s2 = vis.traverse(struct, clone=True) + assert struct != s2 + assert struct is not s2 + assert struct2 == s2 + + class Vis2(ClauseVisitor): + def visit_a(self, a): + if a.expr == "expr2b": + a.expr = "expr2bmodified" + def visit_b(self, b): + pass + + vis2 = Vis2() + s3 = vis2.traverse(struct, clone=True) + assert struct != s3 + assert struct3 == s3 + +class ClauseTest(testbase.AssertMixin): + def setUpAll(self): + global t1, t2 + t1 = table("table1", + column("col1"), + column("col2"), + column("col3"), + ) + t2 = table("table2", + column("col1"), + column("col2"), + column("col3"), + ) + + def test_binary(self): + clause = t1.c.col2 == t2.c.col2 + assert str(clause) == ClauseVisitor().traverse(clause, clone=True) + + def test_join(self): + clause = t1.join(t2, t1.c.col2==t2.c.col2) + c1 = str(clause) + assert str(clause) == str(ClauseVisitor().traverse(clause, clone=True)) + + class Vis(ClauseVisitor): + def visit_binary(self, binary): + binary.right = t2.c.col3 + + clause2 = Vis().traverse(clause, clone=True) + assert c1 == str(clause) + assert str(clause2) == str(t1.join(t2, t1.c.col2==t2.c.col3)) + + def test_select(self): + s = t1.select() + s2 = select([s]) + s2_assert = str(s2) + s3_assert = str(select([t1.select()], t1.c.col2==7)) + class Vis(ClauseVisitor): + def visit_select(self, select): + select.append_whereclause(t1.c.col2==7) + s3 = Vis().traverse(s2, clone=True) + assert str(s3) == s3_assert + assert str(s2) == s2_assert + print str(s2) + print str(s3) + Vis().traverse(s2) + assert str(s2) == s3_assert + + print "------------------" + + s4_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col3==9))) + class Vis(ClauseVisitor): + def visit_select(self, select): + select.append_whereclause(t1.c.col3==9) + s4 = Vis().traverse(s3, clone=True) + print str(s3) + print str(s4) + assert str(s4) == s4_assert + assert str(s3) == s3_assert + + print "------------------" + s5_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col1==9))) + class Vis(ClauseVisitor): + def visit_binary(self, binary): + if binary.left is t1.c.col3: + binary.left = t1.c.col1 + binary.right = bindparam("table1_col1") + s5 = Vis().traverse(s4, clone=True) + print str(s4) + print str(s5) + assert str(s5) == s5_assert + assert str(s4) == s4_assert + + +if __name__ == '__main__': + testbase.main() \ No newline at end of file diff --git a/test/sql/select.py b/test/sql/select.py index 157c623000..ad2fd13e3c 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -143,6 +143,11 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A self.runtest(select([table1, exists([1], from_obj=[table2]).label('foo')]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) AS foo FROM mytable", params={}) def testwheresubquery(self): + s = select([addresses.c.street], addresses.c.user_id==users.c.user_id, correlate=True).alias('s') + self.runtest( + select([users, s.c.street], from_obj=[s]), + """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""") + # TODO: this tests that you dont get a "SELECT column" without a FROM but its not working yet. #self.runtest( # table1.select(table1.c.myid == select([table1.c.myid], table1.c.name=='jack')), "" @@ -223,6 +228,12 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A order_by = ['dist', places.c.nm] ) self.runtest(q, "SELECT places.id, places.nm, main_zip.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE zips.zipcode = main_zip.zipcode), (SELECT zips.longitude FROM zips WHERE zips.zipcode = main_zip.zipcode)) AS dist FROM places, zips AS main_zip ORDER BY dist, places.nm") + + a1 = table2.alias('t2alias') + s1 = select([a1.c.otherid], table1.c.myid==a1.c.otherid, scalar=True) + j1 = table1.join(table2, table1.c.myid==table2.c.otherid) + s2 = select([table1, s1], from_obj=[j1]) + self.runtest(s2, "SELECT mytable.myid, mytable.name, mytable.description, (SELECT t2alias.otherid FROM myothertable AS t2alias WHERE mytable.myid = t2alias.otherid) FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid") def testlabelcomparison(self): x = func.lala(table1.c.myid).label('foo') @@ -410,7 +421,7 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = s.append_column("column2") s.append_whereclause("column1=12") s.append_whereclause("column2=19") - s.order_by("column1") + s = s.order_by("column1") s.append_from("table1") self.runtest(s, "SELECT column1, column2 FROM table1 WHERE column1=12 AND column2=19 ORDER BY column1") @@ -850,16 +861,6 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE ) - def testlateargs(self): - """tests that a SELECT clause will have extra "WHERE" clauses added to it at compile time if extra arguments - are sent""" - - self.runtest(table1.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name = :mytable_name AND mytable.myid = :mytable_myid", params={'myid':'3', 'name':'jack'}) - - self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3'}) - - self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3', 'name':'fred'}) - def testcast(self): tbl = table('casttest', column('id', Integer), @@ -969,7 +970,18 @@ class CRUDTest(SQLTest): def testdelete(self): self.runtest(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid") - + + def testcorrelateddelete(self): + # test a non-correlated WHERE clause + s = select([table2.c.othername], table2.c.otherid == 7) + u = delete(table1, table1.c.name==s) + self.runtest(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = :myothertable_otherid)") + + # test one that is actually correlated... + s = select([table2.c.othername], table2.c.otherid == table1.c.myid) + u = table1.delete(table1.c.name==s) + self.runtest(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)") + class SchemaTest(SQLTest): def testselect(self): # these tests will fail with the MS-SQL compiler since it will alias schema-qualified tables -- 2.47.3