From: Mike Bayer Date: Tue, 11 Nov 2014 17:34:00 +0000 (-0500) Subject: - Fixed issue where the columns from a SELECT embedded in an X-Git-Tag: rel_0_9_9~75 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=fc2d19537331e07ae8b75b9a356fbaebd8046a0b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Fixed issue where the columns from a SELECT embedded in an INSERT, either through the values clause or as a "from select", would pollute the column types used in the result set produced by the RETURNING clause when columns from both statements shared the same name, leading to potential errors or mis-adaptation when retrieving the returning rows. fixes #3248 --- diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst index 8ed2ea7761..66a7da8dab 100644 --- a/doc/build/changelog/changelog_09.rst +++ b/doc/build/changelog/changelog_09.rst @@ -13,6 +13,18 @@ .. changelog:: :version: 0.9.9 + .. change:: + :tags: bug, sql + :versions: 1.0.0 + :tickets: 3248 + + Fixed issue where the columns from a SELECT embedded in an + INSERT, either through the values clause or as a "from select", + would pollute the column types used in the result set produced by + the RETURNING clause when columns from both statements shared the + same name, leading to potential errors or mis-adaptation when + retrieving the returning rows. + .. change:: :tags: bug, orm, sqlite :versions: 1.0.0 diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index a4877ce818..944308e333 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1669,6 +1669,12 @@ class SQLCompiler(Compiled): ) def visit_insert(self, insert_stmt, **kw): + self.stack.append( + {'correlate_froms': set(), + "iswrapper": False, + "asfrom_froms": set(), + "selectable": insert_stmt}) + self.isinsert = True colparams = self._get_colparams(insert_stmt, **kw) @@ -1752,6 +1758,8 @@ class SQLCompiler(Compiled): if self.returning and not self.returning_precedes_values: text += " " + returning_clause + self.stack.pop(-1) + return text def update_limit_clause(self, update_stmt): diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 1238706f6a..1a41c63e17 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -3314,3 +3314,32 @@ class ResultMapTest(fixtures.TestBase): is_( comp.result_map['t1_a'][1][2], t1.c.a ) + + def test_insert_with_select_values(self): + astring = Column('a', String) + aint = Column('a', Integer) + m = MetaData() + Table('t1', m, astring) + t2 = Table('t2', m, aint) + + stmt = t2.insert().values(a=select([astring])).returning(aint) + comp = stmt.compile(dialect=postgresql.dialect()) + eq_( + comp.result_map, + {'a': ('a', (aint, 'a', 'a'), aint.type)} + ) + + def test_insert_from_select(self): + astring = Column('a', String) + aint = Column('a', Integer) + m = MetaData() + Table('t1', m, astring) + t2 = Table('t2', m, aint) + + stmt = t2.insert().from_select(['a'], select([astring])).\ + returning(aint) + comp = stmt.compile(dialect=postgresql.dialect()) + eq_( + comp.result_map, + {'a': ('a', (aint, 'a', 'a'), aint.type)} + ) diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index 26dbcdaa21..f6d03c0d73 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -153,6 +153,39 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): eq_(result2.fetchall(), [(2, False), ]) +class CompositeStatementTest(fixtures.TestBase): + __requires__ = 'returning', + __backend__ = True + + @testing.provide_metadata + def test_select_doesnt_pollute_result(self): + class MyType(TypeDecorator): + impl = Integer + + def process_result_value(self, value, dialect): + raise Exception("I have not been selected") + + t1 = Table( + 't1', self.metadata, + Column('x', MyType()) + ) + + t2 = Table( + 't2', self.metadata, + Column('x', Integer) + ) + + self.metadata.create_all(testing.db) + with testing.db.connect() as conn: + conn.execute(t1.insert().values(x=5)) + + stmt = t2.insert().values( + x=select([t1.c.x]).as_scalar()).returning(t2.c.x) + + result = conn.execute(stmt) + eq_(result.scalar(), 5) + + class SequenceReturningTest(fixtures.TestBase): __requires__ = 'returning', 'sequences' __backend__ = True