]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Apply type processing to untyped preexec default clause
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Feb 2017 21:43:59 +0000 (16:43 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 30 Mar 2017 18:58:50 +0000 (14:58 -0400)
Fixed bug where a SQL-oriented Python-side column default could fail to
be executed properly upon INSERT in the "pre-execute" codepath, if the
SQL itself were an untyped expression, such as plain text.  The "pre-
execute" codepath is fairly uncommon however can apply to non-integer
primary key columns with SQL defaults when RETURNING is not used.

Tests exist here to ensure typing is applied to
a typed expression for default, but in the case of
an untyped SQL value, we know the type from the column,
so apply this.

Change-Id: I5d8b391611c137b9f700115a50a2bf5b30abfe94
Fixes: #3923
doc/build/changelog/changelog_12.rst
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/schema.py
test/engine/test_execute.py
test/sql/test_defaults.py

index 834dc074de8f69c54aa156fc3a26d30f6c4571bb..9fd4c1936a64819bbb1bfc74ce04a811687c931e 100644 (file)
 .. changelog::
     :version: 1.2.0b1
 
+    .. change:: 3923
+        :tags: bug, sql
+        :tickets: 3923
+
+        Fixed bug where a SQL-oriented Python-side column default could fail to
+        be executed properly upon INSERT in the "pre-execute" codepath, if the
+        SQL itself were an untyped expression, such as plain text.  The "pre-
+        execute" codepath is fairly uncommon however can apply to non-integer
+        primary key columns with SQL defaults when RETURNING is not used.
+
     .. change:: 3785
         :tags: bug, sql
         :tickets: 3785
index c7d574a2164679b8c77677356071f3cad4a95cff..1c10f484f6b0ddfb8e52971b2dda290778a4dfef 100644 (file)
@@ -1073,7 +1073,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
             # TODO: expensive branching here should be
             # pulled into _exec_scalar()
             conn = self.connection
-            c = expression.select([default.arg]).compile(bind=conn)
+            if not default._arg_is_typed:
+                default_arg = expression.type_coerce(default.arg, type_)
+            else:
+                default_arg = default.arg
+            c = expression.select([default_arg]).compile(bind=conn)
             return conn._execute_compiled(c, (), {}).scalar()
         else:
             return default.arg
index f8d3209ef5c932ed834932ac2b142ef3b37cf146..cf12ce965b607cce2da0efade07acb5ebe28a132 100644 (file)
@@ -2067,6 +2067,14 @@ class ColumnDefault(DefaultGenerator):
             not self.is_clause_element and \
             not self.is_sequence
 
+    @util.memoized_property
+    @util.dependencies("sqlalchemy.sql.sqltypes")
+    def _arg_is_typed(self, sqltypes):
+        if self.is_clause_element:
+            return not isinstance(self.arg.type, sqltypes.NullType)
+        else:
+            return False
+
     def _maybe_wrap_callable(self, fn):
         """Wrap callables that don't accept a context.
 
index eff1026cd51f1d94271eb25dae8697c0c99eddab..8437aca37037ac090599c601487d2ce68d556e67 100644 (file)
@@ -1350,34 +1350,19 @@ class EngineEventsTest(fixtures.TestBase):
                         ('select * from t1', {}, None),
                         ('DROP TABLE t1', {}, None)]
 
-            # or engine.dialect.preexecute_pk_sequences:
-            if not testing.against('oracle+zxjdbc'):
-                cursor = [
-                    ('CREATE TABLE t1', {}, ()),
-                    ('INSERT INTO t1 (c1, c2)', {
-                        'c2': 'some data', 'c1': 5},
-                        (5, 'some data')),
-                    ('SELECT lower', {'lower_2': 'Foo'},
-                        ('Foo', )),
-                    ('INSERT INTO t1 (c1, c2)',
-                     {'c2': 'foo', 'c1': 6},
-                     (6, 'foo')),
-                    ('select * from t1', {}, ()),
-                    ('DROP TABLE t1', {}, ()),
-                ]
-            else:
-                insert2_params = 6, 'Foo'
-                if testing.against('oracle+zxjdbc'):
-                    insert2_params += (ReturningParam(12), )
-                cursor = [('CREATE TABLE t1', {}, ()),
-                          ('INSERT INTO t1 (c1, c2)',
-                           {'c2': 'some data', 'c1': 5}, (5, 'some data')),
-                          ('INSERT INTO t1 (c1, c2)',
-                           {'c1': 6, 'lower_2': 'Foo'}, insert2_params),
-                          ('select * from t1', {}, ()),
-                          ('DROP TABLE t1', {}, ())]
-                # bind param name 'lower_2' might
-                # be incorrect
+            cursor = [
+                ('CREATE TABLE t1', {}, ()),
+                ('INSERT INTO t1 (c1, c2)', {
+                    'c2': 'some data', 'c1': 5},
+                    (5, 'some data')),
+                ('SELECT lower', {'lower_1': 'Foo'},
+                    ('Foo', )),
+                ('INSERT INTO t1 (c1, c2)',
+                 {'c2': 'foo', 'c1': 6},
+                 (6, 'foo')),
+                ('select * from t1', {}, ()),
+                ('DROP TABLE t1', {}, ()),
+            ]
             self._assert_stmts(compiled, stmts)
             self._assert_stmts(cursor, cursor_stmts)
 
@@ -2363,31 +2348,19 @@ class ProxyConnectionTest(fixtures.TestBase):
                         ('INSERT INTO t1 (c1, c2)', {'c1': 6}, None),
                         ('select * from t1', {}, None),
                         ('DROP TABLE t1', {}, None)]
-            # or engine.dialect.pr eexecute_pk_sequence s:
-            # original comment above moved here for pep8 fix
-            if not testing.against('oracle+zxjdbc'):
-                cursor = [
-                    ('CREATE TABLE t1', {}, ()),
-                    ('INSERT INTO t1 (c1, c2)', {
-                     'c2': 'some data', 'c1': 5}, (5, 'some data')),
-                    ('SELECT lower', {'lower_2': 'Foo'},
-                        ('Foo', )),
-                    ('INSERT INTO t1 (c1, c2)', {'c2': 'foo', 'c1': 6},
-                     (6, 'foo')),
-                    ('select * from t1', {}, ()),
-                    ('DROP TABLE t1', {}, ()),
-                ]
-            else:
-                insert2_params = 6, 'Foo'
-                if testing.against('oracle+zxjdbc'):
-                    insert2_params += (ReturningParam(12), )
-                cursor = [('CREATE TABLE t1', {}, ()),
-                          ('INSERT INTO t1 (c1, c2)', {
-                           'c2': 'some data', 'c1': 5}, (5, 'some data')),
-                          ('INSERT INTO t1 (c1, c2)',
-                           {'c1': 6, 'lower_2': 'Foo'}, insert2_params),
-                          ('select * from t1', {}, ()),
-                          ('DROP TABLE t1', {}, ())]
+
+            cursor = [
+                ('CREATE TABLE t1', {}, ()),
+                ('INSERT INTO t1 (c1, c2)', {
+                 'c2': 'some data', 'c1': 5}, (5, 'some data')),
+                ('SELECT lower', {'lower_1': 'Foo'},
+                    ('Foo', )),
+                ('INSERT INTO t1 (c1, c2)', {'c2': 'foo', 'c1': 6},
+                 (6, 'foo')),
+                ('select * from t1', {}, ()),
+                ('DROP TABLE t1', {}, ()),
+            ]
+
             assert_stmts(compiled, stmts)
             assert_stmts(cursor, cursor_stmts)
 
index dff423bf9fb0ea4e48b3411d59adfc85c67a87ba..3cc7e715d2851bdf263250989b95dec05cdd17f9 100644 (file)
@@ -8,7 +8,7 @@ from sqlalchemy import testing
 from sqlalchemy.testing import engines
 from sqlalchemy import (
     MetaData, Integer, String, ForeignKey, Boolean, exc, Sequence, func,
-    literal, Unicode, cast)
+    literal, Unicode, cast, DateTime)
 from sqlalchemy.types import TypeDecorator, TypeEngine
 from sqlalchemy.testing.schema import Table, Column
 from sqlalchemy.dialects import sqlite
@@ -670,6 +670,13 @@ class PKDefaultTest(fixtures.TablesTest):
                 default=sa.select([func.max(t2.c.nextid)]).as_scalar()),
             Column('data', String(30)))
 
+        Table(
+            'date_table', metadata,
+            Column(
+                'date_id',
+                DateTime, default=text("current_timestamp"), primary_key=True)
+        )
+
     @testing.requires.returning
     def test_with_implicit_returning(self):
         self._test(True)
@@ -678,20 +685,26 @@ class PKDefaultTest(fixtures.TablesTest):
         self._test(False)
 
     def _test(self, returning):
-        t2, t1 = self.tables.t2, self.tables.t1
+        t2, t1, date_table = (
+            self.tables.t2, self.tables.t1, self.tables.date_table
+        )
 
         if not returning and not testing.db.dialect.implicit_returning:
             engine = testing.db
         else:
             engine = engines.testing_engine(
                 options={'implicit_returning': returning})
-        engine.execute(t2.insert(), nextid=1)
-        r = engine.execute(t1.insert(), data='hi')
-        eq_([1], r.inserted_primary_key)
+        with engine.begin() as conn:
+            conn.execute(t2.insert(), nextid=1)
+            r = conn.execute(t1.insert(), data='hi')
+            eq_([1], r.inserted_primary_key)
+
+            conn.execute(t2.insert(), nextid=2)
+            r = conn.execute(t1.insert(), data='there')
+            eq_([2], r.inserted_primary_key)
 
-        engine.execute(t2.insert(), nextid=2)
-        r = engine.execute(t1.insert(), data='there')
-        eq_([2], r.inserted_primary_key)
+            r = conn.execute(date_table.insert())
+            assert isinstance(r.inserted_primary_key[0], datetime.datetime)
 
 
 class PKIncrementTest(fixtures.TablesTest):
@@ -1353,9 +1366,15 @@ class SpecialTypePKTest(fixtures.TestBase):
     def test_literal_default_no_label(self):
         self._run_test(default=literal("INT_1", type_=self.MyInteger))
 
+    def test_literal_column_default_no_label(self):
+        self._run_test(default=literal_column("1", type_=self.MyInteger))
+
     def test_sequence(self):
         self._run_test(Sequence('foo_seq'))
 
+    def test_text_clause_default_no_type(self):
+        self._run_test(default=text('1'))
+
     def test_server_default(self):
         self._run_test(server_default='1',)