]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Sep 2005 04:49:42 +0000 (04:49 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Sep 2005 04:49:42 +0000 (04:49 +0000)
lib/sqlalchemy/mapper.py

index e1ab2e47c1e404391818de12d0f4e11d50e34c2c..f285ba1939922455ded0d775280d9065f9dad868 100644 (file)
@@ -214,6 +214,8 @@ class Mapper(object):
         self.columntoproperty[column][0].setattr(obj, value)
 
     def save_obj(self, obj):
+        # TODO: start breaking down the individual updates/inserts into the UOW or something,
+        # so they can be combined into a multiple execute
         for table in self.tables:
             params = {}
             for primary_key in table.primary_keys:
@@ -460,7 +462,11 @@ class PropertyLoader(MapperProperty):
             else:
                 return uow.register_attribute(obj, self.key)
 
-        setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, self.secondary)
+        clearkeys = False
+        
+        def sync_foreign_keys(binary):
+            self._sync_foreign_keys(binary, obj, child, associationrow, clearkeys)
+        setter = BinaryVisitor(sync_foreign_keys)
 
         if self.secondaryjoin is not None:
             secondary_delete = []
@@ -468,20 +474,16 @@ class PropertyLoader(MapperProperty):
             for obj in deplist:
                 childlist = getlist(obj)
                 for child in childlist.added_items():
-                    setter.obj = obj
-                    setter.child = child
-                    setter.associationrow = {}
+                    associationrow = {}
                     self.primaryjoin.accept_visitor(setter)
                     self.secondaryjoin.accept_visitor(setter)
-                    secondary_insert.append(setter.associationrow)
+                    secondary_insert.append(associationrow)
                 for child in childlist.deleted_items():
-                    setter.obj = obj
-                    setter.child = child
-                    setter.associationrow = {}
-                    setter.clearkeys = True
+                    associationrow = {}
+                    clearkeys = True
                     self.primaryjoin.accept_visitor(setter)
                     self.secondaryjoin.accept_visitor(setter)
-                    secondary_delete.append(setter.associationrow)
+                    secondary_delete.append(associationrow)
             if len(secondary_delete):
                 statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key) for c in self.secondary.c]))
                 statement.echo = self.mapper.echo
@@ -494,21 +496,34 @@ class PropertyLoader(MapperProperty):
             for obj in deplist:
                 childlist = getlist(obj)
                 for child in childlist.added_items():
-                    setter.obj = obj
-                    setter.child = child
-                    setter.associationrow = {}
+                    associationrow = {}
                     self.primaryjoin.accept_visitor(setter)
+                # TODO: deleted items
         elif self.foreignkey.table == self.parent.table:
-            for obj in deplist:
-                childlist = getlist(obj)
-                for child in childlist.added_items():
-                    setter.obj = child
-                    setter.child = obj
-                    setter.associationrow = {}
+            for child in deplist:
+                childlist = getlist(child)
+                for obj in childlist.added_items():
+                    associationrow = {}
                     self.primaryjoin.accept_visitor(setter)
+                # TODO: deleted items
         else:
             raise " no foreign key ?"
-        
+
+    def _sync_foreign_keys(self, binary, obj, child, associationrow, clearkeys):
+        """given a binary clause with an = operator joining two table columns, synchronizes the values 
+        of the corresponding attributes within a parent object and a child object, or the attributes within an 
+        an "association row" that represents an association link between the 'parent' and 'child' object."""
+        if binary.operator == '=':
+            colmap = {binary.left.table : binary.left, binary.right.table : binary.right}
+            if colmap.has_key(self.parent.table) and colmap.has_key(self.target):
+                if clearkeys:
+                    self.mapper._setattrbycolumn(child, colmap[self.target], None)
+                else:
+                    self.mapper._setattrbycolumn(child, colmap[self.target], self.parent._getattrbycolumn(obj, colmap[self.parent.table]))
+            elif colmap.has_key(self.parent.table) and colmap.has_key(self.secondary):
+                associationrow[colmap[self.secondary].key] = self.parent._getattrbycolumn(obj, colmap[self.parent.table])
+            elif colmap.has_key(self.target) and colmap.has_key(self.secondary):
+                associationrow[colmap[self.secondary].key] = self.mapper._getattrbycolumn(child, colmap[self.target])
             
     def delete(self):
         self.mapper.delete()
@@ -663,35 +678,12 @@ class TableFinder(sql.ClauseVisitor):
     def visit_table(self, table):
         self.tables.append(table)
 
-class ForeignKeySetter(sql.ClauseVisitor):
-    """traverses a join condition of a parent/child object or two objects attached by
-    an association table and sets properties on either the child object or an 
-    association table row according to the join properties."""
-    def __init__(self, parentmapper, childmapper, primarytable, secondarytable, associationtable):
-        self.parentmapper = parentmapper
-        self.childmapper = childmapper
-        self.primarytable = primarytable
-        self.secondarytable = secondarytable
-        self.associationtable = associationtable
-        self.obj = None
-        self.associationrow = {}
-        self.clearkeys = False
-        self.child = None
-
+class BinaryVisitor(sql.ClauseVisitor):
+    def __init__(self, func):
+        self.func = func
     def visit_binary(self, binary):
-        if binary.operator == '=':
-            # play a little rock/paper/scissors here    
-            colmap = {binary.left.table : binary.left, binary.right.table : binary.right}
-            if colmap.has_key(self.primarytable) and colmap.has_key(self.secondarytable):
-                if self.clearkeys:
-                    self.childmapper._setattrbycolumn(self.child, colmap[self.secondarytable], None)
-                else:
-                    self.childmapper._setattrbycolumn(self.child, colmap[self.secondarytable], self.parentmapper._getattrbycolumn(self.obj, colmap[self.primarytable]))
-            elif colmap.has_key(self.primarytable) and colmap.has_key(self.associationtable):
-                self.associationrow[colmap[self.associationtable].key] = self.parentmapper._getattrbycolumn(self.obj, colmap[self.primarytable])
-            elif colmap.has_key(self.secondarytable) and colmap.has_key(self.associationtable):
-                self.associationrow[colmap[self.associationtable].key] = self.childmapper._getattrbycolumn(self.child, colmap[self.secondarytable])
-                
+        self.func(binary)
+        
 class LazyIzer(sql.ClauseVisitor):
     """converts an expression which refers to a table column into an
     expression refers to a Bind Param, i.e. a specific value.