From: Mike Bayer Date: Sun, 19 Jul 2009 04:59:18 +0000 (+0000) Subject: generate the RETURNING col lists the same was as visit_select() does (except for... X-Git-Tag: rel_0_6_6~107 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b87baebd9e68592bc6fee17e5eb42e2c03bdcced;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git generate the RETURNING col lists the same was as visit_select() does (except for oracle). mssql gets extra label stuff to deal with column adaption (not sure if column adaption should blow away labels like that...). fixes potential column targeting issues on all platforms + fixes mssql failures --- diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index 949289eb36..58fa19f50f 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -256,17 +256,13 @@ class FBCompiler(sql.compiler.SQLCompiler): def returning_clause(self, stmt): returning_cols = stmt._returning - def flatten_columnlist(collist): - for c in collist: - if isinstance(c, expression.Selectable): - for co in c.columns: - yield co - else: - yield c - columns = [ - self.process(c, within_columns_clause=True, result_map=self.result_map) - for c in flatten_columnlist(returning_cols) + self.process( + self.label_select_column(None, c, asfrom=False), + within_columns_clause=True, + result_map=self.result_map + ) + for c in expression._select_iterables(returning_cols) ] return 'RETURNING ' + ', '.join(columns) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 9831b5134b..c58e32f01a 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -222,6 +222,7 @@ Known Issues """ import datetime, decimal, inspect, operator, sys, re +import itertools from sqlalchemy import sql, schema as sa_schema, exc, util from sqlalchemy.sql import select, compiler, expression, \ @@ -1063,25 +1064,27 @@ class MSSQLCompiler(compiler.SQLCompiler): def returning_clause(self, stmt): returning_cols = stmt._returning - def flatten_columnlist(collist): - for c in collist: - if isinstance(c, expression.Selectable): - for co in c.columns: - yield co - else: - yield c - if self.isinsert or self.isupdate: target = stmt.table.alias("inserted") else: target = stmt.table.alias("deleted") adapter = sql_util.ClauseAdapter(target) + def col_label(col): + adapted = adapter.traverse(c) + if isinstance(c, expression._Label): + return adapted.label(c.key) + else: + return self.label_select_column(None, adapted, asfrom=False) + columns = [ - self.process(adapter.traverse(c), within_columns_clause=True, result_map=self.result_map) - for c in flatten_columnlist(returning_cols) + self.process( + col_label(c), + within_columns_clause=True, + result_map=self.result_map + ) + for c in expression._select_iterables(returning_cols) ] - return 'OUTPUT ' + ', '.join(columns) def label_select_column(self, select, column, asfrom): diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 7c956f6bed..cc19541eb1 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -321,23 +321,17 @@ class OracleCompiler(compiler.SQLCompiler): def returning_clause(self, stmt): returning_cols = stmt._returning - def flatten_columnlist(collist): - for c in collist: - if isinstance(c, expression.Selectable): - for co in c.columns: - yield co - else: - yield c - def create_out_param(col, i): bindparam = sql.outparam("ret_%d" % i, type_=col.type) self.binds[bindparam.key] = bindparam return self.bindparam_string(self._truncate_bindparam(bindparam)) + columnlist = list(expression._select_iterables(returning_cols)) + # within_columns_clause =False so that labels (foo AS bar) don't render - columns = [self.process(c, within_columns_clause=False) for c in flatten_columnlist(returning_cols)] + columns = [self.process(c, within_columns_clause=False) for c in columnlist] - binds = [create_out_param(c, i) for i, c in enumerate(flatten_columnlist(returning_cols))] + binds = [create_out_param(c, i) for i, c in enumerate(columnlist)] return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 849ec50066..2b0ebf5f40 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -266,17 +266,12 @@ class PGCompiler(compiler.SQLCompiler): def returning_clause(self, stmt): returning_cols = stmt._returning - def flatten_columnlist(collist): - for c in collist: - if isinstance(c, expression.Selectable): - for co in c.columns: - yield co - else: - yield c - columns = [ - self.process(c, within_columns_clause=True, result_map=self.result_map) - for c in flatten_columnlist(returning_cols) + self.process( + self.label_select_column(None, c, asfrom=False), + within_columns_clause=True, + result_map=self.result_map) + for c in expression._select_iterables(returning_cols) ] return 'RETURNING ' + ', '.join(columns) @@ -374,7 +369,8 @@ class PGDefaultRunner(base.DefaultRunner): def visit_sequence(self, seq): if not seq.optional: - return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq))) + return self.execute_string(("select nextval('%s')" % \ + self.dialect.identifier_preparer.format_sequence(seq))) else: return None diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 332294729d..7fff18d023 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1630,7 +1630,7 @@ class ResultProxy(object): self.rowcount self.close() # autoclose return - + self._props = util.populate_column_dict(None) self._props.creator = self.__key_fallback() self.keys = [] diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b862c8c811..8899486546 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -528,7 +528,7 @@ class SQLCompiler(engine.Compiled): if isinstance(column, sql._Label): return column - if select.use_labels and column._label: + if select and select.use_labels and column._label: return _CompileLabel(column, column._label) if \ diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 142cdcbe5a..8cf6109cae 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -889,6 +889,13 @@ def _expand_cloned(elements): """ return itertools.chain(*[x._cloned_set for x in elements]) +def _select_iterables(elements): + """expand tables into individual columns in the + given list of column expressions. + + """ + return itertools.chain(*[c._select_iterable for c in elements]) + def _cloned_intersection(a, b): """return the intersection of sets a and b, counting any overlap between 'cloned' predecessors. @@ -3465,7 +3472,7 @@ class Select(_SelectBaseMixin, FromClause): be rendered into the columns clause of the resulting SELECT statement. """ - return itertools.chain(*[c._select_iterable for c in self._raw_columns]) + return _select_iterables(self._raw_columns) def is_derived_from(self, fromclause): if self in fromclause._cloned_set: diff --git a/test/dialect/test_firebird.py b/test/dialect/test_firebird.py index 0c19a4c7e7..2dc6af91b7 100644 --- a/test/dialect/test_firebird.py +++ b/test/dialect/test_firebird.py @@ -113,7 +113,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): "RETURNING mytable.myid, mytable.name, mytable.description") u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) - self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name)") + self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name) AS length_1") def test_insert_returning(self): table1 = table('mytable', @@ -130,7 +130,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): "RETURNING mytable.myid, mytable.name, mytable.description") i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) - self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name)") + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name) AS length_1") diff --git a/test/dialect/test_mssql.py b/test/dialect/test_mssql.py index f76e1c9fb8..2537eb695e 100644 --- a/test/dialect/test_mssql.py +++ b/test/dialect/test_mssql.py @@ -177,7 +177,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): "inserted.name, inserted.description WHERE mytable.name = :name_1") u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) - self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name)") + self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name) AS length_1") def test_insert_returning(self): table1 = table('mytable', @@ -194,7 +194,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): "inserted.name, inserted.description VALUES (:name)") i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) - self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT LEN(inserted.name) VALUES (:name)") + self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT LEN(inserted.name) AS length_1 VALUES (:name)") diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index 2b9a687ebf..3d5b610548 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -40,7 +40,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect) u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) - self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name)", dialect=dialect) + self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name) AS length_1", dialect=dialect) def test_insert_returning(self): @@ -59,7 +59,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect) i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) - self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name) AS length_1", dialect=dialect) @testing.uses_deprecated(r".*argument is deprecated. Please use statement.returning.*") def test_old_returning_names(self):