From: Mike Bayer Date: Mon, 27 Feb 2017 21:43:59 +0000 (-0500) Subject: Apply type processing to untyped preexec default clause X-Git-Tag: rel_1_2_0b1~125 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4eb4010c1a1c3e5c2529b9be9d8d56f1d6a4ec00;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Apply type processing to untyped preexec default clause Fixed bug where a SQL-oriented Python-side column default could fail to be executed properly upon INSERT in the "pre-execute" codepath, if the SQL itself were an untyped expression, such as plain text. The "pre- execute" codepath is fairly uncommon however can apply to non-integer primary key columns with SQL defaults when RETURNING is not used. Tests exist here to ensure typing is applied to a typed expression for default, but in the case of an untyped SQL value, we know the type from the column, so apply this. Change-Id: I5d8b391611c137b9f700115a50a2bf5b30abfe94 Fixes: #3923 --- diff --git a/doc/build/changelog/changelog_12.rst b/doc/build/changelog/changelog_12.rst index 834dc074de..9fd4c1936a 100644 --- a/doc/build/changelog/changelog_12.rst +++ b/doc/build/changelog/changelog_12.rst @@ -13,6 +13,16 @@ .. changelog:: :version: 1.2.0b1 + .. change:: 3923 + :tags: bug, sql + :tickets: 3923 + + Fixed bug where a SQL-oriented Python-side column default could fail to + be executed properly upon INSERT in the "pre-execute" codepath, if the + SQL itself were an untyped expression, such as plain text. The "pre- + execute" codepath is fairly uncommon however can apply to non-integer + primary key columns with SQL defaults when RETURNING is not used. + .. change:: 3785 :tags: bug, sql :tickets: 3785 diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index c7d574a216..1c10f484f6 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1073,7 +1073,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext): # TODO: expensive branching here should be # pulled into _exec_scalar() conn = self.connection - c = expression.select([default.arg]).compile(bind=conn) + if not default._arg_is_typed: + default_arg = expression.type_coerce(default.arg, type_) + else: + default_arg = default.arg + c = expression.select([default_arg]).compile(bind=conn) return conn._execute_compiled(c, (), {}).scalar() else: return default.arg diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index f8d3209ef5..cf12ce965b 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -2067,6 +2067,14 @@ class ColumnDefault(DefaultGenerator): not self.is_clause_element and \ not self.is_sequence + @util.memoized_property + @util.dependencies("sqlalchemy.sql.sqltypes") + def _arg_is_typed(self, sqltypes): + if self.is_clause_element: + return not isinstance(self.arg.type, sqltypes.NullType) + else: + return False + def _maybe_wrap_callable(self, fn): """Wrap callables that don't accept a context. diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index eff1026cd5..8437aca370 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -1350,34 +1350,19 @@ class EngineEventsTest(fixtures.TestBase): ('select * from t1', {}, None), ('DROP TABLE t1', {}, None)] - # or engine.dialect.preexecute_pk_sequences: - if not testing.against('oracle+zxjdbc'): - cursor = [ - ('CREATE TABLE t1', {}, ()), - ('INSERT INTO t1 (c1, c2)', { - 'c2': 'some data', 'c1': 5}, - (5, 'some data')), - ('SELECT lower', {'lower_2': 'Foo'}, - ('Foo', )), - ('INSERT INTO t1 (c1, c2)', - {'c2': 'foo', 'c1': 6}, - (6, 'foo')), - ('select * from t1', {}, ()), - ('DROP TABLE t1', {}, ()), - ] - else: - insert2_params = 6, 'Foo' - if testing.against('oracle+zxjdbc'): - insert2_params += (ReturningParam(12), ) - cursor = [('CREATE TABLE t1', {}, ()), - ('INSERT INTO t1 (c1, c2)', - {'c2': 'some data', 'c1': 5}, (5, 'some data')), - ('INSERT INTO t1 (c1, c2)', - {'c1': 6, 'lower_2': 'Foo'}, insert2_params), - ('select * from t1', {}, ()), - ('DROP TABLE t1', {}, ())] - # bind param name 'lower_2' might - # be incorrect + cursor = [ + ('CREATE TABLE t1', {}, ()), + ('INSERT INTO t1 (c1, c2)', { + 'c2': 'some data', 'c1': 5}, + (5, 'some data')), + ('SELECT lower', {'lower_1': 'Foo'}, + ('Foo', )), + ('INSERT INTO t1 (c1, c2)', + {'c2': 'foo', 'c1': 6}, + (6, 'foo')), + ('select * from t1', {}, ()), + ('DROP TABLE t1', {}, ()), + ] self._assert_stmts(compiled, stmts) self._assert_stmts(cursor, cursor_stmts) @@ -2363,31 +2348,19 @@ class ProxyConnectionTest(fixtures.TestBase): ('INSERT INTO t1 (c1, c2)', {'c1': 6}, None), ('select * from t1', {}, None), ('DROP TABLE t1', {}, None)] - # or engine.dialect.pr eexecute_pk_sequence s: - # original comment above moved here for pep8 fix - if not testing.against('oracle+zxjdbc'): - cursor = [ - ('CREATE TABLE t1', {}, ()), - ('INSERT INTO t1 (c1, c2)', { - 'c2': 'some data', 'c1': 5}, (5, 'some data')), - ('SELECT lower', {'lower_2': 'Foo'}, - ('Foo', )), - ('INSERT INTO t1 (c1, c2)', {'c2': 'foo', 'c1': 6}, - (6, 'foo')), - ('select * from t1', {}, ()), - ('DROP TABLE t1', {}, ()), - ] - else: - insert2_params = 6, 'Foo' - if testing.against('oracle+zxjdbc'): - insert2_params += (ReturningParam(12), ) - cursor = [('CREATE TABLE t1', {}, ()), - ('INSERT INTO t1 (c1, c2)', { - 'c2': 'some data', 'c1': 5}, (5, 'some data')), - ('INSERT INTO t1 (c1, c2)', - {'c1': 6, 'lower_2': 'Foo'}, insert2_params), - ('select * from t1', {}, ()), - ('DROP TABLE t1', {}, ())] + + cursor = [ + ('CREATE TABLE t1', {}, ()), + ('INSERT INTO t1 (c1, c2)', { + 'c2': 'some data', 'c1': 5}, (5, 'some data')), + ('SELECT lower', {'lower_1': 'Foo'}, + ('Foo', )), + ('INSERT INTO t1 (c1, c2)', {'c2': 'foo', 'c1': 6}, + (6, 'foo')), + ('select * from t1', {}, ()), + ('DROP TABLE t1', {}, ()), + ] + assert_stmts(compiled, stmts) assert_stmts(cursor, cursor_stmts) diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index dff423bf9f..3cc7e715d2 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -8,7 +8,7 @@ from sqlalchemy import testing from sqlalchemy.testing import engines from sqlalchemy import ( MetaData, Integer, String, ForeignKey, Boolean, exc, Sequence, func, - literal, Unicode, cast) + literal, Unicode, cast, DateTime) from sqlalchemy.types import TypeDecorator, TypeEngine from sqlalchemy.testing.schema import Table, Column from sqlalchemy.dialects import sqlite @@ -670,6 +670,13 @@ class PKDefaultTest(fixtures.TablesTest): default=sa.select([func.max(t2.c.nextid)]).as_scalar()), Column('data', String(30))) + Table( + 'date_table', metadata, + Column( + 'date_id', + DateTime, default=text("current_timestamp"), primary_key=True) + ) + @testing.requires.returning def test_with_implicit_returning(self): self._test(True) @@ -678,20 +685,26 @@ class PKDefaultTest(fixtures.TablesTest): self._test(False) def _test(self, returning): - t2, t1 = self.tables.t2, self.tables.t1 + t2, t1, date_table = ( + self.tables.t2, self.tables.t1, self.tables.date_table + ) if not returning and not testing.db.dialect.implicit_returning: engine = testing.db else: engine = engines.testing_engine( options={'implicit_returning': returning}) - engine.execute(t2.insert(), nextid=1) - r = engine.execute(t1.insert(), data='hi') - eq_([1], r.inserted_primary_key) + with engine.begin() as conn: + conn.execute(t2.insert(), nextid=1) + r = conn.execute(t1.insert(), data='hi') + eq_([1], r.inserted_primary_key) + + conn.execute(t2.insert(), nextid=2) + r = conn.execute(t1.insert(), data='there') + eq_([2], r.inserted_primary_key) - engine.execute(t2.insert(), nextid=2) - r = engine.execute(t1.insert(), data='there') - eq_([2], r.inserted_primary_key) + r = conn.execute(date_table.insert()) + assert isinstance(r.inserted_primary_key[0], datetime.datetime) class PKIncrementTest(fixtures.TablesTest): @@ -1353,9 +1366,15 @@ class SpecialTypePKTest(fixtures.TestBase): def test_literal_default_no_label(self): self._run_test(default=literal("INT_1", type_=self.MyInteger)) + def test_literal_column_default_no_label(self): + self._run_test(default=literal_column("1", type_=self.MyInteger)) + def test_sequence(self): self._run_test(Sequence('foo_seq')) + def test_text_clause_default_no_type(self): + self._run_test(default=text('1')) + def test_server_default(self): self._run_test(server_default='1',)