self.context.last_inserted_ids = None
else:
self.context.last_inserted_ids = last_inserted_ids
+ self.context.last_inserted_params = param
elif getattr(compiled, 'isupdate', False):
if isinstance(parameters, list):
plist = parameters
else:
plist = [parameters]
drunner = self.defaultrunner(proxy)
+ self.context.lastrow_has_defaults = False
for param in plist:
for c in compiled.statement.table.c:
if c.onupdate is not None and (not param.has_key(c.name) or param[c.name] is None):
value = drunner.get_column_onupdate(c)
if value is not None:
param[c.name] = value
-
+ self.context.last_updated_params = param
+
+ def last_inserted_params(self):
+ """returns a dictionary of the full parameter dictionary for the last compiled INSERT statement,
+ including any ColumnDefaults or Sequences that were pre-executed. this value is thread-local."""
+ return self.context.last_inserted_params
+ def last_updated_params(self):
+ """returns a dictionary of the full parameter dictionary for the last compiled UPDATE statement,
+ including any ColumnDefaults that were pre-executed. this value is thread-local."""
+ return self.context.last_updated_params
def lastrow_has_defaults(self):
+ """returns True if the last row INSERTED via a compiled insert statement contained PassiveDefaults,
+ indicating that the database inserted data beyond that which we gave it. this value is thread-local."""
return self.context.lastrow_has_defaults
def pre_exec(self, proxy, compiled, parameters, **kwargs):
for rec in update:
(obj, params) = rec
c = statement.execute(params)
+ self._postfetch(table, obj, table.engine.last_updated_params())
self.extension.after_update(self, obj)
rows += c.cursor.rowcount
if table.engine.supports_sane_rowcount() and rows != len(update):
if self._getattrbycolumn(obj, col) is None:
self._setattrbycolumn(obj, col, primary_key[i])
i+=1
- if table.engine.lastrow_has_defaults():
- clause = sql.and_()
- for p in self.pks_by_table[table]:
- clause.clauses.append(p == self._getattrbycolumn(obj, p))
- row = table.select(clause).execute().fetchone()
- for c in table.c:
- if self._getattrbycolumn(obj, c) is None:
- self._setattrbycolumn(obj, c, row[c])
+ self._postfetch(table, obj, table.engine.last_inserted_params())
if self._synchronizer is not None:
self._synchronizer.execute(obj, obj)
self.extension.after_insert(self, obj)
-
+
+ def _postfetch(self, table, obj, params):
+ """after an INSERT or UPDATE, asks the engine if PassiveDefaults fired off on the database side
+ which need to be post-fetched, *or* if pre-exec defaults like ColumnDefaults were fired off
+ and should be populated into the instance. this is only for non-primary key columns."""
+ if table.engine.lastrow_has_defaults():
+ clause = sql.and_()
+ for p in self.pks_by_table[table]:
+ clause.clauses.append(p == self._getattrbycolumn(obj, p))
+ row = table.select(clause).execute().fetchone()
+ for c in table.c:
+ if self._getattrbycolumn(obj, c) is None:
+ self._setattrbycolumn(obj, c, row[c])
+ else:
+ for c in table.c:
+ if c.primary_key or not params.has_key(c.name):
+ continue
+ if self._getattrbycolumn(obj, c) != params[c.name]:
+ self._setattrbycolumn(obj, c, params[c.name])
+
def delete_obj(self, objects, uow):
"""called by a UnitOfWork object to delete objects, which involves a
DELETE statement for each table used by this mapper, for each object in the list."""
Column('id', Integer, Sequence("dt_seq", optional=True), primary_key=True),
Column('hoho', hohotype, PassiveDefault(str(self.hohoval))),
Column('counter', Integer, PassiveDefault("7")),
- Column('foober', String(30), default="im foober")
+ Column('foober', String(30), default="im foober", onupdate="im the update")
)
self.table.create()
def tearDownAll(self):
self.table.drop()
+ def setUp(self):
+ self.table = Table('default_test', db)
def testbasic(self):
class Hoho(object):pass
self.assert_(h1.counter == h4.counter==h5.counter==7)
self.assert_(h2.foober == h3.foober == h4.foober == 'im foober')
self.assert_(h5.foober=='im the new foober')
-
+
+ def testupdate(self):
+ class Hoho(object):pass
+ assign_mapper(Hoho, self.table)
+ h1 = Hoho()
+ objectstore.commit()
+ self.assert_(h1.foober == 'im foober')
+ h1.counter = 19
+ objectstore.commit()
+ self.assert_(h1.foober == 'im the update')
+
class SaveTest(AssertMixin):
def setUpAll(self):