--- /dev/null
+import sqlalchemy.util as util
+import weakref
+
+class SmartProperty(object):
+ def __init__(self, manager):
+ self.manager = manager
+ def attribute_registry(self):
+ return self.manager
+ def property(self, key, uselist):
+ def set_prop(obj, value):
+ if uselist:
+ self.attribute_registry().set_list_attribute(obj, key, value)
+ else:
+ self.attribute_registry().set_attribute(obj, key, value)
+ def del_prop(obj):
+ if uselist:
+ # TODO: this probably doesnt work right, deleting the list off an item
+ self.attribute_registry().delete_list_attribute(obj, key)
+ else:
+ self.attribute_registry().delete_attribute(obj, key)
+ def get_prop(obj):
+ if uselist:
+ return self.attribute_registry().get_list_attribute(obj, key)
+ else:
+ return self.attribute_registry().get_attribute(obj, key)
+
+ return property(get_prop, set_prop, del_prop)
+
+class ListElement(util.HistoryArraySet):
+ """overrides HistoryArraySet to mark the parent object as dirty when changes occur"""
+
+ def __init__(self, obj, key, items = None):
+ self.obj = obj
+ self.key = key
+ util.HistoryArraySet.__init__(self, items)
+ print "listelement init"
+
+ def list_value_changed(self, obj, key, listval):
+ pass
+# uow().modified_lists.append(self)
+
+ def setattr(self, value):
+ self.obj.__dict__[self.key] = value
+ self.set_data(value)
+ def delattr(self, value):
+ pass
+ def _setrecord(self, item):
+ res = util.HistoryArraySet._setrecord(self, item)
+ if res:
+ self.list_value_changed(self.obj, self.key, self)
+ return res
+ def _delrecord(self, item):
+ res = util.HistoryArraySet._delrecord(self, item)
+ if res:
+ self.list_value_changed(self.obj, self.key, self)
+ return res
+
+class PropHistory(object):
+ # make our own NONE to distinguish from "None"
+ NONE = object()
+ def __init__(self, obj, key):
+ self.obj = obj
+ self.key = key
+ self.orig = PropHistory.NONE
+ def setattr_clean(self, value):
+ self.obj.__dict__[self.key] = value
+ def setattr(self, value):
+ self.orig = self.obj.__dict__.get(self.key, None)
+ self.obj.__dict__[self.key] = value
+ def delattr(self):
+ self.orig = self.obj.__dict__.get(self.key, None)
+ self.obj.__dict__[self.key] = None
+ def rollback(self):
+ if self.orig is not PropHistory.NONE:
+ self.obj.__dict__[self.key] = self.orig
+ self.orig = PropHistory.NONE
+ def clear_history(self):
+ self.orig = PropHistory.NONE
+ def added_items(self):
+ if self.orig is not PropHistory.NONE:
+ return [self.obj.__dict__[self.key]]
+ else:
+ return []
+ def deleted_items(self):
+ if self.orig is not PropHistory.NONE:
+ return [self.orig]
+ else:
+ return []
+ def unchanged_items(self):
+ if self.orig is PropHistory.NONE:
+ return [self.obj.__dict__[self.key]]
+ else:
+ return []
+
+class AttributeManager(object):
+ def __init__(self):
+ self.attribute_history = {}
+ def value_changed(self, obj, key, value):
+ pass
+# if hasattr(obj, '_instance_key'):
+# self.register_dirty(obj)
+# else:
+# self.register_new(obj)
+
+ def create_list(self, obj, key, list_):
+ return ListElement(obj, key, list_)
+
+ def get_attribute(self, obj, key):
+ try:
+ v = obj.__dict__[key]
+ except KeyError:
+ raise AttributeError(key)
+ if (callable(v)):
+ v = v()
+ obj.__dict__[key] = v
+ return v
+
+ def get_list_attribute(self, obj, key):
+ return self.get_list_history(obj, key)
+
+ def set_attribute(self, obj, key, value):
+ self.get_history(obj, key).setattr(value)
+ self.value_changed(obj, key, value)
+
+ def set_list_attribute(self, obj, key, value):
+ self.get_list_history(obj, key).setattr(value)
+
+ def delete_attribute(self, obj, key):
+ self.get_history(obj, key).delattr()
+ self.value_changed(obj, key, value)
+
+ def delete_list_attribute(self, obj, key):
+ pass
+
+ def rollback(self, obj):
+ try:
+ attributes = self.attribute_history[obj]
+ for hist in attributes.values():
+ hist.rollback()
+ except KeyError:
+ pass
+
+ def clear_history(self, obj):
+ try:
+ attributes = self.attribute_history[obj]
+ for hist in attributes.values():
+ hist.clear_history()
+ except KeyError:
+ pass
+
+ def get_history(self, obj, key):
+ try:
+ return self.attribute_history[obj][key]
+ except KeyError, e:
+ if e.args[0] is obj:
+ d = {}
+ self.attribute_history[obj] = d
+ p = PropHistory(obj, key)
+ d[key] = p
+ return p
+ else:
+ p = PropHistory(obj, key)
+ self.attribute_history[obj][key] = p
+ return p
+
+ def get_list_history(self, obj, key):
+ try:
+ return self.attribute_history[obj][key]
+ except KeyError, e:
+ list_ = obj.__dict__.get(key, None)
+ if callable(list_):
+ list_ = list_()
+ if e.args[0] is obj:
+ d = {}
+ self.attribute_history[obj] = d
+ p = self.create_list(obj, key, list_)
+ d[key] = p
+ return p
+ else:
+ p = self.create_list(obj, key, list_)
+ self.attribute_history[obj][key] = p
+ return p
+
+ def register_attribute(self, class_, key, uselist):
+ setattr(class_, key, SmartProperty(self).property(key, uselist))
self.init()
- engines = property(lambda s: [t.engine for t in self.tables])
+ engines = property(lambda s: [t.engine for t in s.tables])
def hash_key(self):
return self.hashkey
return instance
def rollback(self, obj):
- for prop in self.props.values():
- prop.rollback(obj)
-
+ objectstore.uow().rollback_object(obj)
+
class MapperOption:
"""describes a modification to a Mapper in the context of making a copy
of it. This is used to assist in the prototype pattern used by mapper.options()."""
self.key = key
# establish a SmartProperty property manager on the object for this key
if not hasattr(parent.class_, key):
- setattr(parent.class_, key, SmartProperty(key).property(usehistory = True))
+ objectstore.uow().register_attribute(parent.class_, key, uselist = False)
def execute(self, instance, row, identitykey, imap, isnew):
if isnew:
- setattr(instance, self.key, row[self.columns[0].label])
- def rollback(self, obj):
- objectstore.uow().rollback_attribute(obj, self.key)
+ instance.__dict__[self.key] = row[self.columns[0].label]
+ #setattr(instance, self.key, row[self.columns[0].label])
def hash_key(self):
return self._hash_key
- def rollback(self, obj):
- objectstore.uow().rollback_list_attribute(obj, self.key)
-
def init(self, key, parent):
self.key = key
self.parent = parent
self.foreignkey = w.dependent
if not hasattr(parent.class_, key):
- setattr(parent.class_, key, SmartProperty(key).property(usehistory = True, uselist = self.uselist))
+ objectstore.uow().register_attribute(parent.class_, key, uselist = self.uselist)
class FindDependent(sql.ClauseVisitor):
def __init__(self):
def getlist(obj):
if self.uselist:
- return uowcommit.uow.register_list_attribute(obj, self.key)
- else:
- return uowcommit.uow.register_attribute(obj, self.key)
+ return uowcommit.uow.manager.get_list_history(obj, self.key)
+ else:
+ return uowcommit.uow.manager.get_history(obj, self.key)
clearkeys = False
def init(self, key, parent):
PropertyLoader.init(self, key, parent)
- if not hasattr(parent.class_, key):
- if not issubclass(parent.class_, object):
- raise "LazyLoader can only be used with new-style classes"
- setattr(parent.class_, key, SmartProperty(key).property())
if self.secondaryjoin is not None:
self.lazywhere = sql.and_(self.primaryjoin, self.secondaryjoin)
else:
it to a list on the parent instance."""
if not self.uselist:
# TODO: check for multiple values on single-element child element ?
- setattr(instance, self.key, self.mapper._instance(row, imap))
+ instance.__dict__[self.key] = self.mapper._instance(row, imap)
+ #setattr(instance, self.key, self.mapper._instance(row, imap))
return
elif isnew:
- result_list = objectstore.uow().register_list_attribute(instance, self.key, data = [])
+ result_list = []
+ setattr(instance, self.key, result_list)
+ result_list = getattr(instance, self.key)
result_list.clear_history()
else:
result_list = getattr(instance, self.key)
-
+
self.mapper._instance(row, imap, result_list)
class EagerLazySwitcher(MapperOption):
def visit_binary(self, binary):
self.func(binary)
-class SmartProperty(object):
- def __init__(self, key):
- self.key = key
-
- def property(self, usehistory = False, uselist = False):
- def set_prop(s, value):
- if uselist:
- return objectstore.uow().register_list_attribute(s, self.key, value)
- else:
- objectstore.uow().set_attribute(s, self.key, value, usehistory)
- def del_prop(s):
- if uselist:
- # TODO: this probably doesnt work right, deleting the list off an item
- objectstore.uow().register_list_attribute(s, self.key, [])
- else:
- objectstore.uow().delete_attribute(s, self.key, value, usehistory)
- def get_prop(s):
- if uselist:
- return objectstore.uow().register_list_attribute(s, self.key)
- else:
- return objectstore.uow().get_attribute(s, self.key)
-
- return property(get_prop, set_prop, del_prop)
-
def hash_key(obj):
if obj is None:
import thread
import sqlalchemy.util as util
+import sqlalchemy.attributes as attributes
import weakref
def get_id_key(ident, class_, table):
else:
return False
-class UOWListElement(util.HistoryArraySet):
- """overrides HistoryArraySet to mark the parent object as dirty when changes occur"""
+class UOWListElement(attributes.ListElement):
+ def list_value_changed(self, obj, key, listval):
+ uow().modified_lists.append(self)
+
+class UOWAttributeManager(attributes.AttributeManager):
+ def __init__(self, uow):
+ attributes.AttributeManager.__init__(self)
+ self.uow = uow
- def __init__(self, obj, items = None):
- util.HistoryArraySet.__init__(self, items)
- self.obj = weakref.ref(obj)
+ def value_changed(self, obj, key, value):
+ if hasattr(obj, '_instance_key'):
+ self.uow.register_dirty(obj)
+ else:
+ self.uow.register_new(obj)
+
+ def create_list(self, obj, key, list_):
+ return UOWListElement(obj, key, list_)
- def _setrecord(self, item):
- res = util.HistoryArraySet._setrecord(self, item)
- if res:
- uow().modified_lists.append(self)
- return res
- def _delrecord(self, item):
- res = util.HistoryArraySet._delrecord(self, item)
- if res:
- uow().modified_lists.append(self)
- return res
-
class UnitOfWork(object):
def __init__(self, parent = None, is_begun = False):
self.is_begun = is_begun
+ self.attributes = UOWAttributeManager(self)
self.new = util.HashSet()
self.dirty = util.HashSet()
self.modified_lists = util.HashSet()
self.deleted = util.HashSet()
- self.attribute_history = weakref.WeakKeyDictionary()
self.parent = parent
+
+ def register_attribute(self, class_, key, uselist):
+ self.attributes.register_attribute(class_, key, uselist)
def attribute_set_callable(self, obj, key, func):
obj.__dict__[key] = func
- def get_attribute(self, obj, key):
- try:
- v = obj.__dict__[key]
- except KeyError:
- raise AttributeError(key)
- if (callable(v)):
- v = v()
- obj.__dict__[key] = v
- self.register_attribute(obj, key).setattr_clean(v)
- return v
+ def rollback_object(self, obj):
+ self.attributes.rollback(obj)
- def rollback_attribute(self, obj, key):
- if self.attribute_history.has_key(obj):
- h = self.attribute_history[obj][key]
- h.rollback()
- obj.__dict__[key] = h.current
-
- def set_attribute(self, obj, key, value, usehistory = False):
- if usehistory:
- self.register_attribute(obj, key).setattr(value)
- obj.__dict__[key] = value
- if hasattr(obj, '_instance_key'):
- self.register_dirty(obj)
- else:
- self.register_new(obj)
-
- def delete_attribute(self, obj, key, value, usehistory = False):
- if usehistory:
- self.register_attribute(obj, key).delattr(value)
- del obj.__dict__[key]
- if hasattr(obj, '_instance_key'):
- self.register_dirty(obj)
- else:
- self.register_new(obj)
-
- def rollback_obj(self, obj):
- try:
- attributes = self.attribute_history[obj]
- for key, hist in attributes.iteritems():
- hist.rollback()
- obj.__dict__[key] = hist.current
- except KeyError:
- pass
- for value in obj.__dict__.values():
- if isinstance(value, util.HistoryArraySet):
- value.rollback()
- def register_attribute(self, obj, key):
- try:
- attributes = self.attribute_history[obj]
- except KeyError:
- attributes = self.attribute_history.setdefault(obj, {})
- try:
- return attributes[key]
- except KeyError:
- return attributes.setdefault(key, util.PropHistory(obj.__dict__.get(key, None)))
-
- def register_list_attribute(self, obj, key, data = None):
- try:
- attributes = self.attribute_history[obj]
- except KeyError:
- attributes = self.attribute_history.setdefault(obj, {})
- try:
- childlist = attributes[key]
- except KeyError:
- try:
- list = obj.__dict__[key]
- if callable(list):
- list = list()
- except KeyError:
- list = []
- obj.__dict__[key] = list
-
- childlist = UOWListElement(obj, list)
-
- if data is not None and childlist.data != data:
- try:
- childlist.set_data(data)
- except TypeError:
- raise "object " + repr(data) + " is not an iterable object"
- return childlist
-
- def rollback_list_attribute(self, obj, key):
- try:
- childlist = obj.__dict__[key]
- if isinstance(childlist, util.HistoryArraySet):
- childlist.rollback()
- except KeyError:
- pass
def register_clean(self, obj, scope="thread"):
try:
del self.dirty[obj]
commit_context.append_task(obj)
engines = util.HashSet()
- for mapper in commit_context.mappers.keys():
+ for mapper in commit_context.mappers:
for e in mapper.engines:
engines.append(e)
class UOWTransaction(object):
def __init__(self, uow):
self.uow = uow
- self.mappers = {}
- self.engines = util.HashSet()
+ self.object_mappers = {}
+ self.mappers = util.HashSet()
self.dependencies = {}
self.tasks = {}
self.saved_objects = util.HashSet()
def object_mapper(self, obj):
import sqlalchemy.mapper
try:
- return self.mappers[obj]
+ return self.object_mappers[obj]
except KeyError:
mapper = sqlalchemy.mapper.object_mapper(obj)
- self.mappers[obj] = mapper
+ self.object_mappers[obj] = mapper
+ self.mappers.append(mapper)
return mapper
def execute(self):
return self.map[key]
class HistoryArraySet(UserList.UserList):
- def __init__(self, items = None, data = None):
- UserList.UserList.__init__(self, items)
+ def __init__(self, data = None):
# stores the array's items as keys, and a value of True, False or None indicating
# added, deleted, or unchanged for that item
+ self.records = OrderedDict()
if data is not None:
self.data = data
- self.records = {}
- for i in self.data:
- self.records[i] = True
+ for item in data:
+ self._setrecord(item)
+ else:
+ self.data = []
def set_data(self, data):
# first mark everything current as "deleted"
--- /dev/null
+from testbase import PersistTest
+import sqlalchemy.util as util
+import sqlalchemy.attributes as attributes
+import unittest, sys, os
+
+
+
+class AttributesTest(PersistTest):
+ def testbasic(self):
+ class User(object):pass
+ manager = attributes.AttributeManager()
+ manager.register_attribute(User, 'user_id', uselist = False)
+ manager.register_attribute(User, 'user_name', uselist = False)
+ manager.register_attribute(User, 'email_address', uselist = False)
+
+ u = User()
+ print repr(u.__dict__)
+
+ u.user_id = 7
+ u.user_name = 'john'
+ u.email_address = 'lala@123.com'
+
+ print repr(u.__dict__)
+ self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+ manager.clear_history(u)
+ print repr(u.__dict__)
+ self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+
+ u.user_name = 'heythere'
+ u.email_address = 'foo@bar.com'
+ print repr(u.__dict__)
+ self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com')
+
+ manager.rollback(u)
+ print repr(u.__dict__)
+ self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+
+ def testlist(self):
+ class User(object):pass
+ class Address(object):pass
+ manager = attributes.AttributeManager()
+ manager.register_attribute(User, 'user_id', uselist = False)
+ manager.register_attribute(User, 'user_name', uselist = False)
+ manager.register_attribute(User, 'addresses', uselist = True)
+ manager.register_attribute(Address, 'address_id', uselist = False)
+ manager.register_attribute(Address, 'email_address', uselist = False)
+
+ u = User()
+ print repr(u.__dict__)
+
+ u.user_id = 7
+ u.user_name = 'john'
+ u.addresses = []
+ a = Address()
+ a.address_id = 10
+ a.email_address = 'lala@123.com'
+ u.addresses.append(a)
+
+ print repr(u.__dict__)
+ self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
+ manager.clear_history(u)
+ print repr(u.__dict__)
+ self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
+
+ u.user_name = 'heythere'
+ a = Address()
+ a.address_id = 11
+ a.email_address = 'foo@bar.com'
+ u.addresses.append(a)
+ print repr(u.__dict__)
+ self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com')
+
+ manager.rollback(u)
+ print repr(u.__dict__)
+ self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
+
+if __name__ == "__main__":
+ unittest.main()
self.assert_(a.deleted_items() == [])
self.assert_(a == ['hi'])
+ def testrollback(self):
+ a = util.HistoryArraySet()
+ a.append('hi')
+ a.append('there')
+ a.append('yo')
+ a.clear_history()
+ before = repr(a.data)
+ print repr(a.data)
+ a.remove('there')
+ a.append('lala')
+ a.remove('yo')
+ a.append('yo')
+ after = repr(a.data)
+ print repr(a.data)
+ a.rollback()
+ print repr(a.data)
+ self.assert_(before == repr(a.data))
+
def testarray(self):
a = util.HistoryArraySet()
a.append('hi')
u = User()
u.user_id = 7
u.user_name = 'afdas'
- u.addresses = [Address(), Address()]
+ u.addresses.append(Address())
u.addresses[0].email_address = 'hi'
+ u.addresses.append(Address())
u.addresses[1].email_address = 'there'
+ print repr(u.__dict__)
m.rollback(u)
print repr(u.__dict__)