isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not has_identity(obj)
params = {}
+ value_params = {}
hasdata = False
for col in table.columns:
if col is mapper.version_id_col:
if history:
a = history.added_items()
if len(a):
- params[col.key] = prop.get_col_value(col, a[0])
+ if isinstance(a[0], sql.ClauseElement):
+ value_params[col] = a[0]
+ else:
+ params[col.key] = prop.get_col_value(col, a[0])
hasdata = True
else:
# doing an INSERT, non primary key col ?
if value is NO_ATTRIBUTE:
continue
if col.default is None or value is not None:
- params[col.key] = value
+ if isinstance(value, sql.ClauseElement):
+ value_params[col] = value
+ else:
+ params[col.key] = value
if not isinsert:
if hasdata:
# if none of the attributes changed, dont even
# add the row to be updated.
- update.append((obj, params, mapper, connection))
+ update.append((obj, params, mapper, connection, value_params))
else:
- insert.append((obj, params, mapper, connection))
+ insert.append((obj, params, mapper, connection, value_params))
if len(update):
mapper = table_to_mapper[table]
return 0
update.sort(comparator)
for rec in update:
- (obj, params, mapper, connection) = rec
- c = connection.execute(statement, params)
- mapper._postfetch(connection, table, obj, c, c.last_updated_params())
+ (obj, params, mapper, connection, value_params) = rec
+ c = connection.execute(statement.values(value_params), params)
+ mapper._postfetch(connection, table, obj, c, c.last_updated_params(), value_params)
updated_objects.add((obj, connection))
rows += c.rowcount
return cmp(a[0]._sa_insert_order, b[0]._sa_insert_order)
insert.sort(comparator)
for rec in insert:
- (obj, params, mapper, connection) = rec
- c = connection.execute(statement, params)
+ (obj, params, mapper, connection, value_params) = rec
+ c = connection.execute(statement.values(value_params), params)
primary_key = c.last_inserted_ids()
if primary_key is not None:
i = 0
if mapper.get_attr_by_column(obj, col) is None and len(primary_key) > i:
mapper.set_attr_by_column(obj, col, primary_key[i])
i+=1
- mapper._postfetch(connection, table, obj, c, c.last_inserted_params())
+ mapper._postfetch(connection, table, obj, c, c.last_inserted_params(), value_params)
# synchronize newly inserted ids from one table to the next
# TODO: this fires off more than needed, try to organize syncrules
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.after_update(mapper, connection, obj)
- def _postfetch(self, connection, table, obj, resultproxy, params):
+ def _postfetch(self, connection, table, obj, resultproxy, params, value_params):
"""After an ``INSERT`` or ``UPDATE``, assemble newly generated
values on an instance. For columns which are marked as being generated
on the database side, set up a group-based "deferred" loader
which will populate those attributes in one query when next accessed.
"""
- postfetch_cols = resultproxy.context.postfetch_cols()
+ postfetch_cols = resultproxy.context.postfetch_cols().union(util.Set(value_params.keys()))
deferred_props = []
for c in table.c:
- if c in postfetch_cols and not c.key in params:
+ if c in postfetch_cols and (not c.key in params or c in value_params):
prop = self._getpropbycolumn(c, raiseerror=False)
if prop is None:
continue
deferred_props.append(prop)
+ continue
if c.primary_key or not c.key in params:
continue
v = self.get_attr_by_column(obj, c, False)
Session.commit()
assert people.count(people.c.person=='im the key').scalar() == peoplesites.count(peoplesites.c.person=='im the key').scalar() == 1
+class ClauseAttributesTest(UnitOfWorkTest):
+ def setUpAll(self):
+ UnitOfWorkTest.setUpAll(self)
+ global metadata, users_table
+ metadata = MetaData(testbase.db)
+ users_table = Table('users', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(30)),
+ Column('counter', Integer, default=1))
+ metadata.create_all()
+
+ def tearDown(self):
+ users_table.delete().execute()
+ UnitOfWorkTest.tearDown(self)
+
+ def tearDownAll(self):
+ metadata.drop_all()
+ UnitOfWorkTest.tearDownAll(self)
+
+ def test_update(self):
+ class User(object):
+ pass
+ mapper(User, users_table)
+ u = User(name='test')
+ sess = Session()
+ sess.save(u)
+ sess.flush()
+ assert u.counter == 1
+ u.counter = users_table.c.counter + 1
+ sess.flush()
+ def go():
+ assert u.counter == 2
+ self.assert_sql_count(testbase.db, go, 1)
+
+ def test_multi_update(self):
+ class User(object):
+ pass
+ mapper(User, users_table)
+ u = User(name='test')
+ sess = Session()
+ sess.save(u)
+ sess.flush()
+ assert u.counter == 1
+ u.name = 'test2'
+ u.counter = users_table.c.counter + 1
+ sess.flush()
+ def go():
+ assert u.name == 'test2'
+ assert u.counter == 2
+ self.assert_sql_count(testbase.db, go, 1)
+
+ sess.clear()
+ u = sess.query(User).get(u.id)
+ assert u.name == 'test2'
+ assert u.counter == 2
+
+ def test_insert(self):
+ class User(object):
+ pass
+ mapper(User, users_table)
+ u = User(name='test', counter=select([5]))
+ sess = Session()
+ sess.save(u)
+ sess.flush()
+ assert u.counter == 5
+
+
+
class PassiveDeletesTest(UnitOfWorkTest):
def setUpAll(self):
UnitOfWorkTest.setUpAll(self)