]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fixed unfortunate mutating-dictionary glitch from previous checkin
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 Sep 2006 00:06:10 +0000 (00:06 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 Sep 2006 00:06:10 +0000 (00:06 +0000)
- added "batch=True" flag to mapper; if False, save_obj
will fully save one object at a time including calls
to before_XXXX and after_XXXX

CHANGES
lib/sqlalchemy/attributes.py
lib/sqlalchemy/orm/mapper.py
test/orm/unitofwork.py

diff --git a/CHANGES b/CHANGES
index d8b52b1751e12d061352f53ab4acc70e45b47265..10ef7968b1b989dc2c5bbd46edc5cdc1f843b093 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -31,6 +31,9 @@ kept separate from the normal mapper setup, thereby
 preventing conflicts with lazy loader operation, fixes 
 [ticket:308]
 - fix to deferred group loading
+- added "batch=True" flag to mapper; if False, save_obj
+will fully save one object at a time including calls
+to before_XXXX and after_XXXX
 
 0.2.8
 - cleanup on connection methods + documentation.  custom DBAPI
index 2d3f910d8c8e8d9f24530d82ce5dd30baa6ab004..84a1d58fb8741618cd63f6e7b4c3db5ecb722365 100644 (file)
@@ -616,7 +616,7 @@ class AttributeManager(object):
     def noninherited_managed_attributes(self, class_):
         if not isinstance(class_, type):
             raise repr(class_) + " is not a type"
-        for key in class_.__dict__:
+        for key in list(class_.__dict__):
             value = getattr(class_, key, None)
             if isinstance(value, InstrumentedAttribute):
                 yield value
index b42a79d8f35391a33e5f18ff0857eaeedfe2ec72..0f298a1fbe4c8b6e51d7063b719f79a1dd76a594 100644 (file)
@@ -53,7 +53,8 @@ class Mapper(object):
                 polymorphic_identity=None,
                 concrete=False,
                 select_table=None,
-                allow_null_pks=False):
+                allow_null_pks=False,
+                batch=True):
 
         if not issubclass(class_, object):
             raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
@@ -87,7 +88,7 @@ class Mapper(object):
         self.allow_column_override = allow_column_override
         self.allow_null_pks = allow_null_pks
         self.delete_orphans = []
-        
+        self.batch = batch
         # a Column which is used during a select operation to retrieve the 
         # "polymorphic identity" of the row, which indicates which Mapper should be used
         # to construct a new object instance from that row.
@@ -705,12 +706,18 @@ class Mapper(object):
     def _setattrbycolumn(self, obj, column, value):
         self.columntoproperty[column][0].setattr(obj, value)
     
-    def save_obj(self, objects, uow, postupdate=False, post_update_cols=None):
+    def save_obj(self, objects, uow, postupdate=False, post_update_cols=None, single=False):
         """called by a UnitOfWork object to save objects, which involves either an INSERT or
         an UPDATE statement for each table used by this mapper, for each element of the
         list."""
         #print "SAVE_OBJ MAPPER", self.class_.__name__, objects
         
+        # if batch=false, call save_obj separately for each object
+        if not single and not self.batch:
+            for obj in objects:
+                self.save_obj([obj], uow, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
+            return
+            
         connection = uow.transaction.connection(self)
 
         if not postupdate:
@@ -818,6 +825,7 @@ class Mapper(object):
                         update.append((obj, params))
                 else:
                     insert.append((obj, params))
+                    
             if len(update):
                 clause = sql.and_()
                 for col in self.pks_by_table[table]:
index c3744abec3f82a3c46e70e524953ea1acd195417..35c5378fa526dff03458fa28065c6d9f37fa88d0 100644 (file)
@@ -734,6 +734,32 @@ class SaveTest(UnitOfWorkTest):
         k = ctx.current.query(KeywordUser).get(id)
         assert k.user_name == 'keyworduser'
         assert k.keyword_name == 'a keyword'
+    
+    def testbatchmode(self):
+        class TestExtension(MapperExtension):
+            def before_insert(self, mapper, connection, instance):
+                self.current_instance = instance
+            def after_insert(self, mapper, connection, instance):
+                assert instance is self.current_instance
+        m = mapper(User, users, extension=TestExtension(), batch=False)
+        u1 = User()
+        u1.username = 'user1'
+        u2 = User()
+        u2.username = 'user2'
+        ctx.current.flush()
+        
+        clear_mappers()
+        
+        m = mapper(User, users, extension=TestExtension())
+        u1 = User()
+        u1.username = 'user1'
+        u2 = User()
+        u2.username = 'user2'
+        try:
+            ctx.current.flush()
+            assert False
+        except AssertionError:
+            assert True
         
     def testonetoone(self):
         m = mapper(User, users, properties = dict(