From: Mike Bayer Date: Sun, 25 Nov 2007 03:28:49 +0000 (+0000) Subject: - named_with_column becomes an attribute X-Git-Tag: rel_0_4_2~149 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=cf18eecd704f5eb6fde4e0c362cfdb322e3e559a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - named_with_column becomes an attribute - cleanup within compiler visit_select(), column labeling - is_select() removed from dialects, replaced with returns_rows_text(), returns_rows_compiled() - should_autocommit() removed from dialects, replaced with should_autocommit_text() and should_autocommit_compiled() - typemap and column_labels collections removed from Compiler, replaced with single "result_map" collection. - ResultProxy uses more succinct logic in combination with result_map to target columns --- diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py index d57c9fa9f6..354a8c3322 100644 --- a/lib/sqlalchemy/databases/access.py +++ b/lib/sqlalchemy/databases/access.py @@ -356,11 +356,11 @@ class AccessCompiler(compiler.DefaultCompiler): """Access uses "mod" instead of "%" """ return binary.operator == '%' and 'mod' or binary.operator - def label_select_column(self, select, column): + def label_select_column(self, select, column, asfrom): if isinstance(column, expression._Function): - return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:]) + return column.label() else: - return super(AccessCompiler, self).label_select_column(select, column) + return super(AccessCompiler, self).label_select_column(select, column, asfrom) function_rewrites = {'current_date': 'now', 'current_timestamp': 'now', diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index 247ab2d419..6b01bfc224 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -409,15 +409,6 @@ class InfoCompiler(compiler.DefaultCompiler): def limit_clause(self, select): return "" - def __visit_label(self, label): - # TODO: whats this method for ? - if self.select_stack: - self.typemap.setdefault(label.name.lower(), label.obj.type) - if self.strings[label.obj]: - self.strings[label] = self.strings[label.obj] + " AS " + label.name - else: - self.strings[label] = None - def visit_function( self , func ): if func.name.lower() == 'current_date': return "today" diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 672f8d77cf..469355083b 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -339,8 +339,8 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): _ms_is_select = re.compile(r'\s*(?:SELECT|sp_columns)', re.I | re.UNICODE) - def is_select(self): - return self._ms_is_select.match(self.statement) is not None + def returns_rows_text(self, statement): + return self._ms_is_select.match(statement) is not None class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext): @@ -910,11 +910,11 @@ class MSSQLCompiler(compiler.DefaultCompiler): else: return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) - def label_select_column(self, select, column): + def label_select_column(self, select, column, asfrom): if isinstance(column, expression._Function): return column.label(None) else: - return super(MSSQLCompiler, self).label_select_column(select, column) + return super(MSSQLCompiler, self).label_select_column(select, column, asfrom) function_rewrites = {'current_date': 'getdate', 'length': 'len', diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 39bfc0beaa..03b9a749ce 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -1378,9 +1378,6 @@ def descriptor(): class MySQLExecutionContext(default.DefaultExecutionContext): - _my_is_select = re.compile(r'\s*(?:SELECT|SHOW|DESCRIBE|XA +RECOVER)', - re.I | re.UNICODE) - def post_exec(self): if self.compiled.isinsert and not self.executemany: if (not len(self._last_inserted_ids) or @@ -1388,11 +1385,11 @@ class MySQLExecutionContext(default.DefaultExecutionContext): self._last_inserted_ids = ([self.cursor.lastrowid] + self._last_inserted_ids[1:]) - def is_select(self): - return SELECT_RE.match(self.statement) + def returns_rows_text(self, statement): + return SELECT_RE.match(statement) - def should_autocommit(self): - return AUTOCOMMIT_RE.match(self.statement) + def should_autocommit_text(self, statement): + return AUTOCOMMIT_RE.match(statement) class MySQLDialect(default.DefaultDialect): @@ -1873,9 +1870,6 @@ class MySQLCompiler(compiler.DefaultCompiler): if type_ is None: return self.process(cast.clause) - if self.stack and self.stack[-1].get('select'): - # not sure if we want to set the typemap here... - self.typemap.setdefault("CAST", cast.type) return 'CAST(%s AS %s)' % (self.process(cast.clause), type_) diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 88ac0e2026..1cae31b537 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -233,16 +233,24 @@ RETURNING_QUOTED_RE = re.compile( class PGExecutionContext(default.DefaultExecutionContext): - def is_select(self): - m = SELECT_RE.match(self.statement) - return m and (not m.group(1) or (RETURNING_RE.search(self.statement) - and RETURNING_QUOTED_RE.match(self.statement))) + def returns_rows_text(self, statement): + m = SELECT_RE.match(statement) + return m and (not m.group(1) or (RETURNING_RE.search(statement) + and RETURNING_QUOTED_RE.match(statement))) + + def returns_rows_compiled(self, compiled): + return isinstance(compiled.statement, expression.Selectable) or \ + ( + (compiled.isupdate or compiled.isinsert) and "postgres_returning" in compiled.statement.kwargs + ) def create_cursor(self): # executing a default or Sequence standalone creates an execution context without a statement. # so slightly hacky "if no statement assume we're server side" logic + # TODO: dont use regexp if Compiled is used ? self.__is_server_side = \ - self.dialect.server_side_cursors and (self.statement is None or \ + self.dialect.server_side_cursors and \ + (self.statement is None or \ (SELECT_RE.match(self.statement) and not re.search(r'FOR UPDATE(?: NOWAIT)?\s*$', self.statement, re.I)) ) diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 19d0855ff3..16dd9427c0 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -185,8 +185,8 @@ class SQLiteExecutionContext(default.DefaultExecutionContext): if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] - def is_select(self): - return SELECT_REGEXP.match(self.statement) + def returns_rows_text(self, statement): + return SELECT_REGEXP.match(statement) class SQLiteDialect(default.DefaultDialect): supports_alter = False @@ -343,9 +343,6 @@ class SQLiteCompiler(compiler.DefaultCompiler): if self.dialect.supports_cast: return super(SQLiteCompiler, self).visit_cast(cast) else: - if self.stack and self.stack[-1].get('select'): - # not sure if we want to set the typemap here... - self.typemap.setdefault("CAST", cast.type) return self.process(cast.clause) def limit_clause(self, select): diff --git a/lib/sqlalchemy/databases/sybase.py b/lib/sqlalchemy/databases/sybase.py index 87045d1926..2209594ed7 100644 --- a/lib/sqlalchemy/databases/sybase.py +++ b/lib/sqlalchemy/databases/sybase.py @@ -778,11 +778,11 @@ class SybaseSQLCompiler(compiler.DefaultCompiler): else: return super(SybaseSQLCompiler, self).visit_binary(binary) - def label_select_column(self, select, column): + def label_select_column(self, select, column, asfrom): if isinstance(column, expression._Function): - return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:]) + return column.label(None) else: - return super(SybaseSQLCompiler, self).label_select_column(select, column) + return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom) function_rewrites = {'current_date': 'getdate', } @@ -795,13 +795,7 @@ class SybaseSQLCompiler(compiler.DefaultCompiler): cast = expression._Cast(func, SybaseDate_mxodbc) # infinite recursion # res = self.visit_cast(cast) - if self.stack and self.stack[-1].get('select'): - # not sure if we want to set the typemap here... - self.typemap.setdefault("CAST", cast.type) -# res = "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause)) -# elif func.name.lower() == 'count': -# res = 'count(*)' return res def for_update_clause(self, select): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 21977b689b..9e30043253 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -315,6 +315,12 @@ class ExecutionContext(object): isupdate True if the statement is an UPDATE. + should_autocommit + True if the statement is a "committable" statement + + returns_rows + True if the statement should return result rows + The Dialect should provide an ExecutionContext via the create_execution_context() method. The `pre_exec` and `post_exec` methods will be called for compiled statements. @@ -363,8 +369,13 @@ class ExecutionContext(object): raise NotImplementedError() - def should_autocommit(self): - """Return True if this context's statement should be 'committed' automatically in a non-transactional context""" + def should_autocommit_compiled(self, compiled): + """return True if the given Compiled object refers to a "committable" statement.""" + + raise NotImplementedError() + + def should_autocommit_text(self, statement): + """Parse the given textual statement and return True if it refers to a "committable" statement""" raise NotImplementedError() @@ -750,7 +761,7 @@ class Connection(Connectable): # TODO: have the dialect determine if autocommit can be set on # the connection directly without this extra step - if not self.in_transaction() and context.should_autocommit(): + if not self.in_transaction() and context.should_autocommit: self._commit_impl() def _autorollback(self): @@ -1305,7 +1316,7 @@ class ResultProxy(object): self.cursor = context.cursor self.connection = context.root_connection self.__echo = context.engine._should_log_info - if context.is_select(): + if context.returns_rows: self._init_metadata() self._rowcount = None else: @@ -1322,8 +1333,6 @@ class ResultProxy(object): out_parameters = property(lambda s:s.context.out_parameters) def _init_metadata(self): - if hasattr(self, '_ResultProxy__props'): - return self.__props = {} self._key_cache = self._create_key_cache() self.__keys = [] @@ -1336,20 +1345,24 @@ class ResultProxy(object): # sqlite possibly prepending table name to colnames so strip colname = (item[0].split('.')[-1]).decode(self.dialect.encoding) - if self.context.typemap is not None: - type = self.context.typemap.get(colname.lower(), typemap.get(item[1], types.NULLTYPE)) + if self.context.result_map: + try: + (name, obj, type_) = self.context.result_map[colname] + except KeyError: + (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE)) else: - type = typemap.get(item[1], types.NULLTYPE) + (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE)) - rec = (type, type.dialect_impl(self.dialect).result_processor(self.dialect), i) + rec = (type_, type_.dialect_impl(self.dialect).result_processor(self.dialect), i) - if rec[0] is None: - raise exceptions.InvalidRequestError( - "None for metadata " + colname) - if self.__props.setdefault(colname.lower(), rec) is not rec: - self.__props[colname.lower()] = (type, self.__ambiguous_processor(colname), 0) + if self.__props.setdefault(name.lower(), rec) is not rec: + self.__props[name.lower()] = (type_, self.__ambiguous_processor(colname), 0) + self.__keys.append(colname) self.__props[i] = rec + if obj: + for o in obj: + self.__props[o] = rec if self.__echo: self.context.engine.logger.debug("Col " + repr(tuple([x[0] for x in metadata]))) @@ -1362,16 +1375,19 @@ class ResultProxy(object): """Given a key, which could be a ColumnElement, string, etc., matches it to the appropriate key we got from the result set's metadata; then cache it locally for quick re-access.""" - - if isinstance(key, int) and key in props: + + if isinstance(key, basestring): + key = key.lower() + + try: rec = props[key] - elif isinstance(key, basestring) and key.lower() in props: - rec = props[key.lower()] - elif isinstance(key, expression.ColumnElement): - label = context.column_labels.get(key._label, key.name).lower() - if label in props: - rec = props[label] - if not "rec" in locals(): + except KeyError: + # fallback for targeting a ColumnElement to a textual expression + if isinstance(key, expression.ColumnElement): + if key._label.lower() in props: + return props[key._label.lower()] + elif key.name.lower() in props: + return props[key.name.lower()] raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key))) return rec @@ -1470,18 +1486,20 @@ class ResultProxy(object): def _get_col(self, row, key): try: - rec = self._key_cache[key] + type_, processor, index = self._key_cache[key] except TypeError: # the 'slice' use case is very infrequent, # so we use an exception catch to reduce conditionals in _get_col if isinstance(key, slice): indices = key.indices(len(row)) return tuple([self._get_col(row, i) for i in xrange(*indices)]) - - if rec[1]: - return rec[1](row[rec[2]]) + else: + raise + + if processor: + return processor(row[index]) else: - return row[rec[2]] + return row[index] def _fetchone_impl(self): return self.cursor.fetchone() diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index a91d65b81f..19ab22c9e9 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -146,9 +146,8 @@ class DefaultExecutionContext(base.ExecutionContext): if value is not None ]) - self.typemap = compiled.typemap - self.column_labels = compiled.column_labels - + self.result_map = compiled.result_map + if not dialect.supports_unicode_statements: self.statement = unicode(compiled).encode(self.dialect.encoding) else: @@ -156,6 +155,12 @@ class DefaultExecutionContext(base.ExecutionContext): self.isinsert = compiled.isinsert self.isupdate = compiled.isupdate + if isinstance(compiled.statement, expression._TextClause): + self.returns_rows = self.returns_rows_text(self.statement) + self.should_autocommit = self.should_autocommit_text(self.statement) + else: + self.returns_rows = self.returns_rows_compiled(compiled) + self.should_autocommit = self.should_autocommit_compiled(compiled) if not parameters: self.compiled_parameters = [compiled.construct_params()] @@ -170,7 +175,7 @@ class DefaultExecutionContext(base.ExecutionContext): elif statement is not None: # plain text statement. - self.typemap = self.column_labels = None + self.result_map = None self.parameters = self.__encode_param_keys(parameters) self.executemany = len(parameters) > 1 if not dialect.supports_unicode_statements: @@ -179,10 +184,12 @@ class DefaultExecutionContext(base.ExecutionContext): self.statement = statement self.isinsert = self.isupdate = False self.cursor = self.create_cursor() + self.returns_rows = self.returns_rows_text(statement) + self.should_autocommit = self.should_autocommit_text(statement) else: # no statement. used for standalone ColumnDefault execution. self.statement = None - self.isinsert = self.isupdate = self.executemany = False + self.isinsert = self.isupdate = self.executemany = self.returns_rows = self.should_autocommit = False self.cursor = self.create_cursor() connection = property(lambda s:s._connection._branch()) @@ -244,10 +251,18 @@ class DefaultExecutionContext(base.ExecutionContext): parameters.append(param) return parameters - def is_select(self): - """return TRUE if the statement is expected to have result rows.""" + def returns_rows_compiled(self, compiled): + return isinstance(compiled.statement, expression.Selectable) - return SELECT_REGEXP.match(self.statement) + def returns_rows_text(self, statement): + return SELECT_REGEXP.match(statement) + + def should_autocommit_compiled(self, compiled): + return isinstance(compiled.statement, expression._UpdateBase) + + def should_autocommit_text(self, statement): + return AUTOCOMMIT_REGEXP.match(statement) + def create_cursor(self): return self._connection.connection.cursor() @@ -261,9 +276,6 @@ class DefaultExecutionContext(base.ExecutionContext): def result(self): return self.get_result_proxy() - def should_autocommit(self): - return AUTOCOMMIT_REGEXP.match(self.statement) - def pre_exec(self): pass diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 1214025849..3daf11ed0c 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -249,7 +249,7 @@ class Query(object): # alias non-labeled column elements. if isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'): column = column.label(None) - + q._entities = q._entities + [(column, None, id)] return q @@ -887,7 +887,7 @@ class Query(object): context.exec_with_path(self.select_mapper, value.key, value.setup, context, parentclauses=clauses) elif isinstance(m, sql.ColumnElement): if clauses is not None: - m = clauses.adapt_clause(m) + m = clauses.aliased_column(m) context.secondary_columns.append(m) if self._eager_loaders and self._nestable(**self._select_args()): diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 0e1e5f7a9c..8179810033 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -456,7 +456,7 @@ class Column(SchemaItem, expression._ColumnClause): def __str__(self): if self.table is not None: - if self.table.named_with_column(): + if self.table.named_with_column: return (self.table.description + "." + self.description) else: return self.description diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c1f3bc2a05..a31997d1b3 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -130,13 +130,11 @@ class DefaultCompiler(engine.Compiled): # a stack. what recursive compiler doesn't have a stack ? :) self.stack = [] - # a dictionary of result-set column names (strings) to TypeEngine instances, - # which will be passed to a ResultProxy and used for resultset-level value conversion - self.typemap = {} - - # a dictionary of select columns labels mapped to their "generated" label - self.column_labels = {} - + # relates label names in the final SQL to + # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine. + # ResultProxy uses this for type processing and column targeting + self.result_map = {} + # a dictionary of ClauseElement subclasses to counters, which are used to # generate truncated identifier names or "anonymous" identifiers such as # for aliases @@ -213,19 +211,15 @@ class DefaultCompiler(engine.Compiled): def visit_grouping(self, grouping, **kwargs): return "(" + self.process(grouping.elem) + ")" - def visit_label(self, label, typemap=None, column_labels=None): + def visit_label(self, label, result_map=None): labelname = self._truncated_identifier("colident", label.name) - if typemap is not None: - self.typemap.setdefault(labelname.lower(), label.obj.type) + if result_map is not None: + result_map[labelname] = (label.name, (label, label.obj), label.obj.type) - if column_labels is not None: - if isinstance(label.obj, sql._ColumnClause): - column_labels[label.obj._label] = labelname - column_labels[label.name] = labelname return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) - def visit_column(self, column, typemap=None, column_labels=None, **kwargs): + def visit_column(self, column, result_map=None, **kwargs): # there is actually somewhat of a ruleset when you would *not* necessarily # want to truncate a column identifier, if its mapped to the name of a # physical column. but thats very hard to identify at this point, and @@ -236,15 +230,13 @@ class DefaultCompiler(engine.Compiled): else: name = column.name - if typemap is not None: - typemap.setdefault(name.lower(), column.type) - if column_labels is not None: - self.column_labels.setdefault(column._label, name.lower()) + if result_map is not None: + result_map[name] = (name, (column, ), column.type) if column._is_oid: n = self.dialect.oid_column_name(column) if n is not None: - if column.table is None or not column.table.named_with_column(): + if column.table is None or not column.table.named_with_column: return n else: return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + n @@ -254,7 +246,7 @@ class DefaultCompiler(engine.Compiled): return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + self.preparer.quote(pk, pkname) else: return None - elif column.table is None or not column.table.named_with_column(): + elif column.table is None or not column.table.named_with_column: if getattr(column, "is_literal", False): return name else: @@ -277,8 +269,9 @@ class DefaultCompiler(engine.Compiled): def visit_textclause(self, textclause, **kwargs): if textclause.typemap is not None: - self.typemap.update(textclause.typemap) - + for colname, type_ in textclause.typemap.iteritems(): + self.result_map[colname] = (colname, None, type_) + def do_bindparam(m): name = m.group(1) if name in textclause.bindparams: @@ -302,7 +295,7 @@ class DefaultCompiler(engine.Compiled): sep = ', ' else: sep = " " + self.operator_string(clauselist.operator) + " " - return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep) + return sep.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None]) def apply_function_parens(self, func): return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 @@ -310,12 +303,13 @@ class DefaultCompiler(engine.Compiled): def visit_calculatedclause(self, clause, **kwargs): return self.process(clause.clause_expr) - def visit_cast(self, cast, typemap=None, **kwargs): + def visit_cast(self, cast, **kwargs): return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) - def visit_function(self, func, typemap=None, **kwargs): - if typemap is not None: - typemap.setdefault(func.name, func.type) + def visit_function(self, func, result_map=None, **kwargs): + if result_map is not None: + result_map[func.name] = (func.name, None, func.type) + if not self.apply_function_parens(func): return ".".join(func.packagenames + [func.name]) else: @@ -325,7 +319,7 @@ class DefaultCompiler(engine.Compiled): stack_entry = {'select':cs} if asfrom: - stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True + stack_entry['is_subquery'] = True elif self.stack and self.stack[-1].get('select'): stack_entry['is_subquery'] = True self.stack.append(stack_entry) @@ -353,7 +347,7 @@ class DefaultCompiler(engine.Compiled): s = s + " " + self.operator_string(unary.modifier) return s - def visit_binary(self, binary, typemap=None, **kwargs): + def visit_binary(self, binary, **kwargs): op = self.operator_string(binary.operator) if callable(op): return op(self.process(binary.left), self.process(binary.right)) @@ -438,22 +432,17 @@ class DefaultCompiler(engine.Compiled): else: return self.process(alias.original, **kwargs) - def label_select_column(self, select, column): - """convert a column from a select's "columns" clause. + def label_select_column(self, select, column, asfrom): + """label columns present in a select().""" - given a select() and a column element from its inner_columns collection, return a - Label object if this column should be labeled in the columns clause. Otherwise, - return None and the column will be used as-is. - - The calling method will traverse the returned label to acquire its string - representation. - """ - - # SQLite doesnt like selecting from a subquery where the column - # names look like table.colname. so if column is in a "selected from" - # subquery, label it synoymously with its column name + if isinstance(column, sql._Label): + return column + + if select.use_labels and column._label: + return column.label(column._label) + if \ - (self.stack and self.stack[-1].get('is_selected_from')) and \ + asfrom and \ isinstance(column, sql._ColumnClause) and \ not column.is_literal and \ column.table is not None and \ @@ -462,20 +451,20 @@ class DefaultCompiler(engine.Compiled): elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) and not hasattr(column, 'name'): return column.label(None) else: - return None + return column def visit_select(self, select, asfrom=False, parens=True, **kwargs): stack_entry = {'select':select} if asfrom: - stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True + stack_entry['is_subquery'] = True column_clause_args = {} elif self.stack and 'select' in self.stack[-1]: stack_entry['is_subquery'] = True column_clause_args = {} else: - column_clause_args = {'typemap':self.typemap, 'column_labels':self.column_labels} + column_clause_args = {'result_map':self.result_map} if self.stack and 'from' in self.stack[-1]: existingfroms = self.stack[-1]['from'] @@ -487,8 +476,7 @@ class DefaultCompiler(engine.Compiled): correlate_froms = util.Set() for f in froms: correlate_froms.add(f) - for f2 in f._get_from_objects(): - correlate_froms.add(f2) + correlate_froms.update(f._get_from_objects()) # TODO: might want to propigate existing froms for select(select(select)) # where innermost select should correlate to outermost @@ -501,19 +489,8 @@ class DefaultCompiler(engine.Compiled): inner_columns = util.OrderedSet() for co in select.inner_columns: - if select.use_labels: - labelname = co._label - if labelname is not None: - l = co.label(labelname) - inner_columns.add(self.process(l, **column_clause_args)) - else: - inner_columns.add(self.process(co, **column_clause_args)) - else: - l = self.label_select_column(select, co) - if l is not None: - inner_columns.add(self.process(l, **column_clause_args)) - else: - inner_columns.add(self.process(co, **column_clause_args)) + l = self.label_select_column(select, co, asfrom=asfrom) + inner_columns.add(self.process(l, **column_clause_args)) collist = string.join(inner_columns.difference(util.Set([None])), ', ') diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index b3200a7eba..039145006e 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1522,6 +1522,7 @@ class FromClause(Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement.""" __visit_name__ = 'fromclause' + named_with_column=False def __init__(self): self.oid_column = None @@ -1562,13 +1563,6 @@ class FromClause(Selectable): return Alias(self, name) - def named_with_column(self): - """True if the name of this FromClause may be prepended to a - column in a generated SQL statement. - """ - - return False - def is_derived_from(self, fromclause): """Return True if this FromClause is 'derived' from the given FromClause. @@ -2379,6 +2373,8 @@ class Alias(FromClause): ``FromClause`` subclasses. """ + named_with_column = True + def __init__(self, selectable, alias=None): baseselectable = selectable while isinstance(baseselectable, Alias): @@ -2386,7 +2382,7 @@ class Alias(FromClause): self.original = baseselectable self.selectable = selectable if alias is None: - if self.original.named_with_column(): + if self.original.named_with_column: alias = getattr(self.original, 'name', None) alias = '{ANON %d %s}' % (id(self), alias or 'anon') self.name = alias @@ -2408,9 +2404,6 @@ class Alias(FromClause): def _table_iterator(self): return self.original._table_iterator() - def named_with_column(self): - return True - def _exportable_columns(self): #return self.selectable._exportable_columns() return self.selectable.columns @@ -2602,7 +2595,7 @@ class _ColumnClause(ColumnElement): if self.is_literal: return None if self.__label is None: - if self.table is not None and self.table.named_with_column(): + if self.table is not None and self.table.named_with_column: self.__label = self.table.name + "_" + self.name counter = 1 while self.__label in self.table.c: @@ -2652,6 +2645,8 @@ class TableClause(FromClause): functionality. """ + named_with_column = True + def __init__(self, name, *columns): super(TableClause, self).__init__() self.name = self.fullname = name @@ -2666,9 +2661,6 @@ class TableClause(FromClause): # TableClause is immutable return self - def named_with_column(self): - return True - def append_column(self, c): self._columns[c.name] = c c.table = self @@ -3041,16 +3033,14 @@ class Select(_SelectBaseMixin, FromClause): froms = froms.difference(hide_froms) if len(froms) > 1: - corr = self.__correlate + if self.__correlate: + froms = froms.difference(self.__correlate) if self._should_correlate and existing_froms is not None: - corr.update(existing_froms) + froms = froms.difference(existing_froms) - f = froms.difference(corr) - if not f: + if not froms: raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate())) - return f - else: - return froms + return froms froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""") diff --git a/test/dialect/postgres.py b/test/dialect/postgres.py index 82f41f80a0..4affabb6cd 100644 --- a/test/dialect/postgres.py +++ b/test/dialect/postgres.py @@ -101,6 +101,9 @@ class ReturningTest(AssertMixin): result3 = table.insert(postgres_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False}) self.assertEqual([dict(row) for row in result3], [{'double_id':8}]) + + result4 = testbase.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons') + self.assertEqual([dict(row) for row in result4], [{'persons': 10}]) finally: table.drop() diff --git a/test/profiling/compiler.py b/test/profiling/compiler.py index 544e674f3e..6fa4f96590 100644 --- a/test/profiling/compiler.py +++ b/test/profiling/compiler.py @@ -24,7 +24,7 @@ class CompileTest(AssertMixin): t1.update().compile() # TODO: this is alittle high - @profiling.profiled('ctest_select', call_range=(130, 150), always=True) + @profiling.profiled('ctest_select', call_range=(110, 130), always=True) def test_select(self): s = select([t1], t1.c.c2==t2.c.c1) s.compile() diff --git a/test/profiling/zoomark.py b/test/profiling/zoomark.py index d18502c72a..48f0432cb5 100644 --- a/test/profiling/zoomark.py +++ b/test/profiling/zoomark.py @@ -50,7 +50,7 @@ class ZooMarkTest(testing.AssertMixin): metadata.create_all() @testing.supported('postgres') - @profiling.profiled('populate', call_range=(2800, 3700), always=True) + @profiling.profiled('populate', call_range=(2700, 3700), always=True) def test_1a_populate(self): Zoo = metadata.tables['Zoo'] Animal = metadata.tables['Animal'] @@ -126,7 +126,7 @@ class ZooMarkTest(testing.AssertMixin): tick = i.execute(Species='Tick', Name='Tick %d' % x, Legs=8) @testing.supported('postgres') - @profiling.profiled('properties', call_range=(2900, 3330), always=True) + @profiling.profiled('properties', call_range=(2300, 3030), always=True) def test_3_properties(self): Zoo = metadata.tables['Zoo'] Animal = metadata.tables['Animal'] @@ -149,7 +149,7 @@ class ZooMarkTest(testing.AssertMixin): ticks = fullobject(Animal.select(Animal.c.Species=='Tick')) @testing.supported('postgres') - @profiling.profiled('expressions', call_range=(10350, 12200), always=True) + @profiling.profiled('expressions', call_range=(9200, 12050), always=True) def test_4_expressions(self): Zoo = metadata.tables['Zoo'] Animal = metadata.tables['Animal'] @@ -203,7 +203,7 @@ class ZooMarkTest(testing.AssertMixin): assert len(fulltable(Animal.select(func.date_part('day', Animal.c.LastEscape) == 21))) == 1 @testing.supported('postgres') - @profiling.profiled('aggregates', call_range=(960, 1170), always=True) + @profiling.profiled('aggregates', call_range=(800, 1170), always=True) def test_5_aggregates(self): Animal = metadata.tables['Animal'] Zoo = metadata.tables['Zoo'] @@ -245,7 +245,7 @@ class ZooMarkTest(testing.AssertMixin): legs.sort() @testing.supported('postgres') - @profiling.profiled('editing', call_range=(1150, 1280), always=True) + @profiling.profiled('editing', call_range=(1050, 1180), always=True) def test_6_editing(self): Zoo = metadata.tables['Zoo'] @@ -274,7 +274,7 @@ class ZooMarkTest(testing.AssertMixin): assert SDZ['Founded'] == datetime.date(1935, 9, 13) @testing.supported('postgres') - @profiling.profiled('multiview', call_range=(2300, 2500), always=True) + @profiling.profiled('multiview', call_range=(1900, 2300), always=True) def test_7_multiview(self): Zoo = metadata.tables['Zoo'] Animal = metadata.tables['Animal']