From: Mike Bayer Date: Sun, 5 Mar 2006 21:01:21 +0000 (+0000) Subject: got mapper to receive the onupdates after updating an instance (also properly receive... X-Git-Tag: rel_0_1_4~41 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9446e490cb14d456c96b4731993e2a965a090b1a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git got mapper to receive the onupdates after updating an instance (also properly receives defaults on inserts)... --- diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 3703169fa0..5f681d39e9 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -484,20 +484,33 @@ class SQLEngine(schema.SchemaEngine): self.context.last_inserted_ids = None else: self.context.last_inserted_ids = last_inserted_ids + self.context.last_inserted_params = param elif getattr(compiled, 'isupdate', False): if isinstance(parameters, list): plist = parameters else: plist = [parameters] drunner = self.defaultrunner(proxy) + self.context.lastrow_has_defaults = False for param in plist: for c in compiled.statement.table.c: if c.onupdate is not None and (not param.has_key(c.name) or param[c.name] is None): value = drunner.get_column_onupdate(c) if value is not None: param[c.name] = value - + self.context.last_updated_params = param + + def last_inserted_params(self): + """returns a dictionary of the full parameter dictionary for the last compiled INSERT statement, + including any ColumnDefaults or Sequences that were pre-executed. this value is thread-local.""" + return self.context.last_inserted_params + def last_updated_params(self): + """returns a dictionary of the full parameter dictionary for the last compiled UPDATE statement, + including any ColumnDefaults that were pre-executed. this value is thread-local.""" + return self.context.last_updated_params def lastrow_has_defaults(self): + """returns True if the last row INSERTED via a compiled insert statement contained PassiveDefaults, + indicating that the database inserted data beyond that which we gave it. this value is thread-local.""" return self.context.lastrow_has_defaults def pre_exec(self, proxy, compiled, parameters, **kwargs): diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index 85fe8dc696..8239df99ce 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -591,6 +591,7 @@ class Mapper(object): for rec in update: (obj, params) = rec c = statement.execute(params) + self._postfetch(table, obj, table.engine.last_updated_params()) self.extension.after_update(self, obj) rows += c.cursor.rowcount if table.engine.supports_sane_rowcount() and rows != len(update): @@ -608,18 +609,30 @@ class Mapper(object): if self._getattrbycolumn(obj, col) is None: self._setattrbycolumn(obj, col, primary_key[i]) i+=1 - if table.engine.lastrow_has_defaults(): - clause = sql.and_() - for p in self.pks_by_table[table]: - clause.clauses.append(p == self._getattrbycolumn(obj, p)) - row = table.select(clause).execute().fetchone() - for c in table.c: - if self._getattrbycolumn(obj, c) is None: - self._setattrbycolumn(obj, c, row[c]) + self._postfetch(table, obj, table.engine.last_inserted_params()) if self._synchronizer is not None: self._synchronizer.execute(obj, obj) self.extension.after_insert(self, obj) - + + def _postfetch(self, table, obj, params): + """after an INSERT or UPDATE, asks the engine if PassiveDefaults fired off on the database side + which need to be post-fetched, *or* if pre-exec defaults like ColumnDefaults were fired off + and should be populated into the instance. this is only for non-primary key columns.""" + if table.engine.lastrow_has_defaults(): + clause = sql.and_() + for p in self.pks_by_table[table]: + clause.clauses.append(p == self._getattrbycolumn(obj, p)) + row = table.select(clause).execute().fetchone() + for c in table.c: + if self._getattrbycolumn(obj, c) is None: + self._setattrbycolumn(obj, c, row[c]) + else: + for c in table.c: + if c.primary_key or not params.has_key(c.name): + continue + if self._getattrbycolumn(obj, c) != params[c.name]: + self._setattrbycolumn(obj, c, params[c.name]) + def delete_obj(self, objects, uow): """called by a UnitOfWork object to delete objects, which involves a DELETE statement for each table used by this mapper, for each object in the list.""" diff --git a/test/objectstore.py b/test/objectstore.py index 63a39641e8..a4d2e874cb 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -229,11 +229,13 @@ class DefaultTest(AssertMixin): Column('id', Integer, Sequence("dt_seq", optional=True), primary_key=True), Column('hoho', hohotype, PassiveDefault(str(self.hohoval))), Column('counter', Integer, PassiveDefault("7")), - Column('foober', String(30), default="im foober") + Column('foober', String(30), default="im foober", onupdate="im the update") ) self.table.create() def tearDownAll(self): self.table.drop() + def setUp(self): + self.table = Table('default_test', db) def testbasic(self): class Hoho(object):pass @@ -261,7 +263,17 @@ class DefaultTest(AssertMixin): self.assert_(h1.counter == h4.counter==h5.counter==7) self.assert_(h2.foober == h3.foober == h4.foober == 'im foober') self.assert_(h5.foober=='im the new foober') - + + def testupdate(self): + class Hoho(object):pass + assign_mapper(Hoho, self.table) + h1 = Hoho() + objectstore.commit() + self.assert_(h1.foober == 'im foober') + h1.counter = 19 + objectstore.commit() + self.assert_(h1.foober == 'im the update') + class SaveTest(AssertMixin): def setUpAll(self):