]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Work w/ prefetch even for selects, if present
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Jul 2016 20:38:22 +0000 (16:38 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Jul 2016 20:43:49 +0000 (16:43 -0400)
Fixed bug in new CTE feature for update/insert/delete stated
as a CTE inside of an enclosing statement (typically SELECT) whereby
oninsert and onupdate values weren't called upon for the embedded
statement.

This is accomplished by consulting prefetch
for all statements.  The collection is also broken into
separate insert/update collections so that we don't need to
consult toplevel self.isinsert to determine if the prefetch
is for an insert or an update.  What we don't yet test for
are CTE combinations that have both insert/update in one
statement, though these should now work in theory provided
the underlying database supports such a statement.

Change-Id: I3b6a860e22c86743c91c56a7ec751ff706f66f64
Fixes: #3745
doc/build/changelog/changelog_11.rst
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
test/requirements.py
test/sql/test_defaults.py

index 8ed600639c3bec879ad5d46c95a031141bba6905..75c434da8e3f227a2fcfa04127a57d10b7e34874 100644 (file)
 .. changelog::
     :version: 1.1.0b3
 
+    .. change::
+        :tags: bug, sql
+        :tickets: 3745
+
+        Fixed bug in new CTE feature for update/insert/delete stated
+        as a CTE inside of an enclosing statement (typically SELECT) whereby
+        oninsert and onupdate values weren't called upon for the embedded
+        statement.
+
     .. change::
         :tags: bug, ext
 
index 3ed2d5ee8ceeb9fd0b694e3095ac51cdda4835f7..1bb575984ab2a58ee395f620352a5404f312adc2 100644 (file)
@@ -593,12 +593,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
             self._is_implicit_returning = bool(
                 compiled.returning and not compiled.statement._returning)
 
-            if not self.isdelete:
-                if self.compiled.prefetch:
-                    if self.executemany:
-                        self._process_executemany_defaults()
-                    else:
-                        self._process_executesingle_defaults()
+        if self.compiled.insert_prefetch or self.compiled.update_prefetch:
+            if self.executemany:
+                self._process_executemany_defaults()
+            else:
+                self._process_executesingle_defaults()
 
         processors = compiled._bind_processors
 
@@ -712,7 +711,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
 
     @util.memoized_property
     def prefetch_cols(self):
-        return self.compiled.prefetch
+        if self.isinsert:
+            return self.compiled.insert_prefetch
+        elif self.isupdate:
+            return self.compiled.update_prefetch
+        else:
+            return ()
 
     @util.memoized_property
     def returning_cols(self):
@@ -1007,46 +1011,57 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
     def _process_executemany_defaults(self):
         key_getter = self.compiled._key_getters_for_crud_column[2]
 
-        prefetch = self.compiled.prefetch
         scalar_defaults = {}
 
+        insert_prefetch = self.compiled.insert_prefetch
+        update_prefetch = self.compiled.update_prefetch
+
         # pre-determine scalar Python-side defaults
         # to avoid many calls of get_insert_default()/
         # get_update_default()
-        for c in prefetch:
-            if self.isinsert and c.default and c.default.is_scalar:
+        for c in insert_prefetch:
+            if c.default and c.default.is_scalar:
                 scalar_defaults[c] = c.default.arg
-            elif self.isupdate and c.onupdate and c.onupdate.is_scalar:
+        for c in update_prefetch:
+            if c.onupdate and c.onupdate.is_scalar:
                 scalar_defaults[c] = c.onupdate.arg
 
         for param in self.compiled_parameters:
             self.current_parameters = param
-            for c in prefetch:
+            for c in insert_prefetch:
                 if c in scalar_defaults:
                     val = scalar_defaults[c]
-                elif self.isinsert:
+                else:
                     val = self.get_insert_default(c)
+                if val is not None:
+                    param[key_getter(c)] = val
+            for c in update_prefetch:
+                if c in scalar_defaults:
+                    val = scalar_defaults[c]
                 else:
                     val = self.get_update_default(c)
                 if val is not None:
                     param[key_getter(c)] = val
+
         del self.current_parameters
 
     def _process_executesingle_defaults(self):
         key_getter = self.compiled._key_getters_for_crud_column[2]
-        prefetch = self.compiled.prefetch
         self.current_parameters = compiled_parameters = \
             self.compiled_parameters[0]
 
-        for c in prefetch:
-            if self.isinsert:
-                if c.default and \
-                        not c.default.is_sequence and c.default.is_scalar:
-                    val = c.default.arg
-                else:
-                    val = self.get_insert_default(c)
+        for c in self.compiled.insert_prefetch:
+            if c.default and \
+                    not c.default.is_sequence and c.default.is_scalar:
+                val = c.default.arg
             else:
-                val = self.get_update_default(c)
+                val = self.get_insert_default(c)
+
+            if val is not None:
+                compiled_parameters[key_getter(c)] = val
+
+        for c in self.compiled.update_prefetch:
+            val = self.get_update_default(c)
 
             if val is not None:
                 compiled_parameters[key_getter(c)] = val
index 16ca7f959b0880ca06b39f67042a68b3b9263c8e..095c84f03b9b27e618ecda6e4ab133443c9caafd 100644 (file)
@@ -359,6 +359,8 @@ class SQLCompiler(Compiled):
     True unless using an unordered TextAsFrom.
     """
 
+    insert_prefetch = update_prefetch = ()
+
     def __init__(self, dialect, statement, column_keys=None,
                  inline=False, **kwargs):
         """Construct a new :class:`.SQLCompiler` object.
@@ -428,6 +430,10 @@ class SQLCompiler(Compiled):
         if self.positional and dialect.paramstyle == 'numeric':
             self._apply_numbered_params()
 
+    @property
+    def prefetch(self):
+        return list(self.insert_prefetch + self.update_prefetch)
+
     @util.memoized_instancemethod
     def _init_cte_state(self):
         """Initialize collections related to CTEs only if
index 70e03d220de3f1d1893388afea9c915b94973962..f770fc5134ef204dd70ffab0c2fe0a24bdc3c556 100644 (file)
@@ -11,6 +11,7 @@ within INSERT and UPDATE statements.
 """
 from .. import util
 from .. import exc
+from . import dml
 from . import elements
 import operator
 
@@ -73,7 +74,8 @@ def _get_crud_params(compiler, stmt, **kw):
     """
 
     compiler.postfetch = []
-    compiler.prefetch = []
+    compiler.insert_prefetch = []
+    compiler.update_prefetch = []
     compiler.returning = []
 
     # no parameters in the statement, no parameters in the
@@ -370,7 +372,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
             compiler.returning.append(c)
         else:
             values.append(
-                (c, _create_prefetch_bind_param(compiler, c))
+                (c, _create_insert_prefetch_bind_param(compiler, c))
             )
     elif c is stmt.table._autoincrement_column or c.server_default is not None:
         compiler.returning.append(c)
@@ -380,9 +382,15 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
         _raise_pk_with_no_anticipated_value(c)
 
 
-def _create_prefetch_bind_param(compiler, c, process=True, name=None):
+def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None):
     param = _create_bind_param(compiler, c, None, process=process, name=name)
-    compiler.prefetch.append(c)
+    compiler.insert_prefetch.append(c)
+    return param
+
+
+def _create_update_prefetch_bind_param(compiler, c, process=True, name=None):
+    param = _create_bind_param(compiler, c, None, process=process, name=name)
+    compiler.update_prefetch.append(c)
     return param
 
 
@@ -399,7 +407,7 @@ class _multiparam_column(elements.ColumnElement):
             other.original == self.original
 
 
-def _process_multiparam_default_bind(compiler, c, index, kw):
+def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
 
     if not c.default:
         raise exc.CompileError(
@@ -410,7 +418,10 @@ def _process_multiparam_default_bind(compiler, c, index, kw):
         return compiler.process(c.default.arg.self_group(), **kw)
     else:
         col = _multiparam_column(c, index)
-        return _create_prefetch_bind_param(compiler, col)
+        if isinstance(stmt, dml.Insert):
+            return _create_insert_prefetch_bind_param(compiler, col)
+        else:
+            return _create_update_prefetch_bind_param(compiler, col)
 
 
 def _append_param_insert_pk(compiler, stmt, c, values, kw):
@@ -448,7 +459,7 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
             )
     ):
         values.append(
-            (c, _create_prefetch_bind_param(compiler, c))
+            (c, _create_insert_prefetch_bind_param(compiler, c))
         )
     elif c.default is None and c.server_default is None and not c.nullable:
         # no .default, no .server_default, not autoincrement, we have
@@ -482,7 +493,7 @@ def _append_param_insert_hasdefault(
             compiler.postfetch.append(c)
     else:
         values.append(
-            (c, _create_prefetch_bind_param(compiler, c))
+            (c, _create_insert_prefetch_bind_param(compiler, c))
         )
 
 
@@ -500,7 +511,7 @@ def _append_param_insert_select_hasdefault(
         values.append((c, proc))
     else:
         values.append(
-            (c, _create_prefetch_bind_param(compiler, c, process=False))
+            (c, _create_insert_prefetch_bind_param(compiler, c, process=False))
         )
 
 
@@ -520,7 +531,7 @@ def _append_param_update(
                 compiler.postfetch.append(c)
         else:
             values.append(
-                (c, _create_prefetch_bind_param(compiler, c))
+                (c, _create_update_prefetch_bind_param(compiler, c))
             )
     elif c.server_onupdate is not None:
         if implicit_return_defaults and \
@@ -575,7 +586,7 @@ def _get_multitable_params(
                     compiler.postfetch.append(c)
                 else:
                     values.append(
-                        (c, _create_prefetch_bind_param(
+                        (c, _create_update_prefetch_bind_param(
                             compiler, c, name=_col_bind_name(c)))
                     )
             elif c.server_onupdate is not None:
@@ -597,7 +608,7 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw):
                     else compiler.process(
                         row[c.key].self_group(), **kw))
                 if c.key in row else
-                _process_multiparam_default_bind(compiler, c, i, kw)
+                _process_multiparam_default_bind(compiler, stmt, c, i, kw)
             )
             for (c, param) in values_0
         ]
index d31088e16a3aadd7ed610bb20675e8d7b4f4e9d2..3c7a3fbb41bdef834b443172ec4c9535443dbe65 100644 (file)
@@ -350,6 +350,14 @@ class DefaultRequirements(SuiteRequirements):
 
         return skip_if(exclude('mysql', '<', (4, 1, 1)), 'no subquery support')
 
+    @property
+    def ctes(self):
+        """Target database supports CTEs"""
+
+        return only_if(
+            ['postgresql', 'mssql']
+        )
+
     @property
     def mod_operator_as_percent_sign(self):
         """target database must use a plain percent '%' as the 'modulus'
index db19e145bd0b40c66d03cd500397539f5b9e23ff..57af1e536e06f8b3ea3a5acabd44865ecf1b34f1 100644 (file)
@@ -539,6 +539,93 @@ class DefaultTest(fixtures.TestBase):
         eq_(55, l['col3'])
 
 
+class CTEDefaultTest(fixtures.TablesTest):
+    __requires__ = ('ctes',)
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            'q', metadata,
+            Column('x', Integer, default=2),
+            Column('y', Integer, onupdate=5),
+            Column('z', Integer)
+        )
+
+        Table(
+            'p', metadata,
+            Column('s', Integer),
+            Column('t', Integer),
+            Column('u', Integer, onupdate=1)
+        )
+
+    def _test_a_in_b(self, a, b):
+        q = self.tables.q
+        p = self.tables.p
+
+        with testing.db.connect() as conn:
+            if a == 'delete':
+                conn.execute(q.insert().values(y=10, z=1))
+                cte = q.delete().\
+                    where(q.c.z == 1).returning(q.c.z).cte('c')
+                expected = None
+            elif a == "insert":
+                cte = q.insert().values(z=1, y=10).returning(q.c.z).cte('c')
+                expected = (2, 10)
+            elif a == "update":
+                conn.execute(q.insert().values(x=5, y=10, z=1))
+                cte = q.update().\
+                    where(q.c.z == 1).values(x=7).returning(q.c.z).cte('c')
+                expected = (7, 5)
+            elif a == "select":
+                conn.execute(q.insert().values(x=5, y=10, z=1))
+                cte = sa.select([q.c.z]).cte('c')
+                expected = (5, 10)
+
+            if b == "select":
+                conn.execute(p.insert().values(s=1))
+                stmt = select([p.c.s, cte.c.z])
+            elif b == "insert":
+                sel = select([1, cte.c.z, ])
+                stmt = p.insert().from_select(['s', 't'], sel).returning(
+                    p.c.s, p.c.t)
+            elif b == "delete":
+                stmt = p.insert().values(s=1, t=cte.c.z).returning(
+                    p.c.s, cte.c.z)
+            elif b == "update":
+                conn.execute(p.insert().values(s=1))
+                stmt = p.update().values(t=5).\
+                    where(p.c.s == cte.c.z).\
+                    returning(p.c.u, cte.c.z)
+            eq_(
+                conn.execute(stmt).fetchall(),
+                [(1, 1)]
+            )
+
+            eq_(
+                conn.execute(select([q.c.x, q.c.y])).fetchone(),
+                expected
+            )
+
+    def test_update_in_select(self):
+        self._test_a_in_b("update", "select")
+
+    def test_delete_in_select(self):
+        self._test_a_in_b("update", "select")
+
+    def test_insert_in_select(self):
+        self._test_a_in_b("update", "select")
+
+    def test_select_in_update(self):
+        self._test_a_in_b("select", "update")
+
+    def test_select_in_insert(self):
+        self._test_a_in_b("select", "insert")
+
+    # TODO: updates / inserts can be run in one statement w/ CTE ?
+    # deletes?
+
+
 class PKDefaultTest(fixtures.TablesTest):
     __requires__ = ('subqueries',)
     __backend__ = True