]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 10 Sep 2005 23:56:45 +0000 (23:56 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 10 Sep 2005 23:56:45 +0000 (23:56 +0000)
lib/sqlalchemy/mapper.py
lib/sqlalchemy/objectstore.py
test/mapper.py

index 341006faf20665ed65927b43ca9ce235bb4eb220..4d679beebdda24f9fdede4a34bf2c847d3de326d 100644 (file)
@@ -212,65 +212,6 @@ class Mapper(object):
 
     def _setattrbycolumn(self, obj, column, value):
         self.columntoproperty[column][0].setattr(obj, value)
-        
-    def save(self, obj, traverse = True):
-        """saves the object across all its primary tables.  
-        based on the existence of the primary key for each table, either inserts or updates.
-        primary key is determined by the underlying database engine's sequence methodology.
-        the traverse flag indicates attached objects should be saved as well.
-        
-        if smart attributes are being used for the object, the "dirty" flag, or the absense 
-        of the attribute, determines if the item is saved.  if smart attributes are not being 
-        used, the item is saved unconditionally.
-        """
-
-        if objectstore.uow().is_dirty(obj):
-            def foo():
-                # TODO: unitofwork.begin()
-                
-                # TODO: put a registry for statements in the unitofwork,
-                # where we can store insert/update statements and pre-compile them
-                insert_statement = None
-                update_statement = None
-                for table in self.tables:
-                    params = {}
-                    for primary_key in table.primary_keys:
-                        if self._getattrbycolumn(obj, primary_key) is None:
-                            statement = table.insert()
-                            for col in table.columns:
-                                params[col.key] = self._getattrbycolumn(obj, col)
-                            break
-                    else:
-                        clause = sql.and_()
-                        for col in table.columns:
-                            if col.primary_key:
-                                clause.clauses.append(col == self._getattrbycolumn(obj, col))
-                            else:
-                                params[col.key] = self._getattrbycolumn(obj, col)
-                        statement = table.update(clause)
-                    statement.echo = self.echo
-                    statement.execute(**params)
-                    if isinstance(statement, sql.Insert):
-                        primary_keys = table.engine.last_inserted_ids()
-                        index = 0
-                        for col in table.primary_keys:
-                            newid = primary_keys[index]
-                            index += 1
-                            self._setattrbycolumn(obj, col, newid)
-                        self.put(obj)
-                # TODO: make this "register_saved", which gets committed
-                # to "clean" when you call "unitofwork.commit()"
-                # also put a reentrant "begin/commit" onto unitofwork to handle nests
-                objectstore.uow().register_clean(obj)
-                for prop in self.props.values():
-                    if not isinstance(prop, ColumnProperty):
-                        prop.save(obj, traverse)
-                        
-                # TODO: unitofwork.commit()
-            self.transaction(foo)
-        else:
-            for prop in self.props.values():
-                prop.save(obj, traverse)
 
     def save_obj(self, obj):
         for table in self.tables:
@@ -301,7 +242,6 @@ class Mapper(object):
                 self.put(obj)
 
     def register_dependencies(self, obj, uow):
-        print "hi1"
         for prop in self.props.values():
             prop.register_dependencies(obj, uow)
             
@@ -395,10 +335,6 @@ class MapperProperty:
         """called when the MapperProperty is first attached to a new parent Mapper."""
         pass
 
-    def save(self, object, traverse):
-        """called when the instance is being saved"""
-        pass
-
     def delete(self, object):
         """called when the instance is being deleted"""
         pass
@@ -520,11 +456,13 @@ class PropertyLoader(MapperProperty):
 
         setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, self.secondary)
 
+        print "procdep " + repr(deplist)
         if self.secondaryjoin is not None:
             secondary_delete = []
             secondary_insert = []
             for obj in deplist:
                 childlist = getlist(obj)
+                print "added! " + repr(childlist.added_items())
                 for child in childlist.added_items():
                     setter.obj = obj
                     setter.child = child
@@ -567,59 +505,6 @@ class PropertyLoader(MapperProperty):
         else:
             raise " no foreign key ?"
         
-    def save(self, obj, traverse):
-        # saves child objects
-        
-        if self.secondary is not None:
-            secondary_delete = []
-            secondary_insert = []
-             
-        setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, self.secondary, obj)
-        
-        if self.uselist:
-            childlist = objectstore.uow().register_list_attribute(obj, self.key)
-        else:
-            childlist = objectstore.uow().register_attribute(obj, self.key)
-
-        for child in childlist.deleted_items():
-            setter.child = child
-            setter.associationrow = {}
-            setter.clearkeys = True
-            self.primaryjoin.accept_visitor(setter)
-            self.mapper.save(child, traverse)
-            if self.secondary is not None:
-                self.secondaryjoin.accept_visitor(setter)
-                secondary_delete.append(setter.associationrow)
-                
-        for child in childlist.added_items():
-            setter.child = child
-            setter.associationrow = {}
-            self.primaryjoin.accept_visitor(setter)
-            self.mapper.save(child, traverse)
-            if self.secondary is not None:
-                self.secondaryjoin.accept_visitor(setter)
-                secondary_insert.append(setter.associationrow)
-
-        if self.secondary is not None:
-            # TODO: use unitofwork statement repository thing to get these
-            # delete/insert statements
-            # then, see if unitofwork can even bunch these all up at the end to do an even 
-            # bigger grouping within the "commit"
-            if len(secondary_delete):
-                statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key) for c in self.secondary.c]))
-                statement.echo = self.mapper.echo
-                statement.execute(*secondary_delete)
-            if len(secondary_insert):
-                statement = self.secondary.insert()
-                statement.echo = self.mapper.echo
-                statement.execute(*secondary_insert)
-
-        for child in childlist.unchanged_items():
-            self.mapper.save(child, traverse)
-        # TODO: make this "register_saved_property", or something similar, which gets 
-        # a "clear_history" when you call "unitofwork.commit()"
-        # also put a reentrant "begin/commit" onto unitofwork to handle nests
-        childlist.clear_history()
             
     def delete(self):
         self.mapper.delete()
index 2a9617e74367a3ed8a3f0663dcf3cda6196675ce..334ee42e9fba47c61c9ad0630050115043de988e 100644 (file)
@@ -106,11 +106,34 @@ def has_key(key):
             return True
     else:
         return False
+
+class UOWListElement(util.HistoryArraySet):
+    class listpointer(object): pass
+        
+    def __init__(self, obj, items = None):
+        util.HistoryArraySet.__init__(self, items)
+        self.obj = weakref.ref(obj)
+        
+        # cant hash a UserList, so make a bullshit pointer to us
+        self.listpointer = UOWListElement.listpointer()
+        self.listpointer.list = self
+        
+    def _setrecord(self, item):
+        res = util.HistoryArraySet._setrecord(self, item)
+        if res:
+            uow().modified_lists.append(self.listpointer)
+        return res
+    def _delrecord(self, item):
+        res = util.HistoryArraySet._delrecord(self, item)
+        if res:
+            uow().modified_lists.append(self.listpointer)
+        return res
     
 class UnitOfWork(object):
     def __init__(self):
         self.new = util.HashSet()
         self.dirty = util.HashSet()
+        self.modified_lists = util.HashSet()
         self.deleted = util.HashSet()
         self.attribute_history = weakref.WeakKeyDictionary()
         
@@ -154,14 +177,14 @@ class UnitOfWork(object):
         try:
             childlist = obj.__dict__[key]
         except KeyError:
-            childlist = util.HistoryArraySet()
+            childlist = UOWListElement(obj)
             obj.__dict__[key] = childlist
         
         if callable(childlist):
             childlist = childlist()
             
         if not isinstance(childlist, util.HistoryArraySet):
-            childlist = util.HistoryArraySet(childlist)
+            childlist = UOWListElement(obj, childlist)
             obj.__dict__[key] = childlist
         if data is not None and childlist.data != data:
             childlist.set_data(data)
@@ -204,9 +227,14 @@ class UnitOfWork(object):
             mapper = sqlalchemy.mapper.object_mapper(obj)
             mapperlist = mappers.setdefault(mapper, [])
             mapperlist.append(obj)
-
+        for array in self.modified_lists:
+            mapper = sqlalchemy.mapper.object_mapper(array.list.obj())
+            mapperlist = mappers.setdefault(mapper, [])
+            mapperlist.append(array.list.obj())
+            
         for mapper in mappers.keys():
             mapperlist = mappers[mapper]
+            print repr(mapperlist)
             mapper.register_dependencies(mapperlist, self)
             
         mapperlist = mappers.keys()
@@ -221,7 +249,9 @@ class UnitOfWork(object):
         
         for mapper in mapperlist:
             obj_list = mappers[mapper]
+            print "mapper " + mapper.table.name
             deplist = self.dependencies.get(mapper, [])
+            print "deps " + repr(deplist)
             for obj in obj_list:
                 mapper.save_obj(obj)
             for dep in deplist:
@@ -230,6 +260,10 @@ class UnitOfWork(object):
 
         self.new.clear()
         self.dirty.clear()
+        for item in self.modified_lists:
+            item = item.list
+            item.clear_history()
+        self.modified_lists.clear()
 
     def register_dependency(self, obj, dependency, processor, stuff_to_process):
         self.dependencies[(obj, dependency)] = True
index 1c1d9bb6045dec4a496da2170ea73d8c50fd18d1..2c7b0b1e0905630f59cdcda1f7489865b5c3510c 100644 (file)
@@ -520,9 +520,10 @@ class SaveTest(AssertMixin):
                 item.keywords.append(k)
 
         objectstore.uow().commit()
-
+        print "OK!"
         l = m.select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]))
         self.assert_result(l, data)
+        print "OK!"
 
         objects[4].item_name = 'item4updated'
         k = Keyword()
@@ -530,6 +531,10 @@ class SaveTest(AssertMixin):
         objects[5].keywords.append(k)
         
         objectstore.uow().commit()
+        print "OK!"
+        objects[2].keywords.append(k)
+        print "added: " + repr(objects[2].keywords.added_items())
+        objectstore.uow().commit()
         
 if __name__ == "__main__":
     unittest.main()