From: Mike Bayer Date: Fri, 27 Jul 2007 23:02:20 +0000 (+0000) Subject: - an experimental feature that combines a Query with an InstrumentedAttribute, to... X-Git-Tag: rel_0_4beta1~163 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3f1c33cdc715624dd7ee2f88e04872401e0d9c61;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - an experimental feature that combines a Query with an InstrumentedAttribute, to provide "always live" results in conjunction with mutator capability --- diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 1b081910f5..b903d5aa08 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -14,6 +14,7 @@ from sqlalchemy import logging, exceptions PASSIVE_NORESULT = object() ATTR_WAS_SET = object() +NO_VALUE = object() class InstrumentedAttribute(interfaces.PropComparator): """attribute access for instrumented classes.""" @@ -82,6 +83,15 @@ class InstrumentedAttribute(interfaces.PropComparator): return self return self.get(obj) + def commit_to_state(self, state, obj, value=NO_VALUE): + """commit the a copy of thte value of 'obj' to the given CommittedState""" + + if value is NO_VALUE: + if self.key in obj.__dict__: + value = obj.__dict__[self.key] + if value is not NO_VALUE: + state.data[self.key] = self.copy(value) + def clause_element(self): return self.comparator.clause_element() @@ -257,7 +267,7 @@ class InstrumentedAttribute(interfaces.PropComparator): state = obj._state orig = state.get('original', None) if orig is not None: - orig.commit_attribute(self, obj, value) + self.commit_to_state(orig, obj, value) # remove per-instance callable, if any state.pop(('callable', self), None) obj.__dict__[self.key] = value @@ -475,7 +485,7 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute): value = user_data if orig is not None: - orig.commit_attribute(self, obj, value) + self.commit_to_state(orig, obj, value) # remove per-instance callable, if any state.pop(('callable', self), None) obj.__dict__[self.key] = value @@ -538,34 +548,11 @@ class CommittedState(object): method on the attribute manager is called. """ - NO_VALUE = object() def __init__(self, manager, obj): self.data = {} for attr in manager.managed_attributes(obj.__class__): - self.commit_attribute(attr, obj) - - def commit_attribute(self, attr, obj, value=NO_VALUE): - """Establish the value of attribute `attr` on instance `obj` - as *committed*. - - This corresponds to a previously saved state being restored. - """ - - if value is CommittedState.NO_VALUE: - if attr.key in obj.__dict__: - value = obj.__dict__[attr.key] - if value is not CommittedState.NO_VALUE: - self.data[attr.key] = attr.copy(value) - - # not tracking parent on lazy-loaded instances at the moment. - # its not needed since they will be "optimistically" tested - #if attr.uselist: - #if attr.trackparent: - # [attr.sethasparent(x, True) for x in self.data[attr.key] if x is not None] - #else: - #if attr.trackparent and value is not None: - # attr.sethasparent(value, True) + attr.commit_to_state(self, obj) def rollback(self, manager, obj): for attr in manager.managed_attributes(obj.__class__): @@ -761,6 +748,8 @@ class AttributeManager(object): return [] elif isinstance(attr, InstrumentedCollectionAttribute): return list(attr._get_collection(obj, x)) + elif isinstance(x, list): + return x else: return [x] @@ -832,8 +821,11 @@ class AttributeManager(object): ``InstrumentedAttribute``, which will communicate change events back to this ``AttributeManager``. """ - - if uselist: + + if kwargs.pop('dynamic', False): + from sqlalchemy.orm import dynamic + return dynamic.DynamicCollectionAttribute(class_, self, key, typecallable, **kwargs) + elif uselist: return InstrumentedCollectionAttribute(class_, self, key, callable_, typecallable, diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py new file mode 100644 index 0000000000..5b61368671 --- /dev/null +++ b/lib/sqlalchemy/orm/dynamic.py @@ -0,0 +1,123 @@ +"""'dynamic' collection API. returns Query() objects on the 'read' side, alters +a special AttributeHistory on the 'write' side.""" + +from sqlalchemy import exceptions +from sqlalchemy.orm import attributes, Query, object_session +from sqlalchemy.orm.mapper import has_identity + +class DynamicCollectionAttribute(attributes.InstrumentedAttribute): + def __init__(self, class_, attribute_manager, key, typecallable, target_mapper, **kwargs): + super(DynamicCollectionAttribute, self).__init__(class_, attribute_manager, key, typecallable, **kwargs) + self.target_mapper = target_mapper + + def get(self, obj, passive=False): + if passive: + return self.get_history(obj, passive=True).added_items() + else: + return AppenderQuery(self, obj) + + def commit_to_state(self, state, obj, value=attributes.NO_VALUE): + # we have our own AttributeHistory therefore dont need CommittedState + pass + + def set(self, obj, value, initiator): + if initiator is self: + return + + state = obj._state + + old_collection = self.get(obj).assign(value) + + # TODO: emit events ??? + state['modified'] = True + + def delete(self, *args, **kwargs): + raise NotImplementedError() + + def get_history(self, obj, passive=False): + try: + return obj.__dict__[self.key] + except KeyError: + obj.__dict__[self.key] = c = CollectionHistory(self, obj) + return c + +class AppenderQuery(Query): + def __init__(self, attr, instance): + super(AppenderQuery, self).__init__(attr.target_mapper, None) + self.instance = instance + self.attr = attr + + def __len__(self): + if not has_identity(self.instance): + # TODO: all these various calls to _added_items should be more + # intelligently calculated from the CollectionHistory object + # (i.e. account for deletes too) + return len(self.attr.get_history(self.instance)._added_items) + else: + return self._clone().count() + + def __iter__(self): + if not has_identity(self.instance): + return iter(self.attr.get_history(self.instance)._added_items) + else: + return iter(self._clone()) + + def __getitem__(self, index): + if not has_identity(self.instance): + return iter(self.attr.get_history(self.instance)._added_items.__getitem__(index)) + else: + return self._clone().__getitem__(index) + + def _clone(self): + # note we're returning an entirely new query class here + # without any assignment capabilities; + # the class of this query is determined by the session. + sess = object_session(self.instance) + if sess is None: + try: + sess = mapper.object_mapper(instance).get_session() + except exceptions.InvalidRequestError: + raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key)) + + return sess.query(self.attr.target_mapper).with_parent(self.instance) + + def assign(self, collection): + if has_identity(self.instance): + oldlist = list(self) + else: + oldlist = [] + self.attr.get_history(self.instance).replace(oldlist, collection) + return oldlist + + def append(self, item): + self.attr.get_history(self.instance)._added_items.append(item) + self.attr.fire_append_event(self.instance, item, self.attr) + + def remove(self, item): + self.attr.get_history(self.instance)._deleted_items.append(item) + self.attr.fire_remove_event(self.instance, item, self.attr) + +class CollectionHistory(attributes.AttributeHistory): + """override AttributeHistory to receive append/remove events directly""" + def __init__(self, attr, obj): + self._deleted_items = [] + self._added_items = [] + self._unchanged_items = [] + self._obj = obj + + def replace(self, olditems, newitems): + self._added_items = newitems + self._deleted_items = olditems + + def is_modified(self): + return len(self._deleted_items) > 0 or len(self._added_items) > 0 + + def added_items(self): + return self._added_items + + def unchanged_items(self): + return self._unchanged_items + + def deleted_items(self): + return self._deleted_items + \ No newline at end of file diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index d16b4b287e..ae73f9c7c4 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -261,7 +261,9 @@ class PropertyLoader(StrategizedProperty): private = property(lambda s:s.cascade.delete_orphan) def create_strategy(self): - if self.lazy: + if self.lazy == 'dynamic': + return strategies.DynaLoader(self) + elif self.lazy: return strategies.LazyLoader(self) elif self.lazy is False: return strategies.EagerLoader(self) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 501926d499..beb8f2755d 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -235,10 +235,20 @@ class AbstractRelationLoader(LoaderStrategy): def _init_instance_attribute(self, instance, callable_=None): return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=callable_) - def _register_attribute(self, class_, callable_=None): + def _register_attribute(self, class_, callable_=None, **kwargs): self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__)) - sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator) + sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, **kwargs) +class DynaLoader(AbstractRelationLoader): + def init_class_attribute(self): + self.is_class_level = True + self._register_attribute(self.parent.class_, dynamic=True, target_mapper=self.parent_property.mapper) + + def create_row_processor(self, selectcontext, mapper, row): + return (None, None) + +DynaLoader.logger = logging.class_logger(DynaLoader) + class NoLoader(AbstractRelationLoader): def init_class_attribute(self): self.is_class_level = True diff --git a/test/orm/alltests.py b/test/orm/alltests.py index 9fcea88590..59357c7b71 100644 --- a/test/orm/alltests.py +++ b/test/orm/alltests.py @@ -32,6 +32,7 @@ def suite(): 'orm.compile', 'orm.manytomany', 'orm.onetoone', + 'orm.dynamic', ) alltests = unittest.TestSuite() for name in modules_to_test: diff --git a/test/orm/dynamic.py b/test/orm/dynamic.py new file mode 100644 index 0000000000..434ac22963 --- /dev/null +++ b/test/orm/dynamic.py @@ -0,0 +1,59 @@ +import testbase +import operator +from sqlalchemy import * +from sqlalchemy import ansisql +from sqlalchemy.orm import * +from testlib import * +from fixtures import * + +from query import QueryTest + +class DynamicTest(QueryTest): + keep_mappers = False + + def setup_mappers(self): + pass + + def test_basic(self): + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy='dynamic') + }) + sess = create_session() + q = sess.query(User) + + print q.filter(User.id==7).all() + u = q.filter(User.id==7).first() + print list(u.addresses) + assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(User.id==7).all() + assert fixtures.user_address_result == q.all() + +class FlushTest(FixtureTest): + def test_basic(self): + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy='dynamic') + }) + sess = create_session() + u1 = User(name='jack') + u2 = User(name='ed') + u2.addresses.append(Address(email_address='foo@bar.com')) + u1.addresses.append(Address(email_address='lala@hoho.com')) + sess.save(u1) + sess.save(u2) + sess.flush() + + sess.clear() + + def go(): + assert [ + User(name='jack', addresses=[Address(email_address='lala@hoho.com')]), + User(name='ed', addresses=[Address(email_address='foo@bar.com')]) + ] == sess.query(User).all() + + # one query for the query(User).all(), one query for each address iter(), + # also one query for a count() on each address (the count() is an artifact of the + # fixtures.Base class, its not intrinsic to the property) + self.assert_sql_count(testbase.db, go, 5) + +if __name__ == '__main__': + testbase.main() + \ No newline at end of file diff --git a/test/orm/fixtures.py b/test/orm/fixtures.py index 4a7d41459f..8b7312251c 100644 --- a/test/orm/fixtures.py +++ b/test/orm/fixtures.py @@ -161,6 +161,12 @@ def install_fixture_data(): dict(keyword_id=6, item_id=3) ) +class FixtureTest(ORMTest): + def define_tables(self, meta): + # a slight dirty trick here. + meta.tables = metadata.tables + metadata.connect(meta.bind) + class Fixtures(object): @property def user_address_result(self):