.. changelog::
:version: 1.1.0b3
+ .. change::
+ :tags: bug, sql
+ :tickets: 3745
+
+ Fixed bug in new CTE feature for update/insert/delete stated
+ as a CTE inside of an enclosing statement (typically SELECT) whereby
+ oninsert and onupdate values weren't called upon for the embedded
+ statement.
+
.. change::
:tags: bug, ext
self._is_implicit_returning = bool(
compiled.returning and not compiled.statement._returning)
- if not self.isdelete:
- if self.compiled.prefetch:
- if self.executemany:
- self._process_executemany_defaults()
- else:
- self._process_executesingle_defaults()
+ if self.compiled.insert_prefetch or self.compiled.update_prefetch:
+ if self.executemany:
+ self._process_executemany_defaults()
+ else:
+ self._process_executesingle_defaults()
processors = compiled._bind_processors
@util.memoized_property
def prefetch_cols(self):
- return self.compiled.prefetch
+ if self.isinsert:
+ return self.compiled.insert_prefetch
+ elif self.isupdate:
+ return self.compiled.update_prefetch
+ else:
+ return ()
@util.memoized_property
def returning_cols(self):
def _process_executemany_defaults(self):
key_getter = self.compiled._key_getters_for_crud_column[2]
- prefetch = self.compiled.prefetch
scalar_defaults = {}
+ insert_prefetch = self.compiled.insert_prefetch
+ update_prefetch = self.compiled.update_prefetch
+
# pre-determine scalar Python-side defaults
# to avoid many calls of get_insert_default()/
# get_update_default()
- for c in prefetch:
- if self.isinsert and c.default and c.default.is_scalar:
+ for c in insert_prefetch:
+ if c.default and c.default.is_scalar:
scalar_defaults[c] = c.default.arg
- elif self.isupdate and c.onupdate and c.onupdate.is_scalar:
+ for c in update_prefetch:
+ if c.onupdate and c.onupdate.is_scalar:
scalar_defaults[c] = c.onupdate.arg
for param in self.compiled_parameters:
self.current_parameters = param
- for c in prefetch:
+ for c in insert_prefetch:
if c in scalar_defaults:
val = scalar_defaults[c]
- elif self.isinsert:
+ else:
val = self.get_insert_default(c)
+ if val is not None:
+ param[key_getter(c)] = val
+ for c in update_prefetch:
+ if c in scalar_defaults:
+ val = scalar_defaults[c]
else:
val = self.get_update_default(c)
if val is not None:
param[key_getter(c)] = val
+
del self.current_parameters
def _process_executesingle_defaults(self):
key_getter = self.compiled._key_getters_for_crud_column[2]
- prefetch = self.compiled.prefetch
self.current_parameters = compiled_parameters = \
self.compiled_parameters[0]
- for c in prefetch:
- if self.isinsert:
- if c.default and \
- not c.default.is_sequence and c.default.is_scalar:
- val = c.default.arg
- else:
- val = self.get_insert_default(c)
+ for c in self.compiled.insert_prefetch:
+ if c.default and \
+ not c.default.is_sequence and c.default.is_scalar:
+ val = c.default.arg
else:
- val = self.get_update_default(c)
+ val = self.get_insert_default(c)
+
+ if val is not None:
+ compiled_parameters[key_getter(c)] = val
+
+ for c in self.compiled.update_prefetch:
+ val = self.get_update_default(c)
if val is not None:
compiled_parameters[key_getter(c)] = val
True unless using an unordered TextAsFrom.
"""
+ insert_prefetch = update_prefetch = ()
+
def __init__(self, dialect, statement, column_keys=None,
inline=False, **kwargs):
"""Construct a new :class:`.SQLCompiler` object.
if self.positional and dialect.paramstyle == 'numeric':
self._apply_numbered_params()
+ @property
+ def prefetch(self):
+ return list(self.insert_prefetch + self.update_prefetch)
+
@util.memoized_instancemethod
def _init_cte_state(self):
"""Initialize collections related to CTEs only if
"""
from .. import util
from .. import exc
+from . import dml
from . import elements
import operator
"""
compiler.postfetch = []
- compiler.prefetch = []
+ compiler.insert_prefetch = []
+ compiler.update_prefetch = []
compiler.returning = []
# no parameters in the statement, no parameters in the
compiler.returning.append(c)
else:
values.append(
- (c, _create_prefetch_bind_param(compiler, c))
+ (c, _create_insert_prefetch_bind_param(compiler, c))
)
elif c is stmt.table._autoincrement_column or c.server_default is not None:
compiler.returning.append(c)
_raise_pk_with_no_anticipated_value(c)
-def _create_prefetch_bind_param(compiler, c, process=True, name=None):
+def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None):
param = _create_bind_param(compiler, c, None, process=process, name=name)
- compiler.prefetch.append(c)
+ compiler.insert_prefetch.append(c)
+ return param
+
+
+def _create_update_prefetch_bind_param(compiler, c, process=True, name=None):
+ param = _create_bind_param(compiler, c, None, process=process, name=name)
+ compiler.update_prefetch.append(c)
return param
other.original == self.original
-def _process_multiparam_default_bind(compiler, c, index, kw):
+def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
if not c.default:
raise exc.CompileError(
return compiler.process(c.default.arg.self_group(), **kw)
else:
col = _multiparam_column(c, index)
- return _create_prefetch_bind_param(compiler, col)
+ if isinstance(stmt, dml.Insert):
+ return _create_insert_prefetch_bind_param(compiler, col)
+ else:
+ return _create_update_prefetch_bind_param(compiler, col)
def _append_param_insert_pk(compiler, stmt, c, values, kw):
)
):
values.append(
- (c, _create_prefetch_bind_param(compiler, c))
+ (c, _create_insert_prefetch_bind_param(compiler, c))
)
elif c.default is None and c.server_default is None and not c.nullable:
# no .default, no .server_default, not autoincrement, we have
compiler.postfetch.append(c)
else:
values.append(
- (c, _create_prefetch_bind_param(compiler, c))
+ (c, _create_insert_prefetch_bind_param(compiler, c))
)
values.append((c, proc))
else:
values.append(
- (c, _create_prefetch_bind_param(compiler, c, process=False))
+ (c, _create_insert_prefetch_bind_param(compiler, c, process=False))
)
compiler.postfetch.append(c)
else:
values.append(
- (c, _create_prefetch_bind_param(compiler, c))
+ (c, _create_update_prefetch_bind_param(compiler, c))
)
elif c.server_onupdate is not None:
if implicit_return_defaults and \
compiler.postfetch.append(c)
else:
values.append(
- (c, _create_prefetch_bind_param(
+ (c, _create_update_prefetch_bind_param(
compiler, c, name=_col_bind_name(c)))
)
elif c.server_onupdate is not None:
else compiler.process(
row[c.key].self_group(), **kw))
if c.key in row else
- _process_multiparam_default_bind(compiler, c, i, kw)
+ _process_multiparam_default_bind(compiler, stmt, c, i, kw)
)
for (c, param) in values_0
]
return skip_if(exclude('mysql', '<', (4, 1, 1)), 'no subquery support')
+ @property
+ def ctes(self):
+ """Target database supports CTEs"""
+
+ return only_if(
+ ['postgresql', 'mssql']
+ )
+
@property
def mod_operator_as_percent_sign(self):
"""target database must use a plain percent '%' as the 'modulus'
eq_(55, l['col3'])
+class CTEDefaultTest(fixtures.TablesTest):
+ __requires__ = ('ctes',)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ 'q', metadata,
+ Column('x', Integer, default=2),
+ Column('y', Integer, onupdate=5),
+ Column('z', Integer)
+ )
+
+ Table(
+ 'p', metadata,
+ Column('s', Integer),
+ Column('t', Integer),
+ Column('u', Integer, onupdate=1)
+ )
+
+ def _test_a_in_b(self, a, b):
+ q = self.tables.q
+ p = self.tables.p
+
+ with testing.db.connect() as conn:
+ if a == 'delete':
+ conn.execute(q.insert().values(y=10, z=1))
+ cte = q.delete().\
+ where(q.c.z == 1).returning(q.c.z).cte('c')
+ expected = None
+ elif a == "insert":
+ cte = q.insert().values(z=1, y=10).returning(q.c.z).cte('c')
+ expected = (2, 10)
+ elif a == "update":
+ conn.execute(q.insert().values(x=5, y=10, z=1))
+ cte = q.update().\
+ where(q.c.z == 1).values(x=7).returning(q.c.z).cte('c')
+ expected = (7, 5)
+ elif a == "select":
+ conn.execute(q.insert().values(x=5, y=10, z=1))
+ cte = sa.select([q.c.z]).cte('c')
+ expected = (5, 10)
+
+ if b == "select":
+ conn.execute(p.insert().values(s=1))
+ stmt = select([p.c.s, cte.c.z])
+ elif b == "insert":
+ sel = select([1, cte.c.z, ])
+ stmt = p.insert().from_select(['s', 't'], sel).returning(
+ p.c.s, p.c.t)
+ elif b == "delete":
+ stmt = p.insert().values(s=1, t=cte.c.z).returning(
+ p.c.s, cte.c.z)
+ elif b == "update":
+ conn.execute(p.insert().values(s=1))
+ stmt = p.update().values(t=5).\
+ where(p.c.s == cte.c.z).\
+ returning(p.c.u, cte.c.z)
+ eq_(
+ conn.execute(stmt).fetchall(),
+ [(1, 1)]
+ )
+
+ eq_(
+ conn.execute(select([q.c.x, q.c.y])).fetchone(),
+ expected
+ )
+
+ def test_update_in_select(self):
+ self._test_a_in_b("update", "select")
+
+ def test_delete_in_select(self):
+ self._test_a_in_b("update", "select")
+
+ def test_insert_in_select(self):
+ self._test_a_in_b("update", "select")
+
+ def test_select_in_update(self):
+ self._test_a_in_b("select", "update")
+
+ def test_select_in_insert(self):
+ self._test_a_in_b("select", "insert")
+
+ # TODO: updates / inserts can be run in one statement w/ CTE ?
+ # deletes?
+
+
class PKDefaultTest(fixtures.TablesTest):
__requires__ = ('subqueries',)
__backend__ = True