]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Oct 2005 00:52:22 +0000 (00:52 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Oct 2005 00:52:22 +0000 (00:52 +0000)
lib/sqlalchemy/attributes.py
lib/sqlalchemy/mapper.py
lib/sqlalchemy/objectstore.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/util.py
test/select.py

index 59f68203bfb921da23916f328f0fc22d325a5446..7f9c708257dd45868f36669d4f984fa043c4f815 100644 (file)
@@ -26,23 +26,23 @@ class SmartProperty(object):
         self.manager = manager
     def attribute_registry(self):
         return self.manager
-    def property(self, key, uselist):
+    def property(self, key, uselist, **kwargs):
         def set_prop(obj, value):
             if uselist:
-                self.attribute_registry().set_list_attribute(obj, key, value)
+                self.attribute_registry().set_list_attribute(obj, key, value, **kwargs)
             else:
-                self.attribute_registry().set_attribute(obj, key, value)
+                self.attribute_registry().set_attribute(obj, key, value, **kwargs)
         def del_prop(obj):
             if uselist:
                 # TODO: this probably doesnt work right, deleting the list off an item
-                self.attribute_registry().delete_list_attribute(obj, key)
+                self.attribute_registry().delete_list_attribute(obj, key, **kwargs)
             else:
-                self.attribute_registry().delete_attribute(obj, key)
+                self.attribute_registry().delete_attribute(obj, key, **kwargs)
         def get_prop(obj):
             if uselist:
-                return self.attribute_registry().get_list_attribute(obj, key)
+                return self.attribute_registry().get_list_attribute(obj, key, **kwargs)
             else:
-                return self.attribute_registry().get_attribute(obj, key)
+                return self.attribute_registry().get_attribute(obj, key, **kwargs)
                 
         return property(get_prop, set_prop, del_prop)
 
@@ -50,7 +50,7 @@ class PropHistory(object):
     """manages the value of a particular scalar attribute on a particular object instance."""
     # make our own NONE to distinguish from "None"
     NONE = object()
-    def __init__(self, obj, key):
+    def __init__(self, obj, key, **kwargs):
         self.obj = obj
         self.key = key
         self.orig = PropHistory.NONE
@@ -116,7 +116,7 @@ class ListElement(util.HistoryArraySet):
         return self
     def __call__(self, *args, **kwargs):
         return self
-    def list_value_changed(self, obj, key, listval):
+    def list_value_changed(self, obj, key, item, listval, isdelete):
         pass    
     def setattr(self, value):
         self.obj.__dict__[self.key] = value
@@ -126,12 +126,12 @@ class ListElement(util.HistoryArraySet):
     def _setrecord(self, item):
         res = util.HistoryArraySet._setrecord(self, item)
         if res:
-            self.list_value_changed(self.obj, self.key, self)
+            self.list_value_changed(self.obj, self.key, item, self, False)
         return res
     def _delrecord(self, item):
         res = util.HistoryArraySet._delrecord(self, item)
         if res:
-            self.list_value_changed(self.obj, self.key, self)
+            self.list_value_changed(self.obj, self.key, item, self, True)
         return res
 
 class CallableProp(object):
@@ -140,12 +140,13 @@ class CallableProp(object):
     the AttributeManager.  When the attributemanager
     accesses the object attribute, either to get its history or its real value, the __call__ method
     is invoked which runs the underlying callable_ and sets the new value to the object attribute
-    via the manager."""
-    def __init__(self, callable_, obj, key, uselist = False):
+    via the manager, at which point the CallableProp itself is dereferenced."""
+    def __init__(self, callable_, obj, key, uselist = False, **kwargs):
         self.callable_ = callable_
         self.obj = obj
         self.key = key
         self.uselist = uselist
+        self.kwargs = kwargs
     def gethistory(self, manager, *args, **kwargs):
         self.__call__(manager, *args, **kwargs)
         return manager.attribute_history[self.obj][self.key]
@@ -154,12 +155,12 @@ class CallableProp(object):
             return None
         value = self.callable_()
         if self.uselist:
-            p = manager.create_list(self.obj, self.key, value)
+            p = manager.create_list(self.obj, self.key, value, **self.kwargs)
             manager.attribute_history[self.obj][self.key] = p
             return p
         else:
             self.obj.__dict__[self.key] = value
-            p = PropHistory(self.obj, self.key)
+            p = PropHistory(self.obj, self.key, **self.kwargs)
             manager.attribute_history[self.obj][self.key] = p
             return p
             
@@ -170,14 +171,14 @@ class AttributeManager(object):
 
     def value_changed(self, obj, key, value):
         pass
-    def create_prop(self, key, uselist):
-        return SmartProperty(self).property(key, uselist)
-    def create_list(self, obj, key, list_):
+    def create_prop(self, key, uselist, **kwargs):
+        return SmartProperty(self).property(key, uselist, **kwargs)
+    def create_list(self, obj, key, list_, **kwargs):
         return ListElement(obj, key, list_)
         
-    def get_attribute(self, obj, key):
+    def get_attribute(self, obj, key, **kwargs):
         try:
-            return self.get_history(obj, key)(self)
+            return self.get_history(obj, key, **kwargs)(self)
         except KeyError:
             pass
         try:
@@ -185,29 +186,29 @@ class AttributeManager(object):
         except KeyError:
             raise AttributeError(key)
 
-    def get_list_attribute(self, obj, key):
-        return self.get_list_history(obj, key)
+    def get_list_attribute(self, obj, key, **kwargs):
+        return self.get_list_history(obj, key, **kwargs)
         
-    def set_attribute(self, obj, key, value):
-        self.get_history(obj, key).setattr(value)
+    def set_attribute(self, obj, key, value, **kwargs):
+        self.get_history(obj, key, **kwargs).setattr(value)
         self.value_changed(obj, key, value)
     
-    def set_list_attribute(self, obj, key, value):
-        self.get_list_history(obj, key).setattr(value)
+    def set_list_attribute(self, obj, key, value, **kwargs):
+        self.get_list_history(obj, key, **kwargs).setattr(value)
         
-    def delete_attribute(self, obj, key):
-        self.get_history(obj, key).delattr()
+    def delete_attribute(self, obj, key, **kwargs):
+        self.get_history(obj, key, **kwargs).delattr()
         self.value_changed(obj, key, value)
 
-    def set_callable(self, obj, key, func, uselist):
+    def set_callable(self, obj, key, func, uselist, **kwargs):
         try:
             d = self.attribute_history[obj]
         except KeyError, e:
             d = {}
             self.attribute_history[obj] = d
-        d[key] = CallableProp(func, obj, key, uselist)
+        d[key] = CallableProp(func, obj, key, uselist, **kwargs)
         
-    def delete_list_attribute(self, obj, key):
+    def delete_list_attribute(self, obj, key, **kwargs):
         pass
         
     def rollback(self, obj = None):
@@ -242,22 +243,22 @@ class AttributeManager(object):
         except KeyError:
             pass
             
-    def get_history(self, obj, key):
+    def get_history(self, obj, key, **kwargs):
         try:
             return self.attribute_history[obj][key].gethistory(self)
         except KeyError, e:
             if e.args[0] is obj:
                 d = {}
                 self.attribute_history[obj] = d
-                p = PropHistory(obj, key)
+                p = PropHistory(obj, key, **kwargs)
                 d[key] = p
                 return p
             else:
-                p = PropHistory(obj, key)
+                p = PropHistory(obj, key, **kwargs)
                 self.attribute_history[obj][key] = p
                 return p
 
-    def get_list_history(self, obj, key, passive = False):
+    def get_list_history(self, obj, key, passive = False, **kwargs):
         try:
             return self.attribute_history[obj][key].gethistory(self, passive)
         except KeyError, e:
@@ -266,13 +267,13 @@ class AttributeManager(object):
             if e.args[0] is obj:
                 d = {}
                 self.attribute_history[obj] = d
-                p = self.create_list(obj, key, list_)
+                p = self.create_list(obj, key, list_, **kwargs)
                 d[key] = p
                 return p
             else:
-                p = self.create_list(obj, key, list_)
+                p = self.create_list(obj, key, list_, **kwargs)
                 self.attribute_history[obj][key] = p
                 return p
 
-    def register_attribute(self, class_, key, uselist):
-        setattr(class_, key, self.create_prop(key, uselist))
+    def register_attribute(self, class_, key, uselist, **kwargs):
+        setattr(class_, key, self.create_prop(key, uselist, **kwargs))
index 7feb24439bec9b5a8d4269117e96fbd44bebe575..dda89fe06ecfab05b525fe9483f26eded59286b0 100644 (file)
@@ -324,7 +324,10 @@ class Mapper(object):
 #                print "SAVE_OBJ we are " + hash_key(self) + " obj: " +  obj.__class__.__name__ + repr(id(obj))
                 params = {}
                 for col in table.columns:
-                    params[col.key] = self._getattrbycolumn(obj, col)
+                    if col.primary_key:
+                        params[col.table.name + "_" + col.key] = self._getattrbycolumn(obj, col)
+                    else:
+                        params[col.key] = self._getattrbycolumn(obj, col)
 
                 if hasattr(obj, "_instance_key"):
                     update.append(params)
@@ -335,12 +338,11 @@ class Mapper(object):
                 #print "REGULAR UPDATES"
                 clause = sql.and_()
                 for col in self.primary_keys[table]:
-                    clause.clauses.append(col == sql.bindparam(col.key))
+                    clause.clauses.append(col == sql.bindparam(col.table.name + "_" + col.key))
                 statement = table.update(clause)
                 c = statement.execute(*update)
                 if c.cursor.rowcount != len(update):
                     raise "ConcurrencyError - updated rowcount does not match number of objects updated"
-            
             if len(insert):
                 statement = table.insert()
                 for rec in insert:
@@ -598,7 +600,7 @@ class PropertyLoader(MapperProperty):
                 
         if not hasattr(parent.class_, key):
             #print "regiser list col on class %s key %s" % (parent.class_.__name__, key)
-            objectstore.uow().register_attribute(parent.class_, key, uselist = self.uselist)
+            objectstore.uow().register_attribute(parent.class_, key, uselist = self.uselist, deleteremoved = self.private)
 
     def get_direction(self):
         if self.thiscol is not None:
@@ -808,7 +810,7 @@ class PropertyLoader(MapperProperty):
 class LazyLoader(PropertyLoader):
     def execute(self, instance, row, identitykey, imap, isnew):
         if isnew:
-            objectstore.uow().register_callable(instance, self.key, LazyLoadInstance(self, row), uselist=self.uselist)
+            objectstore.uow().register_callable(instance, self.key, LazyLoadInstance(self, row), uselist=self.uselist, deleteremoved = self.private)
 
 def create_lazy_clause(table, primaryjoin, secondaryjoin, thiscol):
     binds = {}
index e2167cd58d13972ea33a8769e51ea0bc7f981822..c930e19ebf7d8adccab8ea694a62738741c8db67 100644 (file)
@@ -76,8 +76,13 @@ class UOWSmartProperty(attributes.SmartProperty):
         return uow().attributes
     
 class UOWListElement(attributes.ListElement):
-    def list_value_changed(self, obj, key, listval):
+    def __init__(self, obj, key, data=None, deleteremoved=False):
+        attributes.ListElement.__init__(self, obj, key, data=data)
+        self.deleteremoved = deleteremoved
+    def list_value_changed(self, obj, key, item, listval, isdelete):
         uow().modified_lists.append(self)
+        if isdelete and self.deleteremoved:
+            uow().register_deleted(item)
 
 class UOWAttributeManager(attributes.AttributeManager):
     def __init__(self, uow):
@@ -90,11 +95,11 @@ class UOWAttributeManager(attributes.AttributeManager):
         else:
             self.uow.register_new(obj)
 
-    def create_prop(self, key, uselist):
-        return UOWSmartProperty(self).property(key, uselist)
+    def create_prop(self, key, uselist, **kwargs):
+        return UOWSmartProperty(self).property(key, uselist, **kwargs)
 
-    def create_list(self, obj, key, list_):
-        return UOWListElement(obj, key, list_)
+    def create_list(self, obj, key, list_, **kwargs):
+        return UOWListElement(obj, key, list_, **kwargs)
         
 class UnitOfWork(object):
     def __init__(self, parent = None, is_begun = False):
@@ -136,11 +141,11 @@ class UnitOfWork(object):
         self._put(obj._instance_key, obj)
         self.register_dirty(obj)
         
-    def register_attribute(self, class_, key, uselist):
-        self.attributes.register_attribute(class_, key, uselist)
+    def register_attribute(self, class_, key, uselist, **kwargs):
+        self.attributes.register_attribute(class_, key, uselist, **kwargs)
 
-    def register_callable(self, obj, key, func, uselist):
-        self.attributes.set_callable(obj, key, func, uselist)
+    def register_callable(self, obj, key, func, uselist, **kwargs):
+        self.attributes.set_callable(obj, key, func, uselist, **kwargs)
         
     def register_clean(self, obj):
         try:
@@ -456,7 +461,7 @@ class UOWTask(object):
             if dependencies.has_key(node.item):
                 for processor, deptask in dependencies[node.item].iteritems():
                     parenttask.dependencies.append((processor, deptask))
-            t = d[node.item]
+            t = get_task(node.item)
             for n in node.children:
                 t2 = make_task_tree(n, t)
             return t
index 52ce5fdc9dbc34c8f8b550b30c9c544f982e6a95..ae74ff0ed1180aee4fe944c5e1e678846c2f9f77 100644 (file)
@@ -116,7 +116,10 @@ def subquery(alias, *args, **params):
     return Alias(Select(*args, **params), alias)
 
 def bindparam(key, value = None, type=None):
-    return BindParamClause(key, value, type=type)
+    if isinstance(key, schema.Column):
+        return BindParamClause(key.name, value, type=key.type)
+    else:
+        return BindParamClause(key, value, type=type)
 
 def text(text):
     return TextClause(text)
@@ -775,7 +778,7 @@ class UpdateBase(ClauseElement):
                 else:
                     col = key
                 try:
-                    parameters[key] = bindparam(col.name, value, type=col.type)
+                    parameters[key] = bindparam(col, value)
                 except KeyError:
                     del parameters[key]
         return parameters
index 682b9fe323d257c0d8bae1a462d8ad8bde01e156..ca6eb35304b8a6aa28b71596eaf49b5bbd67f481 100644 (file)
@@ -195,6 +195,7 @@ class HistoryArraySet(UserList.UserList):
                 del self.records[item]
         except KeyError:
             pass
+        return True
     def commit(self):
         for key in self.records.keys():
             value = self.records[key]
index 58bd809e3c9a2248940555641c354b3178cccbbe..0549d209306375fac6e8a2acf4b0d28039a0ef14 100644 (file)
@@ -354,7 +354,13 @@ class CRUDTest(SQLTest):
         self.runtest(update(table, values = {table.c.name : table.c.id}), "UPDATE mytable SET name=mytable.myid")
         self.runtest(update(table, whereclause = table.c.name == bindparam('crit'), values = {table.c.name : 'hi'}), "UPDATE mytable SET name=:name WHERE mytable.name = :crit", params = {'crit' : 'notthere'})
         self.runtest(update(table, table.c.id == 12, values = {table.c.name : table.c.id}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'})
-
+        self.runtest(update(table, table.c.id == 12, values = {table.c.id : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'mytable_myid': 12, 'myid': 9, 'description': 'test'})
+        s = table.update(table.c.id == 12, values = {table.c.name : 'lala'})
+        print str(s)
+        c = s.compile(bindparams = {'mytable_id':9,'name':'h0h0'})
+        print str(c)
+        self.assert_(str(s) == str(c))
+        
     def testcorrelatedupdate(self):
         # test against a straight text subquery
         u = update(table, values = {table.c.name : text("select name from mytable where id=mytable.id")})