]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
deletes...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Sep 2005 02:27:39 +0000 (02:27 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Sep 2005 02:27:39 +0000 (02:27 +0000)
lib/sqlalchemy/mapper.py
lib/sqlalchemy/sql.py
test/objectstore.py
test/select.py

index b56a2753da7d1eedd82a40fa85e936af0b44cce9..7d6e0dd6716eb90061b96f50e7eccfca1411c6ba 100644 (file)
@@ -491,6 +491,8 @@ class PropertyLoader(MapperProperty):
             else:
                 self.foreignkey = w.dependent
                 
+        (self.lazywhere, self.lazybinds) = create_lazy_clause(self.parent.selectable, self.primaryjoin, self.secondaryjoin)
+                
         if not hasattr(parent.class_, key):
             objectstore.uow().register_attribute(parent.class_, key, uselist = self.uselist)
 
@@ -539,7 +541,7 @@ class PropertyLoader(MapperProperty):
         elif self.foreignkey.table == self.target:
             uowcommit.register_dependency(self.parent, self.mapper)
             uowcommit.register_task(self.parent, False, self, self.parent, False)
-            uowcommit.register_task(self.parent, True, self, self.mapper, False)
+            uowcommit.register_task(self.parent, True, self, self.parent, True)
                 
         elif self.foreignkey.table == self.parent.table:
             uowcommit.register_dependency(self.mapper, self.parent)
@@ -590,11 +592,25 @@ class PropertyLoader(MapperProperty):
         elif self.foreignkey.table == self.target:
             if delete:
                 updates = []
+                clearkeys = True
                 for obj in deplist:
+                    params = {}
+                    for bind in self.lazybinds.values():
+                        params[bind.key] = self.parent._getattrbycolumn(obj, self.parent.selectable.c[bind.shortname])
+                    updates.append(params)
                     childlist = getlist(obj, False)
-                #if len(updates):
-                #    statement = self.secondary.update(sql.and_(*[c == sql.bindparam(c.key) for c in self.secondary.c]))
-                #    statement.execute(*secondary_delete)
+                    for child in childlist.deleted_items() + childlist.unchanged_items():
+                        self.primaryjoin.accept_visitor(setter)
+                if len(updates):
+                    if self.private:
+                        statement = self.target.delete(lazywhere)
+                    else:
+                        parameters = {}
+                        for bind in self.lazybinds.values():
+                            parameters[bind.shortname] = None
+                        statement = self.target.update(self.lazywhere, parameters = parameters)
+                    statement.execute(**updates[0])
+                    #statement.execute(*updates)
             else:
                 for obj in deplist:
                     childlist = getlist(obj)
@@ -645,20 +661,6 @@ class PropertyLoader(MapperProperty):
 # to do child deletes
 class LazyLoader(PropertyLoader):
 
-    def init(self, key, parent):
-        PropertyLoader.init(self, key, parent)
-        if self.secondaryjoin is not None:
-            self.lazywhere = sql.and_(self.primaryjoin, self.secondaryjoin)
-        else:
-            self.lazywhere = self.primaryjoin
-
-        # we dont want to screw with the primaryjoin and secondary join of the PropertyLoader,
-        # so create a copy
-        self.lazywhere = self.lazywhere.copy_container()
-        self.binds = {}
-        li = BinaryVisitor(lambda b: self._create_lazy_clause(b, self.binds))
-        self.lazywhere.accept_visitor(li)
-
     def execute(self, instance, row, identitykey, imap, isnew):
         if isnew:
             # TODO: get lazy callables to be stored within the unit of work?
@@ -666,20 +668,33 @@ class LazyLoader(PropertyLoader):
             # when u deserialize tho
             objectstore.uow().attribute_set_callable(instance, self.key, LazyLoadInstance(self, row))
 
-    def _create_lazy_clause(self, binary, binds):
-        if isinstance(binary.left, schema.Column) and binary.left.table == self.parent.selectable:
-            binary.left = binds.setdefault(self.parent.selectable.name + "_" + binary.left.name,
-                    sql.BindParamClause(self.parent.selectable.name + "_" + binary.left.name, None, shortname = binary.left.name))
 
-        if isinstance(binary.right, schema.Column) and binary.right.table == self.parent.selectable:
-            binary.right = binds.setdefault(self.parent.selectable.name + "_" + binary.right.name,
-                    sql.BindParamClause(self.parent.selectable.name + "_" + binary.right.name, None, shortname = binary.right.name))
+def create_lazy_clause(table, primaryjoin, secondaryjoin):
+    binds = {}
+    def visit_binary(binary):
+        if isinstance(binary.left, schema.Column) and binary.left.table == table:
+            binary.left = binds.setdefault(table.name + "_" + binary.left.name,
+                    sql.BindParamClause(table.name + "_" + binary.left.name, None, shortname = binary.left.name))
+            binary.swap()
 
+        if isinstance(binary.right, schema.Column) and binary.right.table == table:
+            binary.right = binds.setdefault(table.name + "_" + binary.right.name,
+                    sql.BindParamClause(table.name + "_" + binary.right.name, None, shortname = binary.right.name))
+                    
+    if secondaryjoin is not None:
+        lazywhere = sql.and_(primaryjoin, secondaryjoin)
+    else:
+        lazywhere = primaryjoin
+    lazywhere = lazywhere.copy_container()
+    li = BinaryVisitor(visit_binary)
+    lazywhere.accept_visitor(li)
+    return (lazywhere, binds)
+        
 class LazyLoadInstance(object):
     """attached to a specific object instance to load related rows."""
     def __init__(self, lazyloader, row):
         self.params = {}
-        for key, value in lazyloader.binds.iteritems():
+        for key in lazyloader.lazybinds.keys():
             self.params[key] = row[key]
         self.mapper = lazyloader.mapper
         self.lazywhere = lazyloader.lazywhere
index d6247f57685ab91a3f02f419d382e86260fa938b..dc25a141100be68f67fd52e21915a77e457b9adc 100644 (file)
@@ -388,6 +388,10 @@ class BinaryClause(ClauseElement):
         self.right.accept_visitor(visitor)
         visitor.visit_binary(self)
 
+    def swap(self):
+        c = self.left
+        self.left = self.right
+        self.right = c
         
 class Selectable(FromClause):
     """represents a column list-holding object, like a table or subquery.  can be used anywhere
@@ -725,12 +729,14 @@ class UpdateBase(ClauseElement):
             if isinstance(value, Select):
                 value.clear_from(self.table.id)
             elif _is_literal(value):
-                try:
+                if _is_literal(key):
                     col = self.table.c[key]
+                else:
+                    col = key
+                try:
                     parameters[key] = bindparam(col.name, value)
                 except KeyError:
                     del parameters[key]
-
         return parameters
         
     def get_colparams(self, parameters):
@@ -743,6 +749,8 @@ class UpdateBase(ClauseElement):
         # compiled params
         if parameters is None:
             parameters = {}
+        else:
+            parameters = parameters.copy()
             
         if self.parameters is not None:
             for k, v in self.parameters.iteritems():
@@ -754,7 +762,10 @@ class UpdateBase(ClauseElement):
             if isinstance(key, schema.Column):
                 d[key] = value
             else:
-                d[self.table.columns[str(key)]] = value
+                try:
+                    d[self.table.columns[str(key)]] = value
+                except AttributeError:
+                    pass
 
         # create a list of column assignment clauses as tuples
         values = []
index 1c14f9aaaf928062b8d81f340c101e54ca918f88..dbf4c025ce7f7dd31811f185ef7054099c956fb5 100644 (file)
@@ -128,17 +128,19 @@ class SaveTest(AssertMixin):
 
     def testdelete(self):
         m = mapper(User, users, properties = dict(
-            address = relation(Address, addresses, lazy = True, uselist = False)
+            address = relation(Address, addresses, lazy = True, uselist = False, private = False)
         ))
         u = User()
+        a = Address()
         u.user_name = 'one2onetester'
-        u.address = Address()
+        u.address = a
         u.address.email_address = 'myonlyaddress@foo.com'
         objectstore.uow().commit()
 
         print "OK"
         objectstore.uow().register_deleted(u)
         objectstore.uow().commit()
+        self.assert_(a.address_id is not None and a.user_id is None)
         
     def testbackwardsonetoone(self):
         # test 'backwards'
index 884044c78d6683e94dccdf596f79f185837bc315..d74df1953ed43d2a247f30f9bb7deb95a70bbfc8 100644 (file)
@@ -343,6 +343,7 @@ class CRUDTest(SQLTest):
         self.runtest(update(table, table.c.id == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {table.c.name:'fred'})
         self.runtest(update(table, table.c.id == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {'name':'fred'})
         self.runtest(update(table, values = {table.c.name : table.c.id}), "UPDATE mytable SET name=mytable.myid")
+        self.runtest(update(table, whereclause = table.c.name == bindparam('crit'), values = {table.c.name : 'hi'}), "UPDATE mytable SET name=:name WHERE mytable.name = :crit", params = {'crit' : 'notthere'})
         self.runtest(update(table, table.c.id == 12, values = {table.c.name : table.c.id}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'})
 
     def testcorrelatedupdate(self):