elif not (params or value_params):
continue
+ has_all_pks = True
if bulk:
pk_params = dict(
(propkey_to_col[propkey]._label, state_dict.get(propkey))
else:
# else, use the old value to locate the row
pk_params[col._label] = history.deleted[0]
- params[col.key] = history.added[0]
+ if col in value_params:
+ has_all_pks = False
else:
pk_params[col._label] = history.unchanged[0]
if pk_params[col._label] is None:
params.update(pk_params)
yield (
state, state_dict, params, mapper,
- connection, value_params, has_all_defaults)
+ connection, value_params, has_all_defaults, has_all_pks)
def _collect_post_update_commands(base_mapper, uowtransaction, table,
cached_stmt = base_mapper._memo(('update', table), update_stmt)
- for (connection, paramkeys, hasvalue, has_all_defaults), \
+ for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \
records in groupby(
update,
lambda rec: (
rec[4], # connection
set(rec[2]), # set of parameter keys
bool(rec[5]), # whether or not we have "value" parameters
- rec[6] # has_all_defaults
+ rec[6], # has_all_defaults
+ rec[7] # has all pks
)
):
rows = 0
connection.dialect.supports_sane_multi_rowcount
allow_multirow = has_all_defaults and not needs_version_id
- if bookkeeping and not has_all_defaults and \
+ if not has_all_pks:
+ statement = statement.return_defaults()
+ elif bookkeeping and not has_all_defaults and \
mapper.base_mapper.eager_defaults:
statement = statement.return_defaults()
elif mapper.version_id_col is not None:
if hasvalue:
for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults in records:
+ connection, value_params, \
+ has_all_defaults, has_all_pks in records:
c = connection.execute(
statement.values(value_params),
params)
if not allow_multirow:
check_rowcount = assert_singlerow
for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults in records:
+ connection, value_params, has_all_defaults, \
+ has_all_pks in records:
c = cached_connections[connection].\
execute(statement, params)
rows += c.rowcount
for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults in records:
+ connection, value_params, \
+ has_all_defaults, has_all_pks in records:
if bookkeeping:
_postfetch(
mapper,
row = result.context.returned_defaults
if row is not None:
for col in returning_cols:
- if col.primary_key:
+ # pk cols returned from insert are handled
+ # distinctly, don't step on the values here
+ if col.primary_key and result.context.isinsert:
continue
dict_[mapper._columntoproperty[col].key] = row[col]
if refresh_flush:
assert sess.query(User).get('jack') is None
assert sess.query(User).get('ed').fullname == 'jack'
+ @testing.requires.returning
+ def test_update_to_sql_expr(self):
+ users, User = self.tables.users, self.classes.User
+
+ mapper(User, users)
+
+ sess = create_session()
+ u1 = User(username='jack', fullname='jack')
+
+ sess.add(u1)
+ sess.flush()
+
+ u1.username = User.username + ' jones'
+
+ sess.flush()
+
+ eq_(u1.username, 'jack jones')
+
+ def test_update_to_self_sql_expr(self):
+ # SQL expression where the PK won't actually change,
+ # such as to bump a server side trigger
+ users, User = self.tables.users, self.classes.User
+
+ mapper(User, users)
+
+ sess = create_session()
+ u1 = User(username='jack', fullname='jack')
+
+ sess.add(u1)
+ sess.flush()
+
+ u1.username = User.username + ''
+
+ sess.flush()
+
+ eq_(u1.username, 'jack')
+
def test_flush_new_pk_after_expire(self):
User, users = self.classes.User, self.tables.users