]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
got mapper to receive the onupdates after updating an instance (also properly receive...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Mar 2006 21:01:21 +0000 (21:01 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Mar 2006 21:01:21 +0000 (21:01 +0000)
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapping/mapper.py
test/objectstore.py

index 3703169fa03288c03273847f57f8d85c1c9071c5..5f681d39e907fe544681efcb915ba3aed1f038c8 100644 (file)
@@ -484,20 +484,33 @@ class SQLEngine(schema.SchemaEngine):
                     self.context.last_inserted_ids = None
                 else:
                     self.context.last_inserted_ids = last_inserted_ids
+                self.context.last_inserted_params = param
         elif getattr(compiled, 'isupdate', False):
             if isinstance(parameters, list):
                 plist = parameters
             else:
                 plist = [parameters]
             drunner = self.defaultrunner(proxy)
+            self.context.lastrow_has_defaults = False
             for param in plist:
                 for c in compiled.statement.table.c:
                     if c.onupdate is not None and (not param.has_key(c.name) or param[c.name] is None):
                         value = drunner.get_column_onupdate(c)
                         if value is not None:
                             param[c.name] = value
-                        
+                self.context.last_updated_params = param
+    
+    def last_inserted_params(self):
+        """returns a dictionary of the full parameter dictionary for the last compiled INSERT statement,
+        including any ColumnDefaults or Sequences that were pre-executed.  this value is thread-local."""
+        return self.context.last_inserted_params
+    def last_updated_params(self):
+        """returns a dictionary of the full parameter dictionary for the last compiled UPDATE statement,
+        including any ColumnDefaults that were pre-executed. this value is thread-local."""
+        return self.context.last_updated_params                
     def lastrow_has_defaults(self):
+        """returns True if the last row INSERTED via a compiled insert statement contained PassiveDefaults,
+        indicating that the database inserted data beyond that which we gave it. this value is thread-local."""
         return self.context.lastrow_has_defaults
         
     def pre_exec(self, proxy, compiled, parameters, **kwargs):
index 85fe8dc696462b4053b7c306180123d257df59a0..8239df99ce87e09d6f562d7d7d89849131c821c7 100644 (file)
@@ -591,6 +591,7 @@ class Mapper(object):
                 for rec in update:
                     (obj, params) = rec
                     c = statement.execute(params)
+                    self._postfetch(table, obj, table.engine.last_updated_params())
                     self.extension.after_update(self, obj)
                     rows += c.cursor.rowcount
                 if table.engine.supports_sane_rowcount() and rows != len(update):
@@ -608,18 +609,30 @@ class Mapper(object):
                             if self._getattrbycolumn(obj, col) is None:
                                 self._setattrbycolumn(obj, col, primary_key[i])
                             i+=1
-                    if table.engine.lastrow_has_defaults():
-                        clause = sql.and_()
-                        for p in self.pks_by_table[table]:
-                            clause.clauses.append(p == self._getattrbycolumn(obj, p))
-                        row = table.select(clause).execute().fetchone()
-                        for c in table.c:
-                            if self._getattrbycolumn(obj, c) is None:
-                                self._setattrbycolumn(obj, c, row[c])
+                    self._postfetch(table, obj, table.engine.last_inserted_params())
                     if self._synchronizer is not None:
                         self._synchronizer.execute(obj, obj)
                     self.extension.after_insert(self, obj)
-                    
+
+    def _postfetch(self, table, obj, params):
+        """after an INSERT or UPDATE, asks the engine if PassiveDefaults fired off on the database side
+        which need to be post-fetched, *or* if pre-exec defaults like ColumnDefaults were fired off
+        and should be populated into the instance. this is only for non-primary key columns."""
+        if table.engine.lastrow_has_defaults():
+            clause = sql.and_()
+            for p in self.pks_by_table[table]:
+                clause.clauses.append(p == self._getattrbycolumn(obj, p))
+            row = table.select(clause).execute().fetchone()
+            for c in table.c:
+                if self._getattrbycolumn(obj, c) is None:
+                    self._setattrbycolumn(obj, c, row[c])
+        else:
+            for c in table.c:
+                if c.primary_key or not params.has_key(c.name):
+                    continue
+                if self._getattrbycolumn(obj, c) != params[c.name]:
+                    self._setattrbycolumn(obj, c, params[c.name])
+
     def delete_obj(self, objects, uow):
         """called by a UnitOfWork object to delete objects, which involves a
         DELETE statement for each table used by this mapper, for each object in the list."""
index 63a39641e829b96a11f767d934795fb764517a0b..a4d2e874cb8dc285d4b666feb5bc26ba54082424 100644 (file)
@@ -229,11 +229,13 @@ class DefaultTest(AssertMixin):
         Column('id', Integer, Sequence("dt_seq", optional=True), primary_key=True),
         Column('hoho', hohotype, PassiveDefault(str(self.hohoval))),
         Column('counter', Integer, PassiveDefault("7")),
-        Column('foober', String(30), default="im foober")
+        Column('foober', String(30), default="im foober", onupdate="im the update")
         )
         self.table.create()
     def tearDownAll(self):
         self.table.drop()
+    def setUp(self):
+        self.table = Table('default_test', db)
     def testbasic(self):
         
         class Hoho(object):pass
@@ -261,7 +263,17 @@ class DefaultTest(AssertMixin):
         self.assert_(h1.counter ==  h4.counter==h5.counter==7)
         self.assert_(h2.foober == h3.foober == h4.foober == 'im foober')
         self.assert_(h5.foober=='im the new foober')
-            
+    
+    def testupdate(self):
+        class Hoho(object):pass
+        assign_mapper(Hoho, self.table)
+        h1 = Hoho()
+        objectstore.commit()
+        self.assert_(h1.foober == 'im foober')
+        h1.counter = 19
+        objectstore.commit()
+        self.assert_(h1.foober == 'im the update')
+        
 class SaveTest(AssertMixin):
 
     def setUpAll(self):