the ``CLUSTERED`` keyword to the constraint construct within DDL.
Pullreq courtesy Derek Harland.
+ .. change::
+ :tags: bug, sql, orm
+ :tickets: 2912
+
+ Fixed the multiple-table "UPDATE..FROM" construct, only usable on
+ MySQL, to correctly render the SET clause among multiple columns
+ with the same name across tables. This also changes the name used for
+ the bound parameter in the SET clause to "<tablename>_<colname>" for
+ the non-primary table only; as this parameter is typically specified
+ using the :class:`.Column` object directly this should not have an
+ impact on applications. The fix takes effect for both
+ :meth:`.Table.update` as well as :meth:`.Query.update` in the ORM.
+
.. change::
:tags: bug, oracle
:tickets: 2911
and generate inserted_primary_key collection.
"""
+ key_getter = self.compiled._key_getters_for_crud_column[2]
+
if self.executemany:
if len(self.compiled.prefetch):
scalar_defaults = {}
else:
val = self.get_update_default(c)
if val is not None:
- param[c.key] = val
+ param[key_getter(c)] = val
del self.current_parameters
else:
self.current_parameters = compiled_parameters = \
val = self.get_update_default(c)
if val is not None:
- compiled_parameters[c.key] = val
+ compiled_parameters[key_getter(c)] = val
del self.current_parameters
if self.isinsert:
self.inserted_primary_key = [
- self.compiled_parameters[0].get(c.key, None)
+ self.compiled_parameters[0].get(key_getter(c), None)
for c in self.compiled.\
statement.table.primary_key
]
from .. import util, exc
import decimal
import itertools
+import operator
RESERVED_WORDS = set([
'all', 'analyse', 'analyze', 'and', 'any', 'array',
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
extra_froms, **kw)
- colparams = self._get_colparams(update_stmt, extra_froms, **kw)
+ colparams = self._get_colparams(update_stmt, **kw)
if update_stmt._hints:
dialect_hints = dict([
bindparam._is_crud = True
return bindparam._compiler_dispatch(self)
- def _get_colparams(self, stmt, extra_tables=None, **kw):
+ @util.memoized_property
+ def _key_getters_for_crud_column(self):
+ if self.isupdate and self.statement._extra_froms:
+ # when extra tables are present, refer to the columns
+ # in those extra tables as table-qualified, including in
+ # dictionaries and when rendering bind param names.
+ # the "main" table of the statement remains unqualified,
+ # allowing the most compatibility with a non-multi-table
+ # statement.
+ _et = set(self.statement._extra_froms)
+ def _column_as_key(key):
+ str_key = elements._column_as_key(key)
+ if hasattr(key, 'table') and key.table in _et:
+ return (key.table.name, str_key)
+ else:
+ return str_key
+ def _getattr_col_key(col):
+ if col.table in _et:
+ return (col.table.name, col.key)
+ else:
+ return col.key
+ def _col_bind_name(col):
+ if col.table in _et:
+ return "%s_%s" % (col.table.name, col.key)
+ else:
+ return col.key
+
+ else:
+ _column_as_key = elements._column_as_key
+ _getattr_col_key = _col_bind_name = operator.attrgetter("key")
+
+ return _column_as_key, _getattr_col_key, _col_bind_name
+
+ def _get_colparams(self, stmt, **kw):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
else:
stmt_parameters = stmt.parameters
+ # getters - these are normally just column.key,
+ # but in the case of mysql multi-table update, the rules for
+ # .key must conditionally take tablename into account
+ _column_as_key, _getattr_col_key, _col_bind_name = \
+ self._key_getters_for_crud_column
+
# if we have statement parameters - set defaults in the
# compiled params
if self.column_keys is None:
parameters = {}
else:
- parameters = dict((elements._column_as_key(key), REQUIRED)
+ parameters = dict((_column_as_key(key), REQUIRED)
for key in self.column_keys
if not stmt_parameters or
key not in stmt_parameters)
if stmt_parameters is not None:
for k, v in stmt_parameters.items():
- colkey = elements._column_as_key(k)
+ colkey = _column_as_key(k)
if colkey is not None:
parameters.setdefault(colkey, v)
else:
# add it to values() in an "as-is" state,
# coercing right side to bound param
if elements._is_literal(v):
- v = self.process(elements.BindParameter(None, v, type_=k.type), **kw)
+ v = self.process(
+ elements.BindParameter(None, v, type_=k.type),
+ **kw)
else:
v = self.process(v.self_group(), **kw)
postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
check_columns = {}
+
# special logic that only occurs for multi-table UPDATE
# statements
- if extra_tables and stmt_parameters:
+ if self.isupdate and stmt._extra_froms and stmt_parameters:
normalized_params = dict(
(elements._clause_element_as_expr(c), param)
for c, param in stmt_parameters.items()
)
- assert self.isupdate
affected_tables = set()
- for t in extra_tables:
+ for t in stmt._extra_froms:
for c in t.c:
if c in normalized_params:
affected_tables.add(t)
- check_columns[c.key] = c
+ check_columns[_getattr_col_key(c)] = c
value = normalized_params[c]
if elements._is_literal(value):
value = self._create_crud_bind_param(
- c, value, required=value is REQUIRED)
+ c, value, required=value is REQUIRED,
+ name=_col_bind_name(c))
else:
self.postfetch.append(c)
value = self.process(value.self_group(), **kw)
elif c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, self.process(c.onupdate.arg.self_group(), **kw))
+ (c, self.process(
+ c.onupdate.arg.self_group(),
+ **kw)
+ )
)
self.postfetch.append(c)
else:
values.append(
- (c, self._create_crud_bind_param(c, None))
+ (c, self._create_crud_bind_param(
+ c, None, name=_col_bind_name(c)
+ )
+ )
)
self.prefetch.append(c)
elif c.server_onupdate is not None:
if self.isinsert and stmt.select_names:
# for an insert from select, we can only use names that
# are given, so only select for those names.
- cols = (stmt.table.c[elements._column_as_key(name)]
+ cols = (stmt.table.c[_column_as_key(name)]
for name in stmt.select_names)
else:
# iterate through all table columns to maintain
cols = stmt.table.columns
for c in cols:
- if c.key in parameters and c.key not in check_columns:
- value = parameters.pop(c.key)
+ col_key = _getattr_col_key(c)
+ if col_key in parameters and col_key not in check_columns:
+ value = parameters.pop(col_key)
if elements._is_literal(value):
value = self._create_crud_bind_param(
c, value, required=value is REQUIRED,
- name=c.key
+ name=_col_bind_name(c)
if not stmt._has_multi_parameters
- else "%s_0" % c.key
+ else "%s_0" % _col_bind_name(c)
)
else:
if isinstance(value, elements.BindParameter) and \
if parameters and stmt_parameters:
check = set(parameters).intersection(
- elements._column_as_key(k) for k in stmt.parameters
+ _column_as_key(k) for k in stmt.parameters
).difference(check_columns)
if check:
raise exc.CompileError(
"Unconsumed column names: %s" %
- (", ".join(check))
+ (", ".join("%s" % c for c in check))
)
if stmt._has_multi_parameters:
values.extend(
[
- (
- c,
- self._create_crud_bind_param(
- c, row[c.key],
- name="%s_%d" % (c.key, i + 1)
- )
- if c.key in row else param
- )
- for (c, param) in values_0
- ]
- for i, row in enumerate(stmt.parameters[1:])
+ (
+ c,
+ self._create_crud_bind_param(
+ c, row[c.key],
+ name="%s_%d" % (c.key, i + 1)
+ )
+ if c.key in row else param
+ )
+ for (c, param) in values_0
+ ]
+ for i, row in enumerate(stmt.parameters[1:])
)
return values
def define_tables(cls, metadata):
Table('users', metadata,
Column('id', Integer, primary_key=True),
+ Column('samename', String(10)),
)
Table('documents', metadata,
Column('id', Integer, primary_key=True),
Column('user_id', None, ForeignKey('users.id')),
Column('title', String(32)),
- Column('flag', Boolean)
+ Column('flag', Boolean),
+ Column('samename', String(10)),
)
@classmethod
])
)
+ @testing.only_on('mysql', 'Multi table update')
+ def test_update_from_multitable_same_names(self):
+ Document = self.classes.Document
+ User = self.classes.User
+
+ s = Session()
+
+ s.query(Document).\
+ filter(User.id == Document.user_id).\
+ filter(User.id == 2).update({
+ Document.samename: 'd_samename',
+ User.samename: 'u_samename'
+ }
+ )
+ eq_(
+ s.query(User.id, Document.samename, User.samename).
+ filter(User.id == Document.user_id).
+ order_by(User.id).all(),
+ [
+ (1, None, None),
+ (1, None, None),
+ (2, 'd_samename', 'u_samename'),
+ (2, 'd_samename', 'u_samename'),
+ (3, None, None),
+ (3, None, None),
+ ]
+ )
+
class ExpressionUpdateTest(fixtures.MappedTest):
@classmethod
def define_tables(cls, metadata):
set(s.query(Person.name, Engineer.engineer_name)),
set([('e1', 'e1', ), ('e22', 'e55')])
)
+
+
'UPDATE A B C D mytable SET myid=%s, name=%s, description=%s',
dialect=mysql.dialect())
- def test_alias(self):
- table1 = self.tables.mytable
- talias1 = table1.alias('t1')
-
- self.assert_compile(update(talias1, talias1.c.myid == 7),
- 'UPDATE mytable AS t1 '
- 'SET name=:name '
- 'WHERE t1.myid = :myid_1',
- params={table1.c.name: 'fred'})
-
- self.assert_compile(update(talias1, table1.c.myid == 7),
- 'UPDATE mytable AS t1 '
- 'SET name=:name '
- 'FROM mytable '
- 'WHERE mytable.myid = :myid_1',
- params={table1.c.name: 'fred'})
def test_update_to_expression(self):
"""test update from an expression.
run_create_tables = run_inserts = run_deletes = None
+ def test_alias_one(self):
+ table1 = self.tables.mytable
+ talias1 = table1.alias('t1')
+
+ # this case is nonsensical. the UPDATE is entirely
+ # against the alias, but we name the table-bound column
+ # in values. The behavior here isn't really defined
+ self.assert_compile(
+ update(talias1, talias1.c.myid == 7).
+ values({table1.c.name: "fred"}),
+ 'UPDATE mytable AS t1 '
+ 'SET name=:name '
+ 'WHERE t1.myid = :myid_1')
+
+ def test_alias_two(self):
+ table1 = self.tables.mytable
+ talias1 = table1.alias('t1')
+
+ # Here, compared to
+ # test_alias_one(), here we actually have UPDATE..FROM,
+ # which is causing the "table1.c.name" param to be handled
+ # as an "extra table", hence we see the full table name rendered.
+ self.assert_compile(
+ update(talias1, table1.c.myid == 7).
+ values({table1.c.name: 'fred'}),
+ 'UPDATE mytable AS t1 '
+ 'SET name=:mytable_name '
+ 'FROM mytable '
+ 'WHERE mytable.myid = :myid_1',
+ checkparams={'mytable_name': 'fred', 'myid_1': 7},
+ )
+
+ def test_alias_two_mysql(self):
+ table1 = self.tables.mytable
+ talias1 = table1.alias('t1')
+
+ self.assert_compile(
+ update(talias1, table1.c.myid == 7).
+ values({table1.c.name: 'fred'}),
+ "UPDATE mytable AS t1, mytable SET mytable.name=%s "
+ "WHERE mytable.myid = %s",
+ checkparams={'mytable_name': 'fred', 'myid_1': 7},
+ dialect='mysql')
+
+ def test_update_from_multitable_same_name_mysql(self):
+ users, addresses = self.tables.users, self.tables.addresses
+
+ self.assert_compile(
+ users.update().
+ values(name='newname').\
+ values({addresses.c.name: "new address"}).\
+ where(users.c.id == addresses.c.user_id),
+ "UPDATE users, addresses SET addresses.name=%s, "
+ "users.name=%s WHERE users.id = addresses.user_id",
+ checkparams={u'addresses_name': 'new address', 'name': 'newname'},
+ dialect='mysql'
+ )
+
def test_render_table(self):
users, addresses = self.tables.users, self.tables.addresses
(10, 'chuck')]
self._assert_users(users, expected)
+ @testing.only_on('mysql', 'Multi table update')
+ def test_exec_multitable_same_name(self):
+ users, addresses = self.tables.users, self.tables.addresses
+
+ values = {
+ addresses.c.name: 'ad_ed2',
+ users.c.name: 'ed2'
+ }
+
+ testing.db.execute(
+ addresses.update().
+ values(values).
+ where(users.c.id == addresses.c.user_id).
+ where(users.c.name == 'ed'))
+
+ expected = [
+ (1, 7, 'x', 'jack@bean.com'),
+ (2, 8, 'ad_ed2', 'ed@wood.com'),
+ (3, 8, 'ad_ed2', 'ed@bettyboop.com'),
+ (4, 8, 'ad_ed2', 'ed@lala.com'),
+ (5, 9, 'x', 'fred@fred.com')]
+ self._assert_addresses(addresses, expected)
+
+ expected = [
+ (7, 'jack'),
+ (8, 'ed2'),
+ (9, 'fred'),
+ (10, 'chuck')]
+ self._assert_users(users, expected)
+
def _assert_addresses(self, addresses, expected):
stmt = addresses.select().order_by(addresses.c.id)
eq_(testing.db.execute(stmt).fetchall(), expected)
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('user_id', None, ForeignKey('users.id')),
- Column('email_address', String(50), nullable=False))
+ Column('email_address', String(50), nullable=False),
+ )
+
+ Table('foobar', metadata,
+ Column('id', Integer, primary_key=True,
+ test_needs_autoincrement=True),
+ Column('user_id', None, ForeignKey('users.id')),
+ Column('data', String(30)),
+ Column('some_update', String(30), onupdate='im the other update')
+ )
@classmethod
def fixtures(cls):
(3, 8, 'ed@bettyboop.com'),
(4, 9, 'fred@fred.com')
),
+ foobar=(
+ ('id', 'user_id', 'data'),
+ (2, 8, 'd1'),
+ (3, 8, 'd2'),
+ (4, 9, 'd3')
+ )
)
@testing.only_on('mysql', 'Multi table update')
(9, 'fred', 'value')]
self._assert_users(users, expected)
+ @testing.only_on('mysql', 'Multi table update')
+ def test_defaults_second_table_same_name(self):
+ users, foobar = self.tables.users, self.tables.foobar
+
+ values = {
+ foobar.c.data: foobar.c.data + 'a',
+ users.c.name: 'ed2'
+ }
+
+ ret = testing.db.execute(
+ users.update().
+ values(values).
+ where(users.c.id == foobar.c.user_id).
+ where(users.c.name == 'ed'))
+
+ eq_(
+ set(ret.prefetch_cols()),
+ set([users.c.some_update, foobar.c.some_update])
+ )
+
+ expected = [
+ (2, 8, 'd1a', 'im the other update'),
+ (3, 8, 'd2a', 'im the other update'),
+ (4, 9, 'd3', None)]
+ self._assert_foobar(foobar, expected)
+
+ expected = [
+ (8, 'ed2', 'im the update'),
+ (9, 'fred', 'value')]
+ self._assert_users(users, expected)
+
@testing.only_on('mysql', 'Multi table update')
def test_no_defaults_second_table(self):
users, addresses = self.tables.users, self.tables.addresses
(9, 'fred', 'value')]
self._assert_users(users, expected)
+ def _assert_foobar(self, foobar, expected):
+ stmt = foobar.select().order_by(foobar.c.id)
+ eq_(testing.db.execute(stmt).fetchall(), expected)
+
def _assert_addresses(self, addresses, expected):
stmt = addresses.select().order_by(addresses.c.id)
eq_(testing.db.execute(stmt).fetchall(), expected)