]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Sep 2005 17:53:20 +0000 (17:53 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Sep 2005 17:53:20 +0000 (17:53 +0000)
lib/sqlalchemy/attributes.py [new file with mode: 0644]
lib/sqlalchemy/mapper.py
lib/sqlalchemy/objectstore.py
lib/sqlalchemy/util.py
test/attributes.py [new file with mode: 0644]
test/historyarray.py
test/objectstore.py

diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py
new file mode 100644 (file)
index 0000000..9bc72c6
--- /dev/null
@@ -0,0 +1,185 @@
+import sqlalchemy.util as util
+import weakref
+
+class SmartProperty(object):
+    def __init__(self, manager):
+        self.manager = manager
+    def attribute_registry(self):
+        return self.manager
+    def property(self, key, uselist):
+        def set_prop(obj, value):
+            if uselist:
+                self.attribute_registry().set_list_attribute(obj, key, value)
+            else:
+                self.attribute_registry().set_attribute(obj, key, value)
+        def del_prop(obj):
+            if uselist:
+                # TODO: this probably doesnt work right, deleting the list off an item
+                self.attribute_registry().delete_list_attribute(obj, key)
+            else:
+                self.attribute_registry().delete_attribute(obj, key)
+        def get_prop(obj):
+            if uselist:
+                return self.attribute_registry().get_list_attribute(obj, key)
+            else:
+                return self.attribute_registry().get_attribute(obj, key)
+                
+        return property(get_prop, set_prop, del_prop)
+
+class ListElement(util.HistoryArraySet):
+    """overrides HistoryArraySet to mark the parent object as dirty when changes occur"""
+
+    def __init__(self, obj, key, items = None):
+        self.obj = obj
+        self.key = key
+        util.HistoryArraySet.__init__(self, items)
+        print "listelement init"
+
+    def list_value_changed(self, obj, key, listval):
+        pass    
+#        uow().modified_lists.append(self)
+
+    def setattr(self, value):
+        self.obj.__dict__[self.key] = value
+        self.set_data(value)
+    def delattr(self, value):
+        pass    
+    def _setrecord(self, item):
+        res = util.HistoryArraySet._setrecord(self, item)
+        if res:
+            self.list_value_changed(self.obj, self.key, self)
+        return res
+    def _delrecord(self, item):
+        res = util.HistoryArraySet._delrecord(self, item)
+        if res:
+            self.list_value_changed(self.obj, self.key, self)
+        return res
+
+class PropHistory(object):
+    # make our own NONE to distinguish from "None"
+    NONE = object()
+    def __init__(self, obj, key):
+        self.obj = obj
+        self.key = key
+        self.orig = PropHistory.NONE
+    def setattr_clean(self, value):
+        self.obj.__dict__[self.key] = value
+    def setattr(self, value):
+        self.orig = self.obj.__dict__.get(self.key, None)
+        self.obj.__dict__[self.key] = value
+    def delattr(self):
+        self.orig = self.obj.__dict__.get(self.key, None)
+        self.obj.__dict__[self.key] = None
+    def rollback(self):
+        if self.orig is not PropHistory.NONE:
+            self.obj.__dict__[self.key] = self.orig
+            self.orig = PropHistory.NONE
+    def clear_history(self):
+        self.orig = PropHistory.NONE
+    def added_items(self):
+        if self.orig is not PropHistory.NONE:
+            return [self.obj.__dict__[self.key]]
+        else:
+            return []
+    def deleted_items(self):
+        if self.orig is not PropHistory.NONE:
+            return [self.orig]
+        else:
+            return []
+    def unchanged_items(self):
+        if self.orig is PropHistory.NONE:
+            return [self.obj.__dict__[self.key]]
+        else:
+            return []
+
+class AttributeManager(object):
+    def __init__(self):
+        self.attribute_history = {}
+    def value_changed(self, obj, key, value):
+        pass
+#        if hasattr(obj, '_instance_key'):
+#            self.register_dirty(obj)
+#        else:
+#            self.register_new(obj)
+
+    def create_list(self, obj, key, list_):
+        return ListElement(obj, key, list_)
+        
+    def get_attribute(self, obj, key):
+        try:
+            v = obj.__dict__[key]
+        except KeyError:
+            raise AttributeError(key)
+        if (callable(v)):
+            v = v()
+            obj.__dict__[key] = v
+        return v
+
+    def get_list_attribute(self, obj, key):
+        return self.get_list_history(obj, key)
+        
+    def set_attribute(self, obj, key, value):
+        self.get_history(obj, key).setattr(value)
+        self.value_changed(obj, key, value)
+    
+    def set_list_attribute(self, obj, key, value):
+        self.get_list_history(obj, key).setattr(value)
+        
+    def delete_attribute(self, obj, key):
+        self.get_history(obj, key).delattr()
+        self.value_changed(obj, key, value)
+
+    def delete_list_attribute(self, obj, key):
+        pass
+        
+    def rollback(self, obj):
+        try:
+            attributes = self.attribute_history[obj]
+            for hist in attributes.values():
+                hist.rollback()
+        except KeyError:
+            pass
+
+    def clear_history(self, obj):
+        try:
+            attributes = self.attribute_history[obj]
+            for hist in attributes.values():
+                hist.clear_history()
+        except KeyError:
+            pass
+
+    def get_history(self, obj, key):
+        try:
+            return self.attribute_history[obj][key]
+        except KeyError, e:
+            if e.args[0] is obj:
+                d = {}
+                self.attribute_history[obj] = d
+                p = PropHistory(obj, key)
+                d[key] = p
+                return p
+            else:
+                p = PropHistory(obj, key)
+                self.attribute_history[obj][key] = p
+                return p
+
+    def get_list_history(self, obj, key):
+        try:
+            return self.attribute_history[obj][key]
+        except KeyError, e:
+            list_ = obj.__dict__.get(key, None)
+            if callable(list_):
+                list_ = list_()
+            if e.args[0] is obj:
+                d = {}
+                self.attribute_history[obj] = d
+                p = self.create_list(obj, key, list_)
+                d[key] = p
+                return p
+            else:
+                p = self.create_list(obj, key, list_)
+                self.attribute_history[obj][key] = p
+                return p
+
+    def register_attribute(self, class_, key, uselist):
+        setattr(class_, key, SmartProperty(self).property(key, uselist))
index a01e4b727c0d7dbb9626a28203fdb7d61c2dc001..953c02d382d5fe168e90d7e33f61866d505a117a 100644 (file)
@@ -140,7 +140,7 @@ class Mapper(object):
 
         self.init()
 
-    engines = property(lambda s: [t.engine for t in self.tables])
+    engines = property(lambda s: [t.engine for t in s.tables])
 
     def hash_key(self):
         return self.hashkey
@@ -365,9 +365,8 @@ class Mapper(object):
         return instance
 
     def rollback(self, obj):
-        for prop in self.props.values():
-            prop.rollback(obj)
-            
+        objectstore.uow().rollback_object(obj)
+        
 class MapperOption:
     """describes a modification to a Mapper in the context of making a copy
     of it.  This is used to assist in the prototype pattern used by mapper.options()."""
@@ -418,13 +417,12 @@ class ColumnProperty(MapperProperty):
         self.key = key
         # establish a SmartProperty property manager on the object for this key
         if not hasattr(parent.class_, key):
-            setattr(parent.class_, key, SmartProperty(key).property(usehistory = True))
+            objectstore.uow().register_attribute(parent.class_, key, uselist = False)
 
     def execute(self, instance, row, identitykey, imap, isnew):
         if isnew:
-            setattr(instance, self.key, row[self.columns[0].label])
-    def rollback(self, obj):
-        objectstore.uow().rollback_attribute(obj, self.key)
+            instance.__dict__[self.key] = row[self.columns[0].label]
+            #setattr(instance, self.key, row[self.columns[0].label])
         
 
 
@@ -444,9 +442,6 @@ class PropertyLoader(MapperProperty):
     def hash_key(self):
         return self._hash_key
 
-    def rollback(self, obj):
-        objectstore.uow().rollback_list_attribute(obj, self.key)
-
     def init(self, key, parent):
         self.key = key
         self.parent = parent
@@ -475,7 +470,7 @@ class PropertyLoader(MapperProperty):
                 self.foreignkey = w.dependent
                 
         if not hasattr(parent.class_, key):
-            setattr(parent.class_, key, SmartProperty(key).property(usehistory = True, uselist = self.uselist))
+            objectstore.uow().register_attribute(parent.class_, key, uselist = self.uselist)
 
     class FindDependent(sql.ClauseVisitor):
         def __init__(self):
@@ -530,9 +525,9 @@ class PropertyLoader(MapperProperty):
 
         def getlist(obj):
             if self.uselist:
-                return uowcommit.uow.register_list_attribute(obj, self.key)
-            else:
-                return uowcommit.uow.register_attribute(obj, self.key)
+                return uowcommit.uow.manager.get_list_history(obj, self.key)
+            else: 
+                return uowcommit.uow.manager.get_history(obj, self.key)
 
         clearkeys = False
         
@@ -607,10 +602,6 @@ class LazyLoader(PropertyLoader):
 
     def init(self, key, parent):
         PropertyLoader.init(self, key, parent)
-        if not hasattr(parent.class_, key):
-            if not issubclass(parent.class_, object):
-                raise "LazyLoader can only be used with new-style classes"
-            setattr(parent.class_, key, SmartProperty(key).property())
         if self.secondaryjoin is not None:
             self.lazywhere = sql.and_(self.primaryjoin, self.secondaryjoin)
         else:
@@ -704,14 +695,17 @@ class EagerLoader(PropertyLoader):
         it to a list on the parent instance."""
         if not self.uselist:
             # TODO: check for multiple values on single-element child element ?
-            setattr(instance, self.key, self.mapper._instance(row, imap))
+            instance.__dict__[self.key] = self.mapper._instance(row, imap)
+            #setattr(instance, self.key, self.mapper._instance(row, imap))
             return
         elif isnew:
-            result_list = objectstore.uow().register_list_attribute(instance, self.key, data = [])
+            result_list = []
+            setattr(instance, self.key, result_list)
+            result_list = getattr(instance, self.key)
             result_list.clear_history()
         else:
             result_list = getattr(instance, self.key)
-
+            
         self.mapper._instance(row, imap, result_list)
             
 class EagerLazySwitcher(MapperOption):
@@ -769,30 +763,6 @@ class BinaryVisitor(sql.ClauseVisitor):
     def visit_binary(self, binary):
         self.func(binary)
         
-class SmartProperty(object):
-    def __init__(self, key):
-        self.key = key
-
-    def property(self, usehistory = False, uselist = False):
-        def set_prop(s, value):
-            if uselist:
-                return objectstore.uow().register_list_attribute(s, self.key, value)
-            else:
-                objectstore.uow().set_attribute(s, self.key, value, usehistory)
-        def del_prop(s):
-            if uselist:
-                # TODO: this probably doesnt work right, deleting the list off an item
-                objectstore.uow().register_list_attribute(s, self.key, [])
-            else:
-                objectstore.uow().delete_attribute(s, self.key, value, usehistory)
-        def get_prop(s):
-            if uselist:
-                return objectstore.uow().register_list_attribute(s, self.key)
-            else:
-                return objectstore.uow().get_attribute(s, self.key)
-                
-        return property(get_prop, set_prop, del_prop)
-
   
 def hash_key(obj):
     if obj is None:
index dffa4afa217375e1ea7e60dd64c875f556590cca..a0c49ac5d1c08a72c48d8d4d531d8f0f076a33b5 100644 (file)
@@ -22,6 +22,7 @@ to objects so that they may be properly persisted within a transactional scope."
 
 import thread
 import sqlalchemy.util as util
+import sqlalchemy.attributes as attributes
 import weakref
 
 def get_id_key(ident, class_, table):
@@ -95,125 +96,43 @@ def has_key(key):
     else:
         return False
 
-class UOWListElement(util.HistoryArraySet):
-    """overrides HistoryArraySet to mark the parent object as dirty when changes occur"""
+class UOWListElement(attributes.ListElement):
+    def list_value_changed(self, obj, key, listval):
+        uow().modified_lists.append(self)
+
+class UOWAttributeManager(attributes.AttributeManager):
+    def __init__(self, uow):
+        attributes.AttributeManager.__init__(self)
+        self.uow = uow
         
-    def __init__(self, obj, items = None):
-        util.HistoryArraySet.__init__(self, items)
-        self.obj = weakref.ref(obj)
+    def value_changed(self, obj, key, value):
+        if hasattr(obj, '_instance_key'):
+            self.uow.register_dirty(obj)
+        else:
+            self.uow.register_new(obj)
+
+    def create_list(self, obj, key, list_):
+        return UOWListElement(obj, key, list_)
         
-    def _setrecord(self, item):
-        res = util.HistoryArraySet._setrecord(self, item)
-        if res:
-            uow().modified_lists.append(self)
-        return res
-    def _delrecord(self, item):
-        res = util.HistoryArraySet._delrecord(self, item)
-        if res:
-            uow().modified_lists.append(self)
-        return res
-    
 class UnitOfWork(object):
     def __init__(self, parent = None, is_begun = False):
         self.is_begun = is_begun
+        self.attributes = UOWAttributeManager(self)
         self.new = util.HashSet()
         self.dirty = util.HashSet()
         self.modified_lists = util.HashSet()
         self.deleted = util.HashSet()
-        self.attribute_history = weakref.WeakKeyDictionary()
         self.parent = parent
+
+    def register_attribute(self, class_, key, uselist):
+        self.attributes.register_attribute(class_, key, uselist)
         
     def attribute_set_callable(self, obj, key, func):
         obj.__dict__[key] = func
 
-    def get_attribute(self, obj, key):
-        try:
-            v = obj.__dict__[key]
-        except KeyError:
-            raise AttributeError(key)
-        if (callable(v)):
-            v = v()
-            obj.__dict__[key] = v
-            self.register_attribute(obj, key).setattr_clean(v)
-        return v
+    def rollback_object(self, obj):
+        self.attributes.rollback(obj)
     
-    def rollback_attribute(self, obj, key):
-        if self.attribute_history.has_key(obj):
-            h = self.attribute_history[obj][key]
-            h.rollback()
-            obj.__dict__[key] = h.current
-            
-    def set_attribute(self, obj, key, value, usehistory = False):
-        if usehistory:
-            self.register_attribute(obj, key).setattr(value)
-        obj.__dict__[key] = value
-        if hasattr(obj, '_instance_key'):
-            self.register_dirty(obj)
-        else:
-            self.register_new(obj)
-        
-    def delete_attribute(self, obj, key, value, usehistory = False):
-        if usehistory:
-            self.register_attribute(obj, key).delattr(value)    
-        del obj.__dict__[key]
-        if hasattr(obj, '_instance_key'):
-            self.register_dirty(obj)
-        else:
-            self.register_new(obj)
-    
-    def rollback_obj(self, obj):
-        try:
-            attributes = self.attribute_history[obj]
-            for key, hist in attributes.iteritems():
-                hist.rollback()
-                obj.__dict__[key] = hist.current
-        except KeyError:
-            pass
-        for value in obj.__dict__.values():
-            if isinstance(value, util.HistoryArraySet):
-                value.rollback()
-    def register_attribute(self, obj, key):
-        try:
-            attributes = self.attribute_history[obj]
-        except KeyError:
-            attributes = self.attribute_history.setdefault(obj, {})
-        try:
-            return attributes[key]
-        except KeyError:
-            return attributes.setdefault(key, util.PropHistory(obj.__dict__.get(key, None)))
-
-    def register_list_attribute(self, obj, key, data = None):
-        try:
-            attributes = self.attribute_history[obj]
-        except KeyError:
-            attributes = self.attribute_history.setdefault(obj, {})
-        try:
-            childlist = attributes[key]
-        except KeyError:
-            try:
-                list = obj.__dict__[key]
-                if callable(list):
-                    list = list()
-            except KeyError:
-                list = []
-                obj.__dict__[key] = list
-
-            childlist = UOWListElement(obj, list)
-            
-        if data is not None and childlist.data != data:
-            try:
-                childlist.set_data(data)
-            except TypeError:
-                raise "object " + repr(data) + " is not an iterable object"
-        return childlist
-    
-    def rollback_list_attribute(self, obj, key):
-        try:
-            childlist = obj.__dict__[key]
-            if isinstance(childlist, util.HistoryArraySet):
-                childlist.rollback()
-        except KeyError:
-            pass    
     def register_clean(self, obj, scope="thread"):
         try:
             del self.dirty[obj]
@@ -263,7 +182,7 @@ class UnitOfWork(object):
                 commit_context.append_task(obj)
 
         engines = util.HashSet()
-        for mapper in commit_context.mappers.keys():
+        for mapper in commit_context.mappers:
             for e in mapper.engines:
                 engines.append(e)
                 
@@ -288,8 +207,8 @@ class UnitOfWork(object):
 class UOWTransaction(object):
     def __init__(self, uow):
         self.uow = uow
-        self.mappers = {}
-        self.engines = util.HashSet()
+        self.object_mappers = {}
+        self.mappers = util.HashSet()
         self.dependencies = {}
         self.tasks = {}
         self.saved_objects = util.HashSet()
@@ -322,10 +241,11 @@ class UOWTransaction(object):
     def object_mapper(self, obj):
         import sqlalchemy.mapper
         try:
-            return self.mappers[obj]
+            return self.object_mappers[obj]
         except KeyError:
             mapper = sqlalchemy.mapper.object_mapper(obj)
-            self.mappers[obj] = mapper
+            self.object_mappers[obj] = mapper
+            self.mappers.append(mapper)
             return mapper
             
     def execute(self):
index 23c2420adf7d416585147466f3813e8a67de52f0..20d45142f248d27b68bd69ac276ac12da2513f69 100644 (file)
@@ -128,15 +128,16 @@ class HashSet(object):
         return self.map[key]
         
 class HistoryArraySet(UserList.UserList):
-    def __init__(self, items = None, data = None):
-        UserList.UserList.__init__(self, items)
+    def __init__(self, data = None):
         # stores the array's items as keys, and a value of True, False or None indicating
         # added, deleted, or unchanged for that item
+        self.records = OrderedDict()
         if data is not None:
             self.data = data
-        self.records = {}
-        for i in self.data:
-            self.records[i] = True
+            for item in data:
+                self._setrecord(item)
+        else:
+            self.data = []
 
     def set_data(self, data):
         # first mark everything current as "deleted"
diff --git a/test/attributes.py b/test/attributes.py
new file mode 100644 (file)
index 0000000..718f302
--- /dev/null
@@ -0,0 +1,78 @@
+from testbase import PersistTest
+import sqlalchemy.util as util
+import sqlalchemy.attributes as attributes
+import unittest, sys, os
+
+
+    
+class AttributesTest(PersistTest):
+    def testbasic(self):
+        class User(object):pass
+        manager = attributes.AttributeManager()
+        manager.register_attribute(User, 'user_id', uselist = False)
+        manager.register_attribute(User, 'user_name', uselist = False)
+        manager.register_attribute(User, 'email_address', uselist = False)
+        
+        u = User()
+        print repr(u.__dict__)
+        
+        u.user_id = 7
+        u.user_name = 'john'
+        u.email_address = 'lala@123.com'
+        
+        print repr(u.__dict__)
+        self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+        manager.clear_history(u)
+        print repr(u.__dict__)
+        self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+
+        u.user_name = 'heythere'
+        u.email_address = 'foo@bar.com'
+        print repr(u.__dict__)
+        self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com')
+        
+        manager.rollback(u)
+        print repr(u.__dict__)
+        self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+
+    def testlist(self):
+        class User(object):pass
+        class Address(object):pass
+        manager = attributes.AttributeManager()
+        manager.register_attribute(User, 'user_id', uselist = False)
+        manager.register_attribute(User, 'user_name', uselist = False)
+        manager.register_attribute(User, 'addresses', uselist = True)
+        manager.register_attribute(Address, 'address_id', uselist = False)
+        manager.register_attribute(Address, 'email_address', uselist = False)
+        
+        u = User()
+        print repr(u.__dict__)
+
+        u.user_id = 7
+        u.user_name = 'john'
+        u.addresses = []
+        a = Address()
+        a.address_id = 10
+        a.email_address = 'lala@123.com'
+        u.addresses.append(a)
+
+        print repr(u.__dict__)
+        self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
+        manager.clear_history(u)
+        print repr(u.__dict__)
+        self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
+
+        u.user_name = 'heythere'
+        a = Address()
+        a.address_id = 11
+        a.email_address = 'foo@bar.com'
+        u.addresses.append(a)
+        print repr(u.__dict__)
+        self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com')
+
+        manager.rollback(u)
+        print repr(u.__dict__)
+        self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
+
+if __name__ == "__main__":
+    unittest.main()
index 6e987a5aba02c26ab45e8710d2b84aac94e1b76f..bd5c81ad9bf4813f8d0b6470298d99517738440e 100644 (file)
@@ -38,6 +38,24 @@ class HistoryArrayTest(PersistTest):
         self.assert_(a.deleted_items() == [])
         self.assert_(a == ['hi'])
     
+    def testrollback(self):
+        a = util.HistoryArraySet()
+        a.append('hi')
+        a.append('there')
+        a.append('yo')
+        a.clear_history()
+        before = repr(a.data)
+        print repr(a.data)
+        a.remove('there')
+        a.append('lala')
+        a.remove('yo')
+        a.append('yo')
+        after = repr(a.data)
+        print repr(a.data)
+        a.rollback()
+        print repr(a.data)
+        self.assert_(before == repr(a.data))
+        
     def testarray(self):
         a = util.HistoryArraySet()
         a.append('hi')
index 49a310b69d6aab11d2a88e01cfee059ad7d38132..a257b4b4fc0d1470308a935ad300091080f71148 100644 (file)
@@ -27,9 +27,11 @@ class HistoryTest(AssertMixin):
         u = User()
         u.user_id = 7
         u.user_name = 'afdas'
-        u.addresses = [Address(), Address()]
+        u.addresses.append(Address())
         u.addresses[0].email_address = 'hi'
+        u.addresses.append(Address())
         u.addresses[1].email_address = 'there'
+        print repr(u.__dict__)
         m.rollback(u)
         print repr(u.__dict__)