From a7024527fbd1747f6b063de352e7b09ee18715e1 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 17 Sep 2005 17:53:20 +0000 Subject: [PATCH] --- lib/sqlalchemy/attributes.py | 185 ++++++++++++++++++++++++++++++++++ lib/sqlalchemy/mapper.py | 62 +++--------- lib/sqlalchemy/objectstore.py | 138 ++++++------------------- lib/sqlalchemy/util.py | 11 +- test/attributes.py | 78 ++++++++++++++ test/historyarray.py | 18 ++++ test/objectstore.py | 4 +- 7 files changed, 335 insertions(+), 161 deletions(-) create mode 100644 lib/sqlalchemy/attributes.py create mode 100644 test/attributes.py diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py new file mode 100644 index 0000000000..9bc72c6e25 --- /dev/null +++ b/lib/sqlalchemy/attributes.py @@ -0,0 +1,185 @@ +import sqlalchemy.util as util +import weakref + +class SmartProperty(object): + def __init__(self, manager): + self.manager = manager + def attribute_registry(self): + return self.manager + def property(self, key, uselist): + def set_prop(obj, value): + if uselist: + self.attribute_registry().set_list_attribute(obj, key, value) + else: + self.attribute_registry().set_attribute(obj, key, value) + def del_prop(obj): + if uselist: + # TODO: this probably doesnt work right, deleting the list off an item + self.attribute_registry().delete_list_attribute(obj, key) + else: + self.attribute_registry().delete_attribute(obj, key) + def get_prop(obj): + if uselist: + return self.attribute_registry().get_list_attribute(obj, key) + else: + return self.attribute_registry().get_attribute(obj, key) + + return property(get_prop, set_prop, del_prop) + +class ListElement(util.HistoryArraySet): + """overrides HistoryArraySet to mark the parent object as dirty when changes occur""" + + def __init__(self, obj, key, items = None): + self.obj = obj + self.key = key + util.HistoryArraySet.__init__(self, items) + print "listelement init" + + def list_value_changed(self, obj, key, listval): + pass +# uow().modified_lists.append(self) + + def setattr(self, value): + self.obj.__dict__[self.key] = value + self.set_data(value) + def delattr(self, value): + pass + def _setrecord(self, item): + res = util.HistoryArraySet._setrecord(self, item) + if res: + self.list_value_changed(self.obj, self.key, self) + return res + def _delrecord(self, item): + res = util.HistoryArraySet._delrecord(self, item) + if res: + self.list_value_changed(self.obj, self.key, self) + return res + +class PropHistory(object): + # make our own NONE to distinguish from "None" + NONE = object() + def __init__(self, obj, key): + self.obj = obj + self.key = key + self.orig = PropHistory.NONE + def setattr_clean(self, value): + self.obj.__dict__[self.key] = value + def setattr(self, value): + self.orig = self.obj.__dict__.get(self.key, None) + self.obj.__dict__[self.key] = value + def delattr(self): + self.orig = self.obj.__dict__.get(self.key, None) + self.obj.__dict__[self.key] = None + def rollback(self): + if self.orig is not PropHistory.NONE: + self.obj.__dict__[self.key] = self.orig + self.orig = PropHistory.NONE + def clear_history(self): + self.orig = PropHistory.NONE + def added_items(self): + if self.orig is not PropHistory.NONE: + return [self.obj.__dict__[self.key]] + else: + return [] + def deleted_items(self): + if self.orig is not PropHistory.NONE: + return [self.orig] + else: + return [] + def unchanged_items(self): + if self.orig is PropHistory.NONE: + return [self.obj.__dict__[self.key]] + else: + return [] + +class AttributeManager(object): + def __init__(self): + self.attribute_history = {} + def value_changed(self, obj, key, value): + pass +# if hasattr(obj, '_instance_key'): +# self.register_dirty(obj) +# else: +# self.register_new(obj) + + def create_list(self, obj, key, list_): + return ListElement(obj, key, list_) + + def get_attribute(self, obj, key): + try: + v = obj.__dict__[key] + except KeyError: + raise AttributeError(key) + if (callable(v)): + v = v() + obj.__dict__[key] = v + return v + + def get_list_attribute(self, obj, key): + return self.get_list_history(obj, key) + + def set_attribute(self, obj, key, value): + self.get_history(obj, key).setattr(value) + self.value_changed(obj, key, value) + + def set_list_attribute(self, obj, key, value): + self.get_list_history(obj, key).setattr(value) + + def delete_attribute(self, obj, key): + self.get_history(obj, key).delattr() + self.value_changed(obj, key, value) + + def delete_list_attribute(self, obj, key): + pass + + def rollback(self, obj): + try: + attributes = self.attribute_history[obj] + for hist in attributes.values(): + hist.rollback() + except KeyError: + pass + + def clear_history(self, obj): + try: + attributes = self.attribute_history[obj] + for hist in attributes.values(): + hist.clear_history() + except KeyError: + pass + + def get_history(self, obj, key): + try: + return self.attribute_history[obj][key] + except KeyError, e: + if e.args[0] is obj: + d = {} + self.attribute_history[obj] = d + p = PropHistory(obj, key) + d[key] = p + return p + else: + p = PropHistory(obj, key) + self.attribute_history[obj][key] = p + return p + + def get_list_history(self, obj, key): + try: + return self.attribute_history[obj][key] + except KeyError, e: + list_ = obj.__dict__.get(key, None) + if callable(list_): + list_ = list_() + if e.args[0] is obj: + d = {} + self.attribute_history[obj] = d + p = self.create_list(obj, key, list_) + d[key] = p + return p + else: + p = self.create_list(obj, key, list_) + self.attribute_history[obj][key] = p + return p + + def register_attribute(self, class_, key, uselist): + setattr(class_, key, SmartProperty(self).property(key, uselist)) diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index a01e4b727c..953c02d382 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -140,7 +140,7 @@ class Mapper(object): self.init() - engines = property(lambda s: [t.engine for t in self.tables]) + engines = property(lambda s: [t.engine for t in s.tables]) def hash_key(self): return self.hashkey @@ -365,9 +365,8 @@ class Mapper(object): return instance def rollback(self, obj): - for prop in self.props.values(): - prop.rollback(obj) - + objectstore.uow().rollback_object(obj) + class MapperOption: """describes a modification to a Mapper in the context of making a copy of it. This is used to assist in the prototype pattern used by mapper.options().""" @@ -418,13 +417,12 @@ class ColumnProperty(MapperProperty): self.key = key # establish a SmartProperty property manager on the object for this key if not hasattr(parent.class_, key): - setattr(parent.class_, key, SmartProperty(key).property(usehistory = True)) + objectstore.uow().register_attribute(parent.class_, key, uselist = False) def execute(self, instance, row, identitykey, imap, isnew): if isnew: - setattr(instance, self.key, row[self.columns[0].label]) - def rollback(self, obj): - objectstore.uow().rollback_attribute(obj, self.key) + instance.__dict__[self.key] = row[self.columns[0].label] + #setattr(instance, self.key, row[self.columns[0].label]) @@ -444,9 +442,6 @@ class PropertyLoader(MapperProperty): def hash_key(self): return self._hash_key - def rollback(self, obj): - objectstore.uow().rollback_list_attribute(obj, self.key) - def init(self, key, parent): self.key = key self.parent = parent @@ -475,7 +470,7 @@ class PropertyLoader(MapperProperty): self.foreignkey = w.dependent if not hasattr(parent.class_, key): - setattr(parent.class_, key, SmartProperty(key).property(usehistory = True, uselist = self.uselist)) + objectstore.uow().register_attribute(parent.class_, key, uselist = self.uselist) class FindDependent(sql.ClauseVisitor): def __init__(self): @@ -530,9 +525,9 @@ class PropertyLoader(MapperProperty): def getlist(obj): if self.uselist: - return uowcommit.uow.register_list_attribute(obj, self.key) - else: - return uowcommit.uow.register_attribute(obj, self.key) + return uowcommit.uow.manager.get_list_history(obj, self.key) + else: + return uowcommit.uow.manager.get_history(obj, self.key) clearkeys = False @@ -607,10 +602,6 @@ class LazyLoader(PropertyLoader): def init(self, key, parent): PropertyLoader.init(self, key, parent) - if not hasattr(parent.class_, key): - if not issubclass(parent.class_, object): - raise "LazyLoader can only be used with new-style classes" - setattr(parent.class_, key, SmartProperty(key).property()) if self.secondaryjoin is not None: self.lazywhere = sql.and_(self.primaryjoin, self.secondaryjoin) else: @@ -704,14 +695,17 @@ class EagerLoader(PropertyLoader): it to a list on the parent instance.""" if not self.uselist: # TODO: check for multiple values on single-element child element ? - setattr(instance, self.key, self.mapper._instance(row, imap)) + instance.__dict__[self.key] = self.mapper._instance(row, imap) + #setattr(instance, self.key, self.mapper._instance(row, imap)) return elif isnew: - result_list = objectstore.uow().register_list_attribute(instance, self.key, data = []) + result_list = [] + setattr(instance, self.key, result_list) + result_list = getattr(instance, self.key) result_list.clear_history() else: result_list = getattr(instance, self.key) - + self.mapper._instance(row, imap, result_list) class EagerLazySwitcher(MapperOption): @@ -769,30 +763,6 @@ class BinaryVisitor(sql.ClauseVisitor): def visit_binary(self, binary): self.func(binary) -class SmartProperty(object): - def __init__(self, key): - self.key = key - - def property(self, usehistory = False, uselist = False): - def set_prop(s, value): - if uselist: - return objectstore.uow().register_list_attribute(s, self.key, value) - else: - objectstore.uow().set_attribute(s, self.key, value, usehistory) - def del_prop(s): - if uselist: - # TODO: this probably doesnt work right, deleting the list off an item - objectstore.uow().register_list_attribute(s, self.key, []) - else: - objectstore.uow().delete_attribute(s, self.key, value, usehistory) - def get_prop(s): - if uselist: - return objectstore.uow().register_list_attribute(s, self.key) - else: - return objectstore.uow().get_attribute(s, self.key) - - return property(get_prop, set_prop, del_prop) - def hash_key(obj): if obj is None: diff --git a/lib/sqlalchemy/objectstore.py b/lib/sqlalchemy/objectstore.py index dffa4afa21..a0c49ac5d1 100644 --- a/lib/sqlalchemy/objectstore.py +++ b/lib/sqlalchemy/objectstore.py @@ -22,6 +22,7 @@ to objects so that they may be properly persisted within a transactional scope." import thread import sqlalchemy.util as util +import sqlalchemy.attributes as attributes import weakref def get_id_key(ident, class_, table): @@ -95,125 +96,43 @@ def has_key(key): else: return False -class UOWListElement(util.HistoryArraySet): - """overrides HistoryArraySet to mark the parent object as dirty when changes occur""" +class UOWListElement(attributes.ListElement): + def list_value_changed(self, obj, key, listval): + uow().modified_lists.append(self) + +class UOWAttributeManager(attributes.AttributeManager): + def __init__(self, uow): + attributes.AttributeManager.__init__(self) + self.uow = uow - def __init__(self, obj, items = None): - util.HistoryArraySet.__init__(self, items) - self.obj = weakref.ref(obj) + def value_changed(self, obj, key, value): + if hasattr(obj, '_instance_key'): + self.uow.register_dirty(obj) + else: + self.uow.register_new(obj) + + def create_list(self, obj, key, list_): + return UOWListElement(obj, key, list_) - def _setrecord(self, item): - res = util.HistoryArraySet._setrecord(self, item) - if res: - uow().modified_lists.append(self) - return res - def _delrecord(self, item): - res = util.HistoryArraySet._delrecord(self, item) - if res: - uow().modified_lists.append(self) - return res - class UnitOfWork(object): def __init__(self, parent = None, is_begun = False): self.is_begun = is_begun + self.attributes = UOWAttributeManager(self) self.new = util.HashSet() self.dirty = util.HashSet() self.modified_lists = util.HashSet() self.deleted = util.HashSet() - self.attribute_history = weakref.WeakKeyDictionary() self.parent = parent + + def register_attribute(self, class_, key, uselist): + self.attributes.register_attribute(class_, key, uselist) def attribute_set_callable(self, obj, key, func): obj.__dict__[key] = func - def get_attribute(self, obj, key): - try: - v = obj.__dict__[key] - except KeyError: - raise AttributeError(key) - if (callable(v)): - v = v() - obj.__dict__[key] = v - self.register_attribute(obj, key).setattr_clean(v) - return v + def rollback_object(self, obj): + self.attributes.rollback(obj) - def rollback_attribute(self, obj, key): - if self.attribute_history.has_key(obj): - h = self.attribute_history[obj][key] - h.rollback() - obj.__dict__[key] = h.current - - def set_attribute(self, obj, key, value, usehistory = False): - if usehistory: - self.register_attribute(obj, key).setattr(value) - obj.__dict__[key] = value - if hasattr(obj, '_instance_key'): - self.register_dirty(obj) - else: - self.register_new(obj) - - def delete_attribute(self, obj, key, value, usehistory = False): - if usehistory: - self.register_attribute(obj, key).delattr(value) - del obj.__dict__[key] - if hasattr(obj, '_instance_key'): - self.register_dirty(obj) - else: - self.register_new(obj) - - def rollback_obj(self, obj): - try: - attributes = self.attribute_history[obj] - for key, hist in attributes.iteritems(): - hist.rollback() - obj.__dict__[key] = hist.current - except KeyError: - pass - for value in obj.__dict__.values(): - if isinstance(value, util.HistoryArraySet): - value.rollback() - def register_attribute(self, obj, key): - try: - attributes = self.attribute_history[obj] - except KeyError: - attributes = self.attribute_history.setdefault(obj, {}) - try: - return attributes[key] - except KeyError: - return attributes.setdefault(key, util.PropHistory(obj.__dict__.get(key, None))) - - def register_list_attribute(self, obj, key, data = None): - try: - attributes = self.attribute_history[obj] - except KeyError: - attributes = self.attribute_history.setdefault(obj, {}) - try: - childlist = attributes[key] - except KeyError: - try: - list = obj.__dict__[key] - if callable(list): - list = list() - except KeyError: - list = [] - obj.__dict__[key] = list - - childlist = UOWListElement(obj, list) - - if data is not None and childlist.data != data: - try: - childlist.set_data(data) - except TypeError: - raise "object " + repr(data) + " is not an iterable object" - return childlist - - def rollback_list_attribute(self, obj, key): - try: - childlist = obj.__dict__[key] - if isinstance(childlist, util.HistoryArraySet): - childlist.rollback() - except KeyError: - pass def register_clean(self, obj, scope="thread"): try: del self.dirty[obj] @@ -263,7 +182,7 @@ class UnitOfWork(object): commit_context.append_task(obj) engines = util.HashSet() - for mapper in commit_context.mappers.keys(): + for mapper in commit_context.mappers: for e in mapper.engines: engines.append(e) @@ -288,8 +207,8 @@ class UnitOfWork(object): class UOWTransaction(object): def __init__(self, uow): self.uow = uow - self.mappers = {} - self.engines = util.HashSet() + self.object_mappers = {} + self.mappers = util.HashSet() self.dependencies = {} self.tasks = {} self.saved_objects = util.HashSet() @@ -322,10 +241,11 @@ class UOWTransaction(object): def object_mapper(self, obj): import sqlalchemy.mapper try: - return self.mappers[obj] + return self.object_mappers[obj] except KeyError: mapper = sqlalchemy.mapper.object_mapper(obj) - self.mappers[obj] = mapper + self.object_mappers[obj] = mapper + self.mappers.append(mapper) return mapper def execute(self): diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 23c2420adf..20d45142f2 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -128,15 +128,16 @@ class HashSet(object): return self.map[key] class HistoryArraySet(UserList.UserList): - def __init__(self, items = None, data = None): - UserList.UserList.__init__(self, items) + def __init__(self, data = None): # stores the array's items as keys, and a value of True, False or None indicating # added, deleted, or unchanged for that item + self.records = OrderedDict() if data is not None: self.data = data - self.records = {} - for i in self.data: - self.records[i] = True + for item in data: + self._setrecord(item) + else: + self.data = [] def set_data(self, data): # first mark everything current as "deleted" diff --git a/test/attributes.py b/test/attributes.py new file mode 100644 index 0000000000..718f302143 --- /dev/null +++ b/test/attributes.py @@ -0,0 +1,78 @@ +from testbase import PersistTest +import sqlalchemy.util as util +import sqlalchemy.attributes as attributes +import unittest, sys, os + + + +class AttributesTest(PersistTest): + def testbasic(self): + class User(object):pass + manager = attributes.AttributeManager() + manager.register_attribute(User, 'user_id', uselist = False) + manager.register_attribute(User, 'user_name', uselist = False) + manager.register_attribute(User, 'email_address', uselist = False) + + u = User() + print repr(u.__dict__) + + u.user_id = 7 + u.user_name = 'john' + u.email_address = 'lala@123.com' + + print repr(u.__dict__) + self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') + manager.clear_history(u) + print repr(u.__dict__) + self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') + + u.user_name = 'heythere' + u.email_address = 'foo@bar.com' + print repr(u.__dict__) + self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com') + + manager.rollback(u) + print repr(u.__dict__) + self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') + + def testlist(self): + class User(object):pass + class Address(object):pass + manager = attributes.AttributeManager() + manager.register_attribute(User, 'user_id', uselist = False) + manager.register_attribute(User, 'user_name', uselist = False) + manager.register_attribute(User, 'addresses', uselist = True) + manager.register_attribute(Address, 'address_id', uselist = False) + manager.register_attribute(Address, 'email_address', uselist = False) + + u = User() + print repr(u.__dict__) + + u.user_id = 7 + u.user_name = 'john' + u.addresses = [] + a = Address() + a.address_id = 10 + a.email_address = 'lala@123.com' + u.addresses.append(a) + + print repr(u.__dict__) + self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') + manager.clear_history(u) + print repr(u.__dict__) + self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') + + u.user_name = 'heythere' + a = Address() + a.address_id = 11 + a.email_address = 'foo@bar.com' + u.addresses.append(a) + print repr(u.__dict__) + self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com') + + manager.rollback(u) + print repr(u.__dict__) + self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') + +if __name__ == "__main__": + unittest.main() diff --git a/test/historyarray.py b/test/historyarray.py index 6e987a5aba..bd5c81ad9b 100644 --- a/test/historyarray.py +++ b/test/historyarray.py @@ -38,6 +38,24 @@ class HistoryArrayTest(PersistTest): self.assert_(a.deleted_items() == []) self.assert_(a == ['hi']) + def testrollback(self): + a = util.HistoryArraySet() + a.append('hi') + a.append('there') + a.append('yo') + a.clear_history() + before = repr(a.data) + print repr(a.data) + a.remove('there') + a.append('lala') + a.remove('yo') + a.append('yo') + after = repr(a.data) + print repr(a.data) + a.rollback() + print repr(a.data) + self.assert_(before == repr(a.data)) + def testarray(self): a = util.HistoryArraySet() a.append('hi') diff --git a/test/objectstore.py b/test/objectstore.py index 49a310b69d..a257b4b4fc 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -27,9 +27,11 @@ class HistoryTest(AssertMixin): u = User() u.user_id = 7 u.user_name = 'afdas' - u.addresses = [Address(), Address()] + u.addresses.append(Address()) u.addresses[0].email_address = 'hi' + u.addresses.append(Address()) u.addresses[1].email_address = 'there' + print repr(u.__dict__) m.rollback(u) print repr(u.__dict__) -- 2.47.2