]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
generate the RETURNING col lists the same was as visit_select() does (except for...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 19 Jul 2009 04:59:18 +0000 (04:59 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 19 Jul 2009 04:59:18 +0000 (04:59 +0000)
mssql gets extra label stuff to deal with column adaption (not sure if column adaption should
blow away labels like that...).   fixes potential column targeting issues on all platforms
+ fixes mssql failures

lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/dialect/test_firebird.py
test/dialect/test_mssql.py
test/dialect/test_postgresql.py

index 949289eb360b75f4c7acb0d3c64f31e3fbf00f43..58fa19f50fd63b85007ef59de3f9c5117e1d788b 100644 (file)
@@ -256,17 +256,13 @@ class FBCompiler(sql.compiler.SQLCompiler):
     def returning_clause(self, stmt):
         returning_cols = stmt._returning
 
-        def flatten_columnlist(collist):
-            for c in collist:
-                if isinstance(c, expression.Selectable):
-                    for co in c.columns:
-                        yield co
-                else:
-                    yield c
-
         columns = [
-                self.process(c, within_columns_clause=True, result_map=self.result_map) 
-                for c in flatten_columnlist(returning_cols)
+                self.process(
+                    self.label_select_column(None, c, asfrom=False), 
+                    within_columns_clause=True, 
+                    result_map=self.result_map
+                ) 
+                for c in expression._select_iterables(returning_cols)
             ]
         return 'RETURNING ' + ', '.join(columns)
 
index 9831b5134bbe2286f410cec5dec9b363cc03d4cd..c58e32f01a38a336b1fdbfe0843dc1780333959e 100644 (file)
@@ -222,6 +222,7 @@ Known Issues
 
 """
 import datetime, decimal, inspect, operator, sys, re
+import itertools
 
 from sqlalchemy import sql, schema as sa_schema, exc, util
 from sqlalchemy.sql import select, compiler, expression, \
@@ -1063,25 +1064,27 @@ class MSSQLCompiler(compiler.SQLCompiler):
     def returning_clause(self, stmt):
         returning_cols = stmt._returning
 
-        def flatten_columnlist(collist):
-            for c in collist:
-                if isinstance(c, expression.Selectable):
-                    for co in c.columns:
-                        yield co
-                else:
-                    yield c
-                    
         if self.isinsert or self.isupdate:
             target = stmt.table.alias("inserted")
         else:
             target = stmt.table.alias("deleted")
         
         adapter = sql_util.ClauseAdapter(target)
+        def col_label(col):
+            adapted = adapter.traverse(c)
+            if isinstance(c, expression._Label):
+                return adapted.label(c.key)
+            else:
+                return self.label_select_column(None, adapted, asfrom=False)
+            
         columns = [
-            self.process(adapter.traverse(c), within_columns_clause=True, result_map=self.result_map) 
-            for c in flatten_columnlist(returning_cols)
+            self.process(
+                col_label(c), 
+                within_columns_clause=True, 
+                result_map=self.result_map
+            ) 
+            for c in expression._select_iterables(returning_cols)
         ]
-
         return 'OUTPUT ' + ', '.join(columns)
 
     def label_select_column(self, select, column, asfrom):
index 7c956f6bede3b8bce38b0854745021faba6f4792..cc19541eb11df6c14b76f82d54537ba69f00cfbe 100644 (file)
@@ -321,23 +321,17 @@ class OracleCompiler(compiler.SQLCompiler):
     def returning_clause(self, stmt):
         returning_cols = stmt._returning
             
-        def flatten_columnlist(collist):
-            for c in collist:
-                if isinstance(c, expression.Selectable):
-                    for co in c.columns:
-                        yield co
-                else:
-                    yield c
-
         def create_out_param(col, i):
             bindparam = sql.outparam("ret_%d" % i, type_=col.type)
             self.binds[bindparam.key] = bindparam
             return self.bindparam_string(self._truncate_bindparam(bindparam))
         
+        columnlist = list(expression._select_iterables(returning_cols))
+        
         # within_columns_clause =False so that labels (foo AS bar) don't render
-        columns = [self.process(c, within_columns_clause=False) for c in flatten_columnlist(returning_cols)]
+        columns = [self.process(c, within_columns_clause=False) for c in columnlist]
         
-        binds = [create_out_param(c, i) for i, c in enumerate(flatten_columnlist(returning_cols))]
+        binds = [create_out_param(c, i) for i, c in enumerate(columnlist)]
         
         return 'RETURNING ' + ', '.join(columns) +  " INTO " + ", ".join(binds)
 
index 849ec500668a9d1d199308239db483033e2a8fdb..2b0ebf5f40e02a6dfbe03cab619afce046118513 100644 (file)
@@ -266,17 +266,12 @@ class PGCompiler(compiler.SQLCompiler):
     def returning_clause(self, stmt):
         returning_cols = stmt._returning
         
-        def flatten_columnlist(collist):
-            for c in collist:
-                if isinstance(c, expression.Selectable):
-                    for co in c.columns:
-                        yield co
-                else:
-                    yield c
-    
         columns = [
-                self.process(c, within_columns_clause=True, result_map=self.result_map) 
-                for c in flatten_columnlist(returning_cols)
+                self.process(
+                    self.label_select_column(None, c, asfrom=False), 
+                    within_columns_clause=True, 
+                    result_map=self.result_map) 
+                for c in expression._select_iterables(returning_cols)
             ]
             
         return 'RETURNING ' + ', '.join(columns)
@@ -374,7 +369,8 @@ class PGDefaultRunner(base.DefaultRunner):
 
     def visit_sequence(self, seq):
         if not seq.optional:
-            return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)))
+            return self.execute_string(("select nextval('%s')" % \
+                        self.dialect.identifier_preparer.format_sequence(seq)))
         else:
             return None
 
index 332294729d694244dee8388f1bbd63e6bd647437..7fff18d0239fbea58dcd886d32cd7805875555be 100644 (file)
@@ -1630,7 +1630,7 @@ class ResultProxy(object):
             self.rowcount
             self.close() # autoclose
             return
-
+        
         self._props = util.populate_column_dict(None)
         self._props.creator = self.__key_fallback()
         self.keys = []
index b862c8c8114f3970bd1d44467f84d83d5dd3c444..8899486546fe76d8ee4954e66ca67c60ea5d53a2 100644 (file)
@@ -528,7 +528,7 @@ class SQLCompiler(engine.Compiled):
         if isinstance(column, sql._Label):
             return column
 
-        if select.use_labels and column._label:
+        if select and select.use_labels and column._label:
             return _CompileLabel(column, column._label)
 
         if \
index 142cdcbe5a251638d68f84157c1a082a4afc7451..8cf6109cae3fee7f472af91d884f4da0c85d1937 100644 (file)
@@ -889,6 +889,13 @@ def _expand_cloned(elements):
     """
     return itertools.chain(*[x._cloned_set for x in elements])
 
+def _select_iterables(elements):
+    """expand tables into individual columns in the 
+    given list of column expressions.
+    
+    """
+    return itertools.chain(*[c._select_iterable for c in elements])
+    
 def _cloned_intersection(a, b):
     """return the intersection of sets a and b, counting
     any overlap between 'cloned' predecessors.
@@ -3465,7 +3472,7 @@ class Select(_SelectBaseMixin, FromClause):
         be rendered into the columns clause of the resulting SELECT statement.
 
         """
-        return itertools.chain(*[c._select_iterable for c in self._raw_columns])
+        return _select_iterables(self._raw_columns)
 
     def is_derived_from(self, fromclause):
         if self in fromclause._cloned_set:
index 0c19a4c7e737186353152684e3a817e076acea2c..2dc6af91b76ae7c4c8246e63f6632b845a0fbcd1 100644 (file)
@@ -113,7 +113,7 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             "RETURNING mytable.myid, mytable.name, mytable.description")
 
         u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
-        self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name)")
+        self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name) AS length_1")
 
     def test_insert_returning(self):
         table1 = table('mytable',
@@ -130,7 +130,7 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             "RETURNING mytable.myid, mytable.name, mytable.description")
 
         i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
-        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name)")
+        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name) AS length_1")
 
 
 
index f76e1c9fb8a16170b81809508ac3d3cdab64d4d9..2537eb695e2ce67612346ed8cf5aa7a7d73c8750 100644 (file)
@@ -177,7 +177,7 @@ class CompileTest(TestBase, AssertsCompiledSQL):
                             "inserted.name, inserted.description WHERE mytable.name = :name_1")
         
         u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
-        self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name)")
+        self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name) AS length_1")
 
     def test_insert_returning(self):
         table1 = table('mytable',
@@ -194,7 +194,7 @@ class CompileTest(TestBase, AssertsCompiledSQL):
                                 "inserted.name, inserted.description VALUES (:name)")
 
         i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
-        self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT LEN(inserted.name) VALUES (:name)")
+        self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT LEN(inserted.name) AS length_1 VALUES (:name)")
 
 
 
index 2b9a687ebf087cfb12ae58184abdcc09ce52fef6..3d5b61054860e19685ad4bf188353e8b7f84a375 100644 (file)
@@ -40,7 +40,7 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect)
 
         u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
-        self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name)", dialect=dialect)
+        self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name) AS length_1", dialect=dialect)
 
         
     def test_insert_returning(self):
@@ -59,7 +59,7 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect)
 
         i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
-        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect)
+        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name) AS length_1", dialect=dialect)
     
     @testing.uses_deprecated(r".*argument is deprecated.  Please use statement.returning.*")
     def test_old_returning_names(self):