]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- named_with_column becomes an attribute
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 25 Nov 2007 03:28:49 +0000 (03:28 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 25 Nov 2007 03:28:49 +0000 (03:28 +0000)
- cleanup within compiler visit_select(), column labeling
- is_select() removed from dialects, replaced with returns_rows_text(), returns_rows_compiled()
- should_autocommit() removed from dialects, replaced with should_autocommit_text() and
should_autocommit_compiled()
- typemap and column_labels collections removed from Compiler, replaced with single "result_map" collection.
- ResultProxy uses more succinct logic in combination with result_map to target columns

16 files changed:
lib/sqlalchemy/databases/access.py
lib/sqlalchemy/databases/informix.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/databases/sybase.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/dialect/postgres.py
test/profiling/compiler.py
test/profiling/zoomark.py

index d57c9fa9f65b3e388e65b8caed041842a8d24932..354a8c33224e8ffd2f4ec4ca200fbd3c3b95bd9b 100644 (file)
@@ -356,11 +356,11 @@ class AccessCompiler(compiler.DefaultCompiler):
         """Access uses "mod" instead of "%" """
         return binary.operator == '%' and 'mod' or binary.operator
 
-    def label_select_column(self, select, column):
+    def label_select_column(self, select, column, asfrom):
         if isinstance(column, expression._Function):
-            return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])        
+            return column.label()
         else:
-            return super(AccessCompiler, self).label_select_column(select, column)
+            return super(AccessCompiler, self).label_select_column(select, column, asfrom)
 
     function_rewrites =  {'current_date':       'now',
                           'current_timestamp':  'now',
index 247ab2d41902bc143b68cfe632a4531a7c7c69dd..6b01bfc224193eb3169ace961c602bc94f01ce78 100644 (file)
@@ -409,15 +409,6 @@ class InfoCompiler(compiler.DefaultCompiler):
     def limit_clause(self, select):
         return ""
 
-    def __visit_label(self, label):
-        # TODO: whats this method for ?
-        if self.select_stack:
-            self.typemap.setdefault(label.name.lower(), label.obj.type)
-        if self.strings[label.obj]:
-            self.strings[label] = self.strings[label.obj] + " AS "  + label.name
-        else:
-            self.strings[label] = None
-
     def visit_function( self , func ):
         if func.name.lower() == 'current_date':
             return "today"
index 672f8d77cf69b1576ce0c53e3e1e3b2229619714..469355083bebe0736fa1cfbb14b5593076bcbf6e 100644 (file)
@@ -339,8 +339,8 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
     _ms_is_select = re.compile(r'\s*(?:SELECT|sp_columns)',
                                re.I | re.UNICODE)
     
-    def is_select(self):
-        return self._ms_is_select.match(self.statement) is not None
+    def returns_rows_text(self, statement):
+        return self._ms_is_select.match(statement) is not None
 
 
 class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):    
@@ -910,11 +910,11 @@ class MSSQLCompiler(compiler.DefaultCompiler):
         else:
             return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
 
-    def label_select_column(self, select, column):
+    def label_select_column(self, select, column, asfrom):
         if isinstance(column, expression._Function):
             return column.label(None)
         else:
-            return super(MSSQLCompiler, self).label_select_column(select, column)
+            return super(MSSQLCompiler, self).label_select_column(select, column, asfrom)
 
     function_rewrites =  {'current_date': 'getdate',
                           'length':     'len',
index 39bfc0beaa5363842f77d033a4fc81a0448c7697..03b9a749ce79214ee8803244c8e5b8ba9b12badd 100644 (file)
@@ -1378,9 +1378,6 @@ def descriptor():
 
 
 class MySQLExecutionContext(default.DefaultExecutionContext):
-    _my_is_select = re.compile(r'\s*(?:SELECT|SHOW|DESCRIBE|XA +RECOVER)',
-                               re.I | re.UNICODE)
-
     def post_exec(self):
         if self.compiled.isinsert and not self.executemany:
             if (not len(self._last_inserted_ids) or
@@ -1388,11 +1385,11 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
                 self._last_inserted_ids = ([self.cursor.lastrowid] +
                                            self._last_inserted_ids[1:])
 
-    def is_select(self):
-        return SELECT_RE.match(self.statement)
+    def returns_rows_text(self, statement):
+        return SELECT_RE.match(statement)
 
-    def should_autocommit(self):
-        return AUTOCOMMIT_RE.match(self.statement)
+    def should_autocommit_text(self, statement):
+        return AUTOCOMMIT_RE.match(statement)
 
 
 class MySQLDialect(default.DefaultDialect):
@@ -1873,9 +1870,6 @@ class MySQLCompiler(compiler.DefaultCompiler):
         if type_ is None:
             return self.process(cast.clause)
 
-        if self.stack and self.stack[-1].get('select'):
-            # not sure if we want to set the typemap here...
-            self.typemap.setdefault("CAST", cast.type)
         return 'CAST(%s AS %s)' % (self.process(cast.clause), type_)
 
 
index 88ac0e2026033e57824fb25c1d2862f5631a63f1..1cae31b537a79c14d69429de561137125ba09cff 100644 (file)
@@ -233,16 +233,24 @@ RETURNING_QUOTED_RE = re.compile(
 
 class PGExecutionContext(default.DefaultExecutionContext):
 
-    def is_select(self):
-        m = SELECT_RE.match(self.statement)
-        return m and (not m.group(1) or (RETURNING_RE.search(self.statement)
-           and RETURNING_QUOTED_RE.match(self.statement)))
+    def returns_rows_text(self, statement):
+        m = SELECT_RE.match(statement)
+        return m and (not m.group(1) or (RETURNING_RE.search(statement)
+           and RETURNING_QUOTED_RE.match(statement)))
+    
+    def returns_rows_compiled(self, compiled):
+        return isinstance(compiled.statement, expression.Selectable) or \
+            (
+                (compiled.isupdate or compiled.isinsert) and "postgres_returning" in compiled.statement.kwargs
+            )
         
     def create_cursor(self):
         # executing a default or Sequence standalone creates an execution context without a statement.  
         # so slightly hacky "if no statement assume we're server side" logic
+        # TODO: dont use regexp if Compiled is used ?
         self.__is_server_side = \
-            self.dialect.server_side_cursors and (self.statement is None or \
+            self.dialect.server_side_cursors and \
+            (self.statement is None or \
             (SELECT_RE.match(self.statement) and not re.search(r'FOR UPDATE(?: NOWAIT)?\s*$', self.statement, re.I))
         )
 
index 19d0855ff3ade312e7f0e4798c64832704a8cd67..16dd9427c0b9bc9784f871d834c2667a9980792c 100644 (file)
@@ -185,8 +185,8 @@ class SQLiteExecutionContext(default.DefaultExecutionContext):
             if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
                 self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
 
-    def is_select(self):
-        return SELECT_REGEXP.match(self.statement)
+    def returns_rows_text(self, statement):
+        return SELECT_REGEXP.match(statement)
         
 class SQLiteDialect(default.DefaultDialect):
     supports_alter = False
@@ -343,9 +343,6 @@ class SQLiteCompiler(compiler.DefaultCompiler):
         if self.dialect.supports_cast:
             return super(SQLiteCompiler, self).visit_cast(cast)
         else:
-            if self.stack and self.stack[-1].get('select'):
-                # not sure if we want to set the typemap here...
-                self.typemap.setdefault("CAST", cast.type)
             return self.process(cast.clause)
 
     def limit_clause(self, select):
index 87045d192612d4fbf087b875df39ad150a3addec..2209594ed7c3b498710bd912a54d225349241453 100644 (file)
@@ -778,11 +778,11 @@ class SybaseSQLCompiler(compiler.DefaultCompiler):
         else:
             return super(SybaseSQLCompiler, self).visit_binary(binary)
 
-    def label_select_column(self, select, column):
+    def label_select_column(self, select, column, asfrom):
         if isinstance(column, expression._Function):
-            return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])
+            return column.label(None)
         else:
-            return super(SybaseSQLCompiler, self).label_select_column(select, column)
+            return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom)
 
     function_rewrites =  {'current_date': 'getdate',
                          }
@@ -795,13 +795,7 @@ class SybaseSQLCompiler(compiler.DefaultCompiler):
             cast = expression._Cast(func, SybaseDate_mxodbc)
             # infinite recursion
             # res = self.visit_cast(cast)
-            if self.stack and self.stack[-1].get('select'):
-                # not sure if we want to set the typemap here...
-                self.typemap.setdefault("CAST", cast.type)
-#            res = "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
             res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause))
-#        elif func.name.lower() == 'count':
-#            res = 'count(*)'
         return res
 
     def for_update_clause(self, select):
index 21977b689b9807017cfc08c2294eaa18016ff380..9e30043253bc70d611f9af7277397b8721630cc5 100644 (file)
@@ -315,6 +315,12 @@ class ExecutionContext(object):
     isupdate
       True if the statement is an UPDATE.
 
+    should_autocommit
+      True if the statement is a "committable" statement
+      
+    returns_rows
+      True if the statement should return result rows
+      
     The Dialect should provide an ExecutionContext via the
     create_execution_context() method.  The `pre_exec` and `post_exec`
     methods will be called for compiled statements.
@@ -363,8 +369,13 @@ class ExecutionContext(object):
 
         raise NotImplementedError()
 
-    def should_autocommit(self):
-        """Return True if this context's statement should be 'committed' automatically in a non-transactional context"""
+    def should_autocommit_compiled(self, compiled):
+        """return True if the given Compiled object refers to a "committable" statement."""
+        
+        raise NotImplementedError()
+        
+    def should_autocommit_text(self, statement):
+        """Parse the given textual statement and return True if it refers to a "committable" statement"""
 
         raise NotImplementedError()
 
@@ -750,7 +761,7 @@ class Connection(Connectable):
 
         # TODO: have the dialect determine if autocommit can be set on
         # the connection directly without this extra step
-        if not self.in_transaction() and context.should_autocommit():
+        if not self.in_transaction() and context.should_autocommit:
             self._commit_impl()
 
     def _autorollback(self):
@@ -1305,7 +1316,7 @@ class ResultProxy(object):
         self.cursor = context.cursor
         self.connection = context.root_connection
         self.__echo = context.engine._should_log_info
-        if context.is_select():
+        if context.returns_rows:
             self._init_metadata()
             self._rowcount = None
         else:
@@ -1322,8 +1333,6 @@ class ResultProxy(object):
     out_parameters = property(lambda s:s.context.out_parameters)
 
     def _init_metadata(self):
-        if hasattr(self, '_ResultProxy__props'):
-            return
         self.__props = {}
         self._key_cache = self._create_key_cache()
         self.__keys = []
@@ -1336,20 +1345,24 @@ class ResultProxy(object):
                 # sqlite possibly prepending table name to colnames so strip
                 colname = (item[0].split('.')[-1]).decode(self.dialect.encoding)
 
-                if self.context.typemap is not None:
-                    type = self.context.typemap.get(colname.lower(), typemap.get(item[1], types.NULLTYPE))
+                if self.context.result_map:
+                    try:
+                        (name, obj, type_) = self.context.result_map[colname]
+                    except KeyError:
+                        (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE))
                 else:
-                    type = typemap.get(item[1], types.NULLTYPE)
+                    (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE))
 
-                rec = (type, type.dialect_impl(self.dialect).result_processor(self.dialect), i)
+                rec = (type_, type_.dialect_impl(self.dialect).result_processor(self.dialect), i)
 
-                if rec[0] is None:
-                    raise exceptions.InvalidRequestError(
-                        "None for metadata " + colname)
-                if self.__props.setdefault(colname.lower(), rec) is not rec:
-                    self.__props[colname.lower()] = (type, self.__ambiguous_processor(colname), 0)
+                if self.__props.setdefault(name.lower(), rec) is not rec:
+                    self.__props[name.lower()] = (type_, self.__ambiguous_processor(colname), 0)
+                    
                 self.__keys.append(colname)
                 self.__props[i] = rec
+                if obj:
+                    for o in obj:
+                        self.__props[o] = rec
 
             if self.__echo:
                 self.context.engine.logger.debug("Col " + repr(tuple([x[0] for x in metadata])))
@@ -1362,16 +1375,19 @@ class ResultProxy(object):
             """Given a key, which could be a ColumnElement, string, etc.,
             matches it to the appropriate key we got from the result set's
             metadata; then cache it locally for quick re-access."""
-
-            if isinstance(key, int) and key in props:
+            
+            if isinstance(key, basestring):
+                key = key.lower()
+            
+            try:
                 rec = props[key]
-            elif isinstance(key, basestring) and key.lower() in props:
-                rec = props[key.lower()]
-            elif isinstance(key, expression.ColumnElement):
-                label = context.column_labels.get(key._label, key.name).lower()
-                if label in props:
-                    rec = props[label]
-            if not "rec" in locals():
+            except KeyError:
+                # fallback for targeting a ColumnElement to a textual expression
+                if isinstance(key, expression.ColumnElement):
+                    if key._label.lower() in props:
+                        return props[key._label.lower()]
+                    elif key.name.lower() in props:
+                        return props[key.name.lower()]
                 raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key)))
 
             return rec
@@ -1470,18 +1486,20 @@ class ResultProxy(object):
 
     def _get_col(self, row, key):
         try:
-            rec = self._key_cache[key]
+            type_, processor, index = self._key_cache[key]
         except TypeError:
             # the 'slice' use case is very infrequent,
             # so we use an exception catch to reduce conditionals in _get_col
             if isinstance(key, slice):
                 indices = key.indices(len(row))
                 return tuple([self._get_col(row, i) for i in xrange(*indices)])
-
-        if rec[1]:
-            return rec[1](row[rec[2]])
+            else:
+                raise
+                
+        if processor:
+            return processor(row[index])
         else:
-            return row[rec[2]]
+            return row[index]
 
     def _fetchone_impl(self):
         return self.cursor.fetchone()
index a91d65b81f301c8aede1de80bcd06da870a3cd4d..19ab22c9e9d1857d9c91dd494010b045868a79be 100644 (file)
@@ -146,9 +146,8 @@ class DefaultExecutionContext(base.ExecutionContext):
                 if value is not None
             ])
             
-            self.typemap = compiled.typemap
-            self.column_labels = compiled.column_labels
-
+            self.result_map = compiled.result_map
+            
             if not dialect.supports_unicode_statements:
                 self.statement = unicode(compiled).encode(self.dialect.encoding)
             else:
@@ -156,6 +155,12 @@ class DefaultExecutionContext(base.ExecutionContext):
                 
             self.isinsert = compiled.isinsert
             self.isupdate = compiled.isupdate
+            if isinstance(compiled.statement, expression._TextClause):
+                self.returns_rows = self.returns_rows_text(self.statement)
+                self.should_autocommit = self.should_autocommit_text(self.statement)
+            else:
+                self.returns_rows = self.returns_rows_compiled(compiled)
+                self.should_autocommit = self.should_autocommit_compiled(compiled)
             
             if not parameters:
                 self.compiled_parameters = [compiled.construct_params()]
@@ -170,7 +175,7 @@ class DefaultExecutionContext(base.ExecutionContext):
 
         elif statement is not None:
             # plain text statement.  
-            self.typemap = self.column_labels = None
+            self.result_map = None
             self.parameters = self.__encode_param_keys(parameters)
             self.executemany = len(parameters) > 1
             if not dialect.supports_unicode_statements:
@@ -179,10 +184,12 @@ class DefaultExecutionContext(base.ExecutionContext):
                 self.statement = statement
             self.isinsert = self.isupdate = False
             self.cursor = self.create_cursor()
+            self.returns_rows = self.returns_rows_text(statement)
+            self.should_autocommit = self.should_autocommit_text(statement)
         else:
             # no statement. used for standalone ColumnDefault execution.
             self.statement = None
-            self.isinsert = self.isupdate = self.executemany = False
+            self.isinsert = self.isupdate = self.executemany = self.returns_rows = self.should_autocommit = False
             self.cursor = self.create_cursor()
     
     connection = property(lambda s:s._connection._branch())
@@ -244,10 +251,18 @@ class DefaultExecutionContext(base.ExecutionContext):
                 parameters.append(param)
         return parameters
                 
-    def is_select(self):
-        """return TRUE if the statement is expected to have result rows."""
+    def returns_rows_compiled(self, compiled):
+        return isinstance(compiled.statement, expression.Selectable)
         
-        return SELECT_REGEXP.match(self.statement)
+    def returns_rows_text(self, statement):
+        return SELECT_REGEXP.match(statement)
+
+    def should_autocommit_compiled(self, compiled):
+        return isinstance(compiled.statement, expression._UpdateBase)
+
+    def should_autocommit_text(self, statement):
+        return AUTOCOMMIT_REGEXP.match(statement)
+
 
     def create_cursor(self):
         return self._connection.connection.cursor()
@@ -261,9 +276,6 @@ class DefaultExecutionContext(base.ExecutionContext):
     def result(self):
         return self.get_result_proxy()
 
-    def should_autocommit(self):
-        return AUTOCOMMIT_REGEXP.match(self.statement)
-            
     def pre_exec(self):
         pass
 
index 1214025849e95554ddfab53ddfd01103c6010e43..3daf11ed0c5186ac09f59aa04a60073e9aacd480 100644 (file)
@@ -249,7 +249,7 @@ class Query(object):
         # alias non-labeled column elements. 
         if isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'):
             column = column.label(None)
-
+            
         q._entities = q._entities + [(column, None, id)]
         return q
         
@@ -887,7 +887,7 @@ class Query(object):
                     context.exec_with_path(self.select_mapper, value.key, value.setup, context, parentclauses=clauses)
             elif isinstance(m, sql.ColumnElement):
                 if clauses is not None:
-                    m = clauses.adapt_clause(m)
+                    m = clauses.aliased_column(m)
                 context.secondary_columns.append(m)
             
         if self._eager_loaders and self._nestable(**self._select_args()):
index 0e1e5f7a9caf8e7df48197463e7c8c45f049ddf2..8179810033a33e3e8f43d185bb77c5350b2a6382 100644 (file)
@@ -456,7 +456,7 @@ class Column(SchemaItem, expression._ColumnClause):
 
     def __str__(self):
         if self.table is not None:
-            if self.table.named_with_column():
+            if self.table.named_with_column:
                 return (self.table.description + "." + self.description)
             else:
                 return self.description
index c1f3bc2a05daa18ea339a48c491937d69f875d41..a31997d1b3303a5a2bf7fd35a694cf7afa6a8476 100644 (file)
@@ -130,13 +130,11 @@ class DefaultCompiler(engine.Compiled):
         # a stack.  what recursive compiler doesn't have a stack ? :)
         self.stack = []
         
-        # a dictionary of result-set column names (strings) to TypeEngine instances,
-        # which will be passed to a ResultProxy and used for resultset-level value conversion
-        self.typemap = {}
-
-        # a dictionary of select columns labels mapped to their "generated" label
-        self.column_labels = {}
-
+        # relates label names in the final SQL to
+        # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine.
+        # ResultProxy uses this for type processing and column targeting
+        self.result_map = {}
+        
         # a dictionary of ClauseElement subclasses to counters, which are used to
         # generate truncated identifier names or "anonymous" identifiers such as
         # for aliases
@@ -213,19 +211,15 @@ class DefaultCompiler(engine.Compiled):
     def visit_grouping(self, grouping, **kwargs):
         return "(" + self.process(grouping.elem) + ")"
         
-    def visit_label(self, label, typemap=None, column_labels=None):
+    def visit_label(self, label, result_map=None):
         labelname = self._truncated_identifier("colident", label.name)
         
-        if typemap is not None:
-            self.typemap.setdefault(labelname.lower(), label.obj.type)
+        if result_map is not None:
+            result_map[labelname] = (label.name, (label, label.obj), label.obj.type)
             
-        if column_labels is not None:
-            if isinstance(label.obj, sql._ColumnClause):
-                column_labels[label.obj._label] = labelname
-            column_labels[label.name] = labelname
         return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
         
-    def visit_column(self, column, typemap=None, column_labels=None, **kwargs):
+    def visit_column(self, column, result_map=None, **kwargs):
         # there is actually somewhat of a ruleset when you would *not* necessarily
         # want to truncate a column identifier, if its mapped to the name of a 
         # physical column.  but thats very hard to identify at this point, and 
@@ -236,15 +230,13 @@ class DefaultCompiler(engine.Compiled):
         else:
             name = column.name
 
-        if typemap is not None:
-            typemap.setdefault(name.lower(), column.type)
-        if column_labels is not None:    
-            self.column_labels.setdefault(column._label, name.lower())
+        if result_map is not None:
+            result_map[name] = (name, (column, ), column.type)
         
         if column._is_oid:
             n = self.dialect.oid_column_name(column)
             if n is not None:
-                if column.table is None or not column.table.named_with_column():
+                if column.table is None or not column.table.named_with_column:
                     return n
                 else:
                     return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + n
@@ -254,7 +246,7 @@ class DefaultCompiler(engine.Compiled):
                 return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + self.preparer.quote(pk, pkname)
             else:
                 return None
-        elif column.table is None or not column.table.named_with_column():
+        elif column.table is None or not column.table.named_with_column:
             if getattr(column, "is_literal", False):
                 return name
             else:
@@ -277,8 +269,9 @@ class DefaultCompiler(engine.Compiled):
 
     def visit_textclause(self, textclause, **kwargs):
         if textclause.typemap is not None:
-            self.typemap.update(textclause.typemap)
-            
+            for colname, type_ in textclause.typemap.iteritems():
+                self.result_map[colname] = (colname, None, type_)
+                
         def do_bindparam(m):
             name = m.group(1)
             if name in textclause.bindparams:
@@ -302,7 +295,7 @@ class DefaultCompiler(engine.Compiled):
             sep = ', '
         else:
             sep = " " + self.operator_string(clauselist.operator) + " "
-        return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep)
+        return sep.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None])
 
     def apply_function_parens(self, func):
         return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
@@ -310,12 +303,13 @@ class DefaultCompiler(engine.Compiled):
     def visit_calculatedclause(self, clause, **kwargs):
         return self.process(clause.clause_expr)
 
-    def visit_cast(self, cast, typemap=None, **kwargs):
+    def visit_cast(self, cast, **kwargs):
         return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
 
-    def visit_function(self, func, typemap=None, **kwargs):
-        if typemap is not None:
-            typemap.setdefault(func.name, func.type)
+    def visit_function(self, func, result_map=None, **kwargs):
+        if result_map is not None:
+            result_map[func.name] = (func.name, None, func.type)
+            
         if not self.apply_function_parens(func):
             return ".".join(func.packagenames + [func.name])
         else:
@@ -325,7 +319,7 @@ class DefaultCompiler(engine.Compiled):
         stack_entry = {'select':cs}
         
         if asfrom:
-            stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True
+            stack_entry['is_subquery'] = True
         elif self.stack and self.stack[-1].get('select'):
             stack_entry['is_subquery'] = True
         self.stack.append(stack_entry)
@@ -353,7 +347,7 @@ class DefaultCompiler(engine.Compiled):
             s = s + " " + self.operator_string(unary.modifier)
         return s
         
-    def visit_binary(self, binary, typemap=None, **kwargs):
+    def visit_binary(self, binary, **kwargs):
         op = self.operator_string(binary.operator)
         if callable(op):
             return op(self.process(binary.left), self.process(binary.right))
@@ -438,22 +432,17 @@ class DefaultCompiler(engine.Compiled):
         else:
             return self.process(alias.original, **kwargs)
 
-    def label_select_column(self, select, column):
-        """convert a column from a select's "columns" clause.
+    def label_select_column(self, select, column, asfrom):
+        """label columns present in a select()."""
         
-        given a select() and a column element from its inner_columns collection, return a
-        Label object if this column should be labeled in the columns clause.  Otherwise,
-        return None and the column will be used as-is.
-        
-        The calling method will traverse the returned label to acquire its string
-        representation.
-        """
-        
-        # SQLite doesnt like selecting from a subquery where the column
-        # names look like table.colname. so if column is in a "selected from"
-        # subquery, label it synoymously with its column name
+        if isinstance(column, sql._Label):
+            return column
+            
+        if select.use_labels and column._label:
+            return column.label(column._label)
+                
         if \
-            (self.stack and self.stack[-1].get('is_selected_from')) and \
+            asfrom and \
             isinstance(column, sql._ColumnClause) and \
             not column.is_literal and \
             column.table is not None and \
@@ -462,20 +451,20 @@ class DefaultCompiler(engine.Compiled):
         elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) and not hasattr(column, 'name'):
             return column.label(None)
         else:
-            return None
+            return column
 
     def visit_select(self, select, asfrom=False, parens=True, **kwargs):
 
         stack_entry = {'select':select}
         
         if asfrom:
-            stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True
+            stack_entry['is_subquery'] = True
             column_clause_args = {}
         elif self.stack and 'select' in self.stack[-1]:
             stack_entry['is_subquery'] = True
             column_clause_args = {}
         else:
-            column_clause_args = {'typemap':self.typemap, 'column_labels':self.column_labels}
+            column_clause_args = {'result_map':self.result_map}
             
         if self.stack and 'from' in self.stack[-1]:
             existingfroms = self.stack[-1]['from']
@@ -487,8 +476,7 @@ class DefaultCompiler(engine.Compiled):
         correlate_froms = util.Set()
         for f in froms:
             correlate_froms.add(f)
-            for f2 in f._get_from_objects():
-                correlate_froms.add(f2)
+            correlate_froms.update(f._get_from_objects())
 
         # TODO: might want to propigate existing froms for select(select(select))
         # where innermost select should correlate to outermost
@@ -501,19 +489,8 @@ class DefaultCompiler(engine.Compiled):
         inner_columns = util.OrderedSet()
                 
         for co in select.inner_columns:
-            if select.use_labels:
-                labelname = co._label
-                if labelname is not None:
-                    l = co.label(labelname)
-                    inner_columns.add(self.process(l, **column_clause_args))
-                else:
-                    inner_columns.add(self.process(co, **column_clause_args))
-            else:
-                l = self.label_select_column(select, co)
-                if l is not None:
-                    inner_columns.add(self.process(l, **column_clause_args))
-                else:
-                    inner_columns.add(self.process(co, **column_clause_args))
+            l = self.label_select_column(select, co, asfrom=asfrom)
+            inner_columns.add(self.process(l, **column_clause_args))
             
         collist = string.join(inner_columns.difference(util.Set([None])), ', ')
 
index b3200a7eba339a0a5ae41f738638b45fcce93be3..039145006ec39109ee270484019bde93e9ddd2e7 100644 (file)
@@ -1522,6 +1522,7 @@ class FromClause(Selectable):
     """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement."""
 
     __visit_name__ = 'fromclause'
+    named_with_column=False
 
     def __init__(self):
         self.oid_column = None
@@ -1562,13 +1563,6 @@ class FromClause(Selectable):
 
         return Alias(self, name)
 
-    def named_with_column(self):
-        """True if the name of this FromClause may be prepended to a
-        column in a generated SQL statement.
-        """
-
-        return False
-
     def is_derived_from(self, fromclause):
         """Return True if this FromClause is 'derived' from the given FromClause.
 
@@ -2379,6 +2373,8 @@ class Alias(FromClause):
     ``FromClause`` subclasses.
     """
 
+    named_with_column = True
+    
     def __init__(self, selectable, alias=None):
         baseselectable = selectable
         while isinstance(baseselectable, Alias):
@@ -2386,7 +2382,7 @@ class Alias(FromClause):
         self.original = baseselectable
         self.selectable = selectable
         if alias is None:
-            if self.original.named_with_column():
+            if self.original.named_with_column:
                 alias = getattr(self.original, 'name', None)
             alias = '{ANON %d %s}' % (id(self), alias or 'anon')
         self.name = alias
@@ -2408,9 +2404,6 @@ class Alias(FromClause):
     def _table_iterator(self):
         return self.original._table_iterator()
 
-    def named_with_column(self):
-        return True
-
     def _exportable_columns(self):
         #return self.selectable._exportable_columns()
         return self.selectable.columns
@@ -2602,7 +2595,7 @@ class _ColumnClause(ColumnElement):
         if self.is_literal:
             return None
         if self.__label is None:
-            if self.table is not None and self.table.named_with_column():
+            if self.table is not None and self.table.named_with_column:
                 self.__label = self.table.name + "_" + self.name
                 counter = 1
                 while self.__label in self.table.c:
@@ -2652,6 +2645,8 @@ class TableClause(FromClause):
     functionality.
     """
 
+    named_with_column = True
+    
     def __init__(self, name, *columns):
         super(TableClause, self).__init__()
         self.name = self.fullname = name
@@ -2666,9 +2661,6 @@ class TableClause(FromClause):
         # TableClause is immutable
         return self
 
-    def named_with_column(self):
-        return True
-
     def append_column(self, c):
         self._columns[c.name] = c
         c.table = self
@@ -3041,16 +3033,14 @@ class Select(_SelectBaseMixin, FromClause):
         froms = froms.difference(hide_froms)
         
         if len(froms) > 1:
-            corr = self.__correlate
+            if self.__correlate:
+                froms = froms.difference(self.__correlate)
             if self._should_correlate and existing_froms is not None:
-                corr.update(existing_froms)
+                froms = froms.difference(existing_froms)
                 
-            f = froms.difference(corr)
-            if not f:
+            if not froms:
                 raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
-            return f
-        else:
-            return froms
+        return froms
 
     froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""")
 
index 82f41f80a0eeab7abf5a6c731c255e36250a6683..4affabb6cddcbddc746120b1957c51296499c063 100644 (file)
@@ -101,6 +101,9 @@ class ReturningTest(AssertMixin):
             
             result3 = table.insert(postgres_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False})
             self.assertEqual([dict(row) for row in result3], [{'double_id':8}])
+            
+            result4 = testbase.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons')
+            self.assertEqual([dict(row) for row in result4], [{'persons': 10}])
         finally:
             table.drop()
     
index 544e674f3e2966fe03973431666cd8067339b358..6fa4f96590464e3dc3b01a7719a2b364564e5efc 100644 (file)
@@ -24,7 +24,7 @@ class CompileTest(AssertMixin):
         t1.update().compile()
 
     # TODO: this is alittle high
-    @profiling.profiled('ctest_select', call_range=(130, 150), always=True)        
+    @profiling.profiled('ctest_select', call_range=(110, 130), always=True)        
     def test_select(self):
         s = select([t1], t1.c.c2==t2.c.c1)
         s.compile()
index d18502c72aaa1a7c8dedfc90f23aedac9141a56f..48f0432cb55d45f914f4a4453a9515315d3fd016 100644 (file)
@@ -50,7 +50,7 @@ class ZooMarkTest(testing.AssertMixin):
         metadata.create_all()
         
     @testing.supported('postgres')
-    @profiling.profiled('populate', call_range=(2800, 3700), always=True)
+    @profiling.profiled('populate', call_range=(2700, 3700), always=True)
     def test_1a_populate(self):
         Zoo = metadata.tables['Zoo']
         Animal = metadata.tables['Animal']
@@ -126,7 +126,7 @@ class ZooMarkTest(testing.AssertMixin):
             tick = i.execute(Species='Tick', Name='Tick %d' % x, Legs=8)
     
     @testing.supported('postgres')
-    @profiling.profiled('properties', call_range=(2900, 3330), always=True)
+    @profiling.profiled('properties', call_range=(2300, 3030), always=True)
     def test_3_properties(self):
         Zoo = metadata.tables['Zoo']
         Animal = metadata.tables['Animal']
@@ -149,7 +149,7 @@ class ZooMarkTest(testing.AssertMixin):
             ticks = fullobject(Animal.select(Animal.c.Species=='Tick'))
     
     @testing.supported('postgres')
-    @profiling.profiled('expressions', call_range=(10350, 12200), always=True)
+    @profiling.profiled('expressions', call_range=(9200, 12050), always=True)
     def test_4_expressions(self):
         Zoo = metadata.tables['Zoo']
         Animal = metadata.tables['Animal']
@@ -203,7 +203,7 @@ class ZooMarkTest(testing.AssertMixin):
             assert len(fulltable(Animal.select(func.date_part('day', Animal.c.LastEscape) == 21))) == 1
     
     @testing.supported('postgres')
-    @profiling.profiled('aggregates', call_range=(960, 1170), always=True)
+    @profiling.profiled('aggregates', call_range=(800, 1170), always=True)
     def test_5_aggregates(self):
         Animal = metadata.tables['Animal']
         Zoo = metadata.tables['Zoo']
@@ -245,7 +245,7 @@ class ZooMarkTest(testing.AssertMixin):
             legs.sort()
     
     @testing.supported('postgres')
-    @profiling.profiled('editing', call_range=(1150, 1280), always=True)
+    @profiling.profiled('editing', call_range=(1050, 1180), always=True)
     def test_6_editing(self):
         Zoo = metadata.tables['Zoo']
         
@@ -274,7 +274,7 @@ class ZooMarkTest(testing.AssertMixin):
             assert SDZ['Founded'] == datetime.date(1935, 9, 13)
     
     @testing.supported('postgres')
-    @profiling.profiled('multiview', call_range=(2300, 2500), always=True)
+    @profiling.profiled('multiview', call_range=(1900, 2300), always=True)
     def test_7_multiview(self):
         Zoo = metadata.tables['Zoo']
         Animal = metadata.tables['Animal']