From: Mike Bayer Date: Sat, 3 Sep 2005 05:47:20 +0000 (+0000) Subject: (no commit message) X-Git-Tag: rel_0_1_0~810 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=65109a881f23f0a780aca68ff1ea4f22c2a7c9d2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git --- diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 7588b6aa16..1c86bf8c30 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -16,7 +16,7 @@ # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. __ALL__ = ['OrderedProperties', 'OrderedDict'] -import thread, weakref +import thread, weakref, UserList class OrderedProperties(object): @@ -117,26 +117,92 @@ class HashSet(object): def __getitem__(self, key): return self.map[key] -class HistoryAwareArraySet(object): - def __init__(self): - self.elements = [] - self.data = {} - def __len__(self): - return len(self.elements) - def append(self, item): - if not hasattr(self.data, value): - self.data[value] = ['True', value] - self.elements.append(value) - def __setitem__(self, key, value): - if not hasattr(self.data, value): - self.data[value] = ['True', value] - self.elements[key] = value - def __getitem__(self, key): - return self.elements[key] - def __delitem__(self, key): - pass - def __iter__(self): - return iter(self.map.values()) +class HistoryArraySet(UserList.UserList): + def __init__(self, items = None): + UserList.UserList.__init__(self, items) + # stores the array's items as keys, and a value of True, False or None indicating + # added, deleted, or unchanged for that item + self.records = {} + if items is not None: + for i in items: + self.records[i] = True + def _setrecord(self, item): + try: + val = self.records[item] + if val is True or val is None: + return False + else: + self.records[item] = None + return True + except KeyError: + self.records[item] = True + return True + def _delrecord(self, item): + try: + val = self.records[item] + if val is None: + self.records[item] = False + elif val is True: + del self.records[item] + except KeyError: + pass + def clear_history(self): + for key in self.records.keys(): + value = self.records[key] + if value is False: + del self.records[key] + else: + self.records[key] = None + def added_items(self): + return [key for key, value in self.records.iteritems() if value is True] + def deleted_items(self): + return [key for key, value in self.records.iteritems() if value is False] + def append_nohistory(self, item): + if not self.records.has_key(item): + self.records[item] = None + self.data.append(item) + def has_item(self, item): + return self.records.has_key(item) + def __setitem__(self, i, item): + if self._setrecord(a): + self.data[i] = item + def __delitem__(self, i): + self._delrecord(self.data[i]) + del self.data[i] + def __setslice__(self, i, j, other): + i = max(i, 0); j = max(j, 0) + if isinstance(other, UserList.UserList): + l = other.data + elif isinstance(other, type(self.data)): + l = other + else: + l = list(other) + g = [a for a in l if self._setrecord(a)] + self.data[i:] = g + def __delslice__(self, i, j): + i = max(i, 0); j = max(j, 0) + for a in self.data[i:j]: + self._delrecord(a) + del self.data[i:j] + def append(self, item): + if self._setrecord(item): + self.data.append(item) + def insert(self, i, item): + if self._setrecord(item): + self.data.insert(i, item) + def pop(self, i=-1): + item = self.data[i] + self._delrecord(item) + return self.data.pop(i) + def remove(self, item): + self._delrecord(item) + self.data.remove(item) + def __add__(self, other): + raise NotImplementedError() + def __radd__(self, other): + raise NotImplementedError() + def __iadd__(self, other): + raise NotImplementedError() class ScopedRegistry(object): def __init__(self, createfunc): diff --git a/test/historyarray.py b/test/historyarray.py new file mode 100644 index 0000000000..6e987a5aba --- /dev/null +++ b/test/historyarray.py @@ -0,0 +1,53 @@ +from testbase import PersistTest +import sqlalchemy.util as util +import unittest, sys, os + +class HistoryArrayTest(PersistTest): + def testadd(self): + a = util.HistoryArraySet() + a.append('hi') + self.assert_(a == ['hi']) + self.assert_(a.added_items() == ['hi']) + + def testremove(self): + a = util.HistoryArraySet() + a.append('hi') + a.clear_history() + self.assert_(a == ['hi']) + self.assert_(a.added_items() == []) + a.remove('hi') + self.assert_(a == []) + self.assert_(a.deleted_items() == ['hi']) + + def testremoveadded(self): + a = util.HistoryArraySet() + a.append('hi') + a.remove('hi') + self.assert_(a.added_items() == []) + self.assert_(a.deleted_items() == []) + self.assert_(a == []) + + def testaddedremoved(self): + a = util.HistoryArraySet() + a.append('hi') + a.clear_history() + a.remove('hi') + self.assert_(a.deleted_items() == ['hi']) + a.append('hi') + self.assert_(a.added_items() == []) + self.assert_(a.deleted_items() == []) + self.assert_(a == ['hi']) + + def testarray(self): + a = util.HistoryArraySet() + a.append('hi') + a.append('there') + self.assert_(a[0] == 'hi' and a[1] == 'there') + del a[1] + self.assert_(a == ['hi']) + a.append('hi') + a.append('there') + a[3:4] = ['yo', 'hi'] + self.assert_(a == ['hi', 'there', 'yo']) +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/test/mapper.py b/test/mapper.py index 756ff066e5..70248e34bd 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -268,24 +268,30 @@ class SaveTest(PersistTest): """tests a save of an object where each instance spans two tables. also tests redefinition of the keynames for the column properties.""" usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) - m = mapper(User, usersaddresses, table = users, echo = True, properties = dict(email = ColumnProperty(addresses.c.email_address), foo_id = ColumnProperty(users.c.user_id, addresses.c.user_id))) + m = mapper(User, usersaddresses, table = users, echo = True, + properties = dict( + email = ColumnProperty(addresses.c.email_address), + foo_id = ColumnProperty(users.c.user_id, addresses.c.user_id) + ) + ) + u = User() u.user_name = 'multitester' u.email = 'multi@test.org' m.save(u) - usertable = engine.ResultProxy(users.select(users.c.user_id.in_(10)).execute()).fetchall() - self.assert_(usertable[0].row == (10, 'multitester')) + usertable = engine.ResultProxy(users.select(users.c.user_id.in_(u.foo_id)).execute()).fetchall() + self.assert_(usertable[0].row == (u.foo_id, 'multitester')) addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(4)).execute()).fetchall() - self.assert_(addresstable[0].row == (4, 10, 'multi@test.org')) + self.assert_(addresstable[0].row == (u.address_id, u.foo_id, 'multi@test.org')) u.email = 'lala@hey.com' u.user_name = 'imnew' m.save(u) - usertable = engine.ResultProxy(users.select(users.c.user_id.in_(u.user_id)).execute()).fetchall() - self.assert_(usertable[0].row == (u.user_id, 'imnew')) + usertable = engine.ResultProxy(users.select(users.c.user_id.in_(u.foo_id)).execute()).fetchall() + self.assert_(usertable[0].row == (u.foo_id, 'imnew')) addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(u.address_id)).execute()).fetchall() - self.assert_(addresstable[0].row == (u.address_id, u.user_id, 'lala@hey.com')) + self.assert_(addresstable[0].row == (u.address_id, u.foo_id, 'lala@hey.com')) def testonetomany(self): m = mapper(User, users, properties = dict(