From: Mike Bayer Date: Tue, 11 Nov 2014 17:34:00 +0000 (-0500) Subject: - Fixed issue where the columns from a SELECT embedded in an X-Git-Tag: rel_1_0_0~19^2~24 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b013fb82f5a5d891c6fd776e0e6ed926cdf2ffe1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Fixed issue where the columns from a SELECT embedded in an INSERT, either through the values clause or as a "from select", would pollute the column types used in the result set produced by the RETURNING clause when columns from both statements shared the same name, leading to potential errors or mis-adaptation when retrieving the returning rows. fixes #3248 --- diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst index 8ed2ea7761..66a7da8dab 100644 --- a/doc/build/changelog/changelog_09.rst +++ b/doc/build/changelog/changelog_09.rst @@ -13,6 +13,18 @@ .. changelog:: :version: 0.9.9 + .. change:: + :tags: bug, sql + :versions: 1.0.0 + :tickets: 3248 + + Fixed issue where the columns from a SELECT embedded in an + INSERT, either through the values clause or as a "from select", + would pollute the column types used in the result set produced by + the RETURNING clause when columns from both statements shared the + same name, leading to potential errors or mis-adaptation when + retrieving the returning rows. + .. change:: :tags: bug, orm, sqlite :versions: 1.0.0 diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5fa78ad0f4..8f3ede25f4 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1729,6 +1729,12 @@ class SQLCompiler(Compiled): ) def visit_insert(self, insert_stmt, **kw): + self.stack.append( + {'correlate_froms': set(), + "iswrapper": False, + "asfrom_froms": set(), + "selectable": insert_stmt}) + self.isinsert = True crud_params = crud._get_crud_params(self, insert_stmt, **kw) @@ -1812,6 +1818,8 @@ class SQLCompiler(Compiled): if self.returning and not self.returning_precedes_values: text += " " + returning_clause + self.stack.pop(-1) + return text def update_limit_clause(self, update_stmt): diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 5d1afe6162..9e99a947b0 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -3437,3 +3437,32 @@ class ResultMapTest(fixtures.TestBase): is_( comp.result_map['t1_a'][1][2], t1.c.a ) + + def test_insert_with_select_values(self): + astring = Column('a', String) + aint = Column('a', Integer) + m = MetaData() + Table('t1', m, astring) + t2 = Table('t2', m, aint) + + stmt = t2.insert().values(a=select([astring])).returning(aint) + comp = stmt.compile(dialect=postgresql.dialect()) + eq_( + comp.result_map, + {'a': ('a', (aint, 'a', 'a'), aint.type)} + ) + + def test_insert_from_select(self): + astring = Column('a', String) + aint = Column('a', Integer) + m = MetaData() + Table('t1', m, astring) + t2 = Table('t2', m, aint) + + stmt = t2.insert().from_select(['a'], select([astring])).\ + returning(aint) + comp = stmt.compile(dialect=postgresql.dialect()) + eq_( + comp.result_map, + {'a': ('a', (aint, 'a', 'a'), aint.type)} + ) diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index 79a0b38a5d..cd9f632b9b 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -160,6 +160,39 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): eq_(result2.fetchall(), [(2, False), ]) +class CompositeStatementTest(fixtures.TestBase): + __requires__ = 'returning', + __backend__ = True + + @testing.provide_metadata + def test_select_doesnt_pollute_result(self): + class MyType(TypeDecorator): + impl = Integer + + def process_result_value(self, value, dialect): + raise Exception("I have not been selected") + + t1 = Table( + 't1', self.metadata, + Column('x', MyType()) + ) + + t2 = Table( + 't2', self.metadata, + Column('x', Integer) + ) + + self.metadata.create_all(testing.db) + with testing.db.connect() as conn: + conn.execute(t1.insert().values(x=5)) + + stmt = t2.insert().values( + x=select([t1.c.x]).as_scalar()).returning(t2.c.x) + + result = conn.execute(stmt) + eq_(result.scalar(), 5) + + class SequenceReturningTest(fixtures.TestBase): __requires__ = 'returning', 'sequences' __backend__ = True