From: Mike Bayer Date: Sat, 23 Sep 2006 00:06:10 +0000 (+0000) Subject: - fixed unfortunate mutating-dictionary glitch from previous checkin X-Git-Tag: rel_0_3_0~140 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f6baed941d8f9c5a076e7d28c9ef1b1a94bfc3fd;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - fixed unfortunate mutating-dictionary glitch from previous checkin - added "batch=True" flag to mapper; if False, save_obj will fully save one object at a time including calls to before_XXXX and after_XXXX --- diff --git a/CHANGES b/CHANGES index d8b52b1751..10ef7968b1 100644 --- a/CHANGES +++ b/CHANGES @@ -31,6 +31,9 @@ kept separate from the normal mapper setup, thereby preventing conflicts with lazy loader operation, fixes [ticket:308] - fix to deferred group loading +- added "batch=True" flag to mapper; if False, save_obj +will fully save one object at a time including calls +to before_XXXX and after_XXXX 0.2.8 - cleanup on connection methods + documentation. custom DBAPI diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py index 2d3f910d8c..84a1d58fb8 100644 --- a/lib/sqlalchemy/attributes.py +++ b/lib/sqlalchemy/attributes.py @@ -616,7 +616,7 @@ class AttributeManager(object): def noninherited_managed_attributes(self, class_): if not isinstance(class_, type): raise repr(class_) + " is not a type" - for key in class_.__dict__: + for key in list(class_.__dict__): value = getattr(class_, key, None) if isinstance(value, InstrumentedAttribute): yield value diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index b42a79d8f3..0f298a1fbe 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -53,7 +53,8 @@ class Mapper(object): polymorphic_identity=None, concrete=False, select_table=None, - allow_null_pks=False): + allow_null_pks=False, + batch=True): if not issubclass(class_, object): raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__) @@ -87,7 +88,7 @@ class Mapper(object): self.allow_column_override = allow_column_override self.allow_null_pks = allow_null_pks self.delete_orphans = [] - + self.batch = batch # a Column which is used during a select operation to retrieve the # "polymorphic identity" of the row, which indicates which Mapper should be used # to construct a new object instance from that row. @@ -705,12 +706,18 @@ class Mapper(object): def _setattrbycolumn(self, obj, column, value): self.columntoproperty[column][0].setattr(obj, value) - def save_obj(self, objects, uow, postupdate=False, post_update_cols=None): + def save_obj(self, objects, uow, postupdate=False, post_update_cols=None, single=False): """called by a UnitOfWork object to save objects, which involves either an INSERT or an UPDATE statement for each table used by this mapper, for each element of the list.""" #print "SAVE_OBJ MAPPER", self.class_.__name__, objects + # if batch=false, call save_obj separately for each object + if not single and not self.batch: + for obj in objects: + self.save_obj([obj], uow, postupdate=postupdate, post_update_cols=post_update_cols, single=True) + return + connection = uow.transaction.connection(self) if not postupdate: @@ -818,6 +825,7 @@ class Mapper(object): update.append((obj, params)) else: insert.append((obj, params)) + if len(update): clause = sql.and_() for col in self.pks_by_table[table]: diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index c3744abec3..35c5378fa5 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -734,6 +734,32 @@ class SaveTest(UnitOfWorkTest): k = ctx.current.query(KeywordUser).get(id) assert k.user_name == 'keyworduser' assert k.keyword_name == 'a keyword' + + def testbatchmode(self): + class TestExtension(MapperExtension): + def before_insert(self, mapper, connection, instance): + self.current_instance = instance + def after_insert(self, mapper, connection, instance): + assert instance is self.current_instance + m = mapper(User, users, extension=TestExtension(), batch=False) + u1 = User() + u1.username = 'user1' + u2 = User() + u2.username = 'user2' + ctx.current.flush() + + clear_mappers() + + m = mapper(User, users, extension=TestExtension()) + u1 = User() + u1.username = 'user1' + u2 = User() + u2.username = 'user2' + try: + ctx.current.flush() + assert False + except AssertionError: + assert True def testonetoone(self): m = mapper(User, users, properties = dict(