]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Sep 2005 05:47:20 +0000 (05:47 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Sep 2005 05:47:20 +0000 (05:47 +0000)
lib/sqlalchemy/util.py
test/historyarray.py [new file with mode: 0644]
test/mapper.py

index 7588b6aa1622cdc4c62c50b4d7bae0c4e0ddbdb4..1c86bf8c30848455785231d4ed700c12203995ee 100644 (file)
@@ -16,7 +16,7 @@
 # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 
 __ALL__ = ['OrderedProperties', 'OrderedDict']
-import thread, weakref
+import thread, weakref, UserList
 
 class OrderedProperties(object):
 
@@ -117,26 +117,92 @@ class HashSet(object):
     def __getitem__(self, key):
         return self.map[key]
         
-class HistoryAwareArraySet(object):
-    def __init__(self):
-        self.elements = []
-        self.data = {}
-    def __len__(self):
-        return len(self.elements)
-    def append(self, item):
-        if not hasattr(self.data, value):
-            self.data[value] = ['True', value]
-            self.elements.append(value)
-    def __setitem__(self, key, value):
-        if not hasattr(self.data, value):
-            self.data[value] = ['True', value]
-            self.elements[key] = value
-    def __getitem__(self, key):
-        return self.elements[key]
-    def __delitem__(self, key):
-        pass
-    def __iter__(self):
-        return iter(self.map.values())
+class HistoryArraySet(UserList.UserList):
+    def __init__(self, items = None):
+        UserList.UserList.__init__(self, items)
+        # stores the array's items as keys, and a value of True, False or None indicating
+        # added, deleted, or unchanged for that item
+        self.records = {}
+        if items is not None:
+            for i in items:
+                self.records[i] = True
+    def _setrecord(self, item):
+        try:
+            val = self.records[item]
+            if val is True or val is None:
+                return False
+            else:
+                self.records[item] = None
+                return True
+        except KeyError:
+            self.records[item] = True
+            return True
+    def _delrecord(self, item):
+        try:
+            val = self.records[item]
+            if val is None:
+                self.records[item] = False
+            elif val is True:
+                del self.records[item]
+        except KeyError:
+            pass
+    def clear_history(self):
+        for key in self.records.keys():
+            value = self.records[key]
+            if value is False:
+                del self.records[key]
+            else:
+                self.records[key] = None
+    def added_items(self):
+        return [key for key, value in self.records.iteritems() if value is True]
+    def deleted_items(self):
+        return [key for key, value in self.records.iteritems() if value is False]
+    def append_nohistory(self, item):
+        if not self.records.has_key(item):
+            self.records[item] = None
+            self.data.append(item)
+    def has_item(self, item):
+        return self.records.has_key(item)
+    def __setitem__(self, i, item): 
+        if self._setrecord(a):
+            self.data[i] = item
+    def __delitem__(self, i):
+        self._delrecord(self.data[i])
+        del self.data[i]
+    def __setslice__(self, i, j, other):
+        i = max(i, 0); j = max(j, 0)
+        if isinstance(other, UserList.UserList):
+            l = other.data
+        elif isinstance(other, type(self.data)):
+            l = other
+        else:
+            l = list(other)
+        g = [a for a in l if self._setrecord(a)]
+        self.data[i:] = g
+    def __delslice__(self, i, j):
+        i = max(i, 0); j = max(j, 0)
+        for a in self.data[i:j]:
+            self._delrecord(a)
+        del self.data[i:j]
+    def append(self, item): 
+        if self._setrecord(item):
+            self.data.append(item)
+    def insert(self, i, item): 
+        if self._setrecord(item):
+            self.data.insert(i, item)
+    def pop(self, i=-1):
+        item = self.data[i]
+        self._delrecord(item) 
+        return self.data.pop(i)
+    def remove(self, item): 
+        self._delrecord(item)
+        self.data.remove(item)
+    def __add__(self, other):
+        raise NotImplementedError()
+    def __radd__(self, other):
+        raise NotImplementedError()
+    def __iadd__(self, other):
+        raise NotImplementedError()
         
 class ScopedRegistry(object):
     def __init__(self, createfunc):
diff --git a/test/historyarray.py b/test/historyarray.py
new file mode 100644 (file)
index 0000000..6e987a5
--- /dev/null
@@ -0,0 +1,53 @@
+from testbase import PersistTest
+import sqlalchemy.util as util
+import unittest, sys, os
+
+class HistoryArrayTest(PersistTest):
+    def testadd(self):
+        a = util.HistoryArraySet()
+        a.append('hi')
+        self.assert_(a == ['hi'])
+        self.assert_(a.added_items() == ['hi'])
+    
+    def testremove(self):
+        a = util.HistoryArraySet()
+        a.append('hi')
+        a.clear_history()
+        self.assert_(a == ['hi'])
+        self.assert_(a.added_items() == [])
+        a.remove('hi')
+        self.assert_(a == [])
+        self.assert_(a.deleted_items() == ['hi'])
+        
+    def testremoveadded(self):
+        a = util.HistoryArraySet()
+        a.append('hi')
+        a.remove('hi')
+        self.assert_(a.added_items() == [])
+        self.assert_(a.deleted_items() == [])
+        self.assert_(a == [])
+
+    def testaddedremoved(self):
+        a = util.HistoryArraySet()
+        a.append('hi')
+        a.clear_history()
+        a.remove('hi')
+        self.assert_(a.deleted_items() == ['hi'])
+        a.append('hi')
+        self.assert_(a.added_items() == [])
+        self.assert_(a.deleted_items() == [])
+        self.assert_(a == ['hi'])
+    
+    def testarray(self):
+        a = util.HistoryArraySet()
+        a.append('hi')
+        a.append('there')
+        self.assert_(a[0] == 'hi' and a[1] == 'there')
+        del a[1]
+        self.assert_(a == ['hi'])
+        a.append('hi')
+        a.append('there')
+        a[3:4] = ['yo', 'hi']
+        self.assert_(a == ['hi', 'there', 'yo'])    
+if __name__ == "__main__":
+    unittest.main()
\ No newline at end of file
index 756ff066e5d0a4b97019009c3fc18ba8b99e643d..70248e34bddca51fa34176c67eb9d518e8e95717 100644 (file)
@@ -268,24 +268,30 @@ class SaveTest(PersistTest):
         """tests a save of an object where each instance spans two tables. also tests
         redefinition of the keynames for the column properties."""
         usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
-        m = mapper(User, usersaddresses, table = users, echo = True, properties = dict(email = ColumnProperty(addresses.c.email_address), foo_id = ColumnProperty(users.c.user_id, addresses.c.user_id)))
+        m = mapper(User, usersaddresses, table = users, echo = True, 
+            properties = dict(
+                email = ColumnProperty(addresses.c.email_address), 
+                foo_id = ColumnProperty(users.c.user_id, addresses.c.user_id)
+                )
+            )
+            
         u = User()
         u.user_name = 'multitester'
         u.email = 'multi@test.org'
         m.save(u)
 
-        usertable = engine.ResultProxy(users.select(users.c.user_id.in_(10)).execute()).fetchall()
-        self.assert_(usertable[0].row == (10, 'multitester'))
+        usertable = engine.ResultProxy(users.select(users.c.user_id.in_(u.foo_id)).execute()).fetchall()
+        self.assert_(usertable[0].row == (u.foo_id, 'multitester'))
         addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(4)).execute()).fetchall()
-        self.assert_(addresstable[0].row == (4, 10, 'multi@test.org'))
+        self.assert_(addresstable[0].row == (u.address_id, u.foo_id, 'multi@test.org'))
 
         u.email = 'lala@hey.com'
         u.user_name = 'imnew'
         m.save(u)
-        usertable = engine.ResultProxy(users.select(users.c.user_id.in_(u.user_id)).execute()).fetchall()
-        self.assert_(usertable[0].row == (u.user_id, 'imnew'))
+        usertable = engine.ResultProxy(users.select(users.c.user_id.in_(u.foo_id)).execute()).fetchall()
+        self.assert_(usertable[0].row == (u.foo_id, 'imnew'))
         addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(u.address_id)).execute()).fetchall()
-        self.assert_(addresstable[0].row == (u.address_id, u.user_id, 'lala@hey.com'))
+        self.assert_(addresstable[0].row == (u.address_id, u.foo_id, 'lala@hey.com'))
 
     def testonetomany(self):
         m = mapper(User, users, properties = dict(