]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- an experimental feature that combines a Query with an InstrumentedAttribute, to...
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 Jul 2007 23:02:20 +0000 (23:02 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 Jul 2007 23:02:20 +0000 (23:02 +0000)
"always live" results in conjunction with mutator capability

lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/dynamic.py [new file with mode: 0644]
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/strategies.py
test/orm/alltests.py
test/orm/dynamic.py [new file with mode: 0644]
test/orm/fixtures.py

index 1b081910f5d26ab50ff6662dc5c39dd6927922a0..b903d5aa08d55663a4397a9d6da85d29a6fd7c70 100644 (file)
@@ -14,6 +14,7 @@ from sqlalchemy import logging, exceptions
 
 PASSIVE_NORESULT = object()
 ATTR_WAS_SET = object()
+NO_VALUE = object()
 
 class InstrumentedAttribute(interfaces.PropComparator):
     """attribute access for instrumented classes."""
@@ -82,6 +83,15 @@ class InstrumentedAttribute(interfaces.PropComparator):
             return self
         return self.get(obj)
 
+    def commit_to_state(self, state, obj, value=NO_VALUE):
+        """commit the a copy of thte value of 'obj' to the given CommittedState"""
+
+        if value is NO_VALUE:
+            if self.key in obj.__dict__:
+                value = obj.__dict__[self.key]
+        if value is not NO_VALUE:
+            state.data[self.key] = self.copy(value)
+
     def clause_element(self):
         return self.comparator.clause_element()
 
@@ -257,7 +267,7 @@ class InstrumentedAttribute(interfaces.PropComparator):
         state = obj._state
         orig = state.get('original', None)
         if orig is not None:
-            orig.commit_attribute(self, obj, value)
+            self.commit_to_state(orig, obj, value)
         # remove per-instance callable, if any
         state.pop(('callable', self), None)
         obj.__dict__[self.key] = value
@@ -475,7 +485,7 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute):
         value = user_data
 
         if orig is not None:
-            orig.commit_attribute(self, obj, value)
+            self.commit_to_state(orig, obj, value)
         # remove per-instance callable, if any
         state.pop(('callable', self), None)
         obj.__dict__[self.key] = value
@@ -538,34 +548,11 @@ class CommittedState(object):
     method on the attribute manager is called.
     """
 
-    NO_VALUE = object()
 
     def __init__(self, manager, obj):
         self.data = {}
         for attr in manager.managed_attributes(obj.__class__):
-            self.commit_attribute(attr, obj)
-
-    def commit_attribute(self, attr, obj, value=NO_VALUE):
-        """Establish the value of attribute `attr` on instance `obj`
-        as *committed*.
-
-        This corresponds to a previously saved state being restored.
-        """
-
-        if value is CommittedState.NO_VALUE:
-            if attr.key in obj.__dict__:
-                value = obj.__dict__[attr.key]
-        if value is not CommittedState.NO_VALUE:
-            self.data[attr.key] = attr.copy(value)
-
-            # not tracking parent on lazy-loaded instances at the moment.
-            # its not needed since they will be "optimistically" tested
-            #if attr.uselist:
-                #if attr.trackparent:
-                #    [attr.sethasparent(x, True) for x in self.data[attr.key] if x is not None]
-            #else:
-                #if attr.trackparent and value is not None:
-                #    attr.sethasparent(value, True)
+            attr.commit_to_state(self, obj)
 
     def rollback(self, manager, obj):
         for attr in manager.managed_attributes(obj.__class__):
@@ -761,6 +748,8 @@ class AttributeManager(object):
             return []
         elif isinstance(attr, InstrumentedCollectionAttribute):
             return list(attr._get_collection(obj, x))
+        elif isinstance(x, list):
+            return x
         else:
             return [x]
 
@@ -832,8 +821,11 @@ class AttributeManager(object):
         ``InstrumentedAttribute``, which will communicate change
         events back to this ``AttributeManager``.
         """
-
-        if uselist:
+        
+        if kwargs.pop('dynamic', False):
+            from sqlalchemy.orm import dynamic
+            return dynamic.DynamicCollectionAttribute(class_, self, key, typecallable, **kwargs)
+        elif uselist:
             return InstrumentedCollectionAttribute(class_, self, key,
                                                    callable_,
                                                    typecallable,
diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py
new file mode 100644 (file)
index 0000000..5b61368
--- /dev/null
@@ -0,0 +1,123 @@
+"""'dynamic' collection API.  returns Query() objects on the 'read' side, alters
+a special AttributeHistory on the 'write' side."""
+
+from sqlalchemy import exceptions
+from sqlalchemy.orm import attributes, Query, object_session
+from sqlalchemy.orm.mapper import has_identity
+
+class DynamicCollectionAttribute(attributes.InstrumentedAttribute):
+    def __init__(self, class_, attribute_manager, key, typecallable, target_mapper, **kwargs):
+        super(DynamicCollectionAttribute, self).__init__(class_, attribute_manager, key, typecallable, **kwargs)
+        self.target_mapper = target_mapper
+        
+    def get(self, obj, passive=False):
+        if passive:
+            return self.get_history(obj, passive=True).added_items()
+        else:
+            return AppenderQuery(self, obj)
+
+    def commit_to_state(self, state, obj, value=attributes.NO_VALUE):
+        # we have our own AttributeHistory therefore dont need CommittedState
+        pass
+    
+    def set(self, obj, value, initiator):
+        if initiator is self:
+            return
+
+        state = obj._state
+
+        old_collection = self.get(obj).assign(value)
+
+        # TODO: emit events ???
+        state['modified'] = True
+    
+    def delete(self, *args, **kwargs):
+        raise NotImplementedError()
+        
+    def get_history(self, obj, passive=False):
+        try:
+            return obj.__dict__[self.key]
+        except KeyError:
+            obj.__dict__[self.key] = c = CollectionHistory(self, obj)
+            return c
+            
+class AppenderQuery(Query):
+    def __init__(self, attr, instance):
+        super(AppenderQuery, self).__init__(attr.target_mapper, None)
+        self.instance = instance
+        self.attr = attr
+    
+    def __len__(self):
+        if not has_identity(self.instance):
+            # TODO: all these various calls to _added_items should be more
+            # intelligently calculated from the CollectionHistory object 
+            # (i.e. account for deletes too)
+            return len(self.attr.get_history(self.instance)._added_items)
+        else:
+            return self._clone().count()
+        
+    def __iter__(self):
+        if not has_identity(self.instance):
+            return iter(self.attr.get_history(self.instance)._added_items)
+        else:
+            return iter(self._clone())
+    
+    def __getitem__(self, index):
+        if not has_identity(self.instance):
+            return iter(self.attr.get_history(self.instance)._added_items.__getitem__(index))
+        else:
+            return self._clone().__getitem__(index)
+        
+    def _clone(self):
+        # note we're returning an entirely new query class here
+        # without any assignment capabilities;
+        # the class of this query is determined by the session.
+        sess = object_session(self.instance)
+        if sess is None:
+            try:
+                sess = mapper.object_mapper(instance).get_session()
+            except exceptions.InvalidRequestError:
+                raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
+
+        return sess.query(self.attr.target_mapper).with_parent(self.instance)
+
+    def assign(self, collection):
+        if has_identity(self.instance):
+            oldlist = list(self)
+        else:
+            oldlist = []
+        self.attr.get_history(self.instance).replace(oldlist, collection)
+        return oldlist
+        
+    def append(self, item):
+        self.attr.get_history(self.instance)._added_items.append(item)
+        self.attr.fire_append_event(self.instance, item, self.attr)
+    
+    def remove(self, item):
+        self.attr.get_history(self.instance)._deleted_items.append(item)
+        self.attr.fire_remove_event(self.instance, item, self.attr)
+            
+class CollectionHistory(attributes.AttributeHistory): 
+    """override AttributeHistory to receive append/remove events directly"""
+    def __init__(self, attr, obj):
+        self._deleted_items = []
+        self._added_items = []
+        self._unchanged_items = []
+        self._obj = obj
+        
+    def replace(self, olditems, newitems):
+        self._added_items = newitems
+        self._deleted_items = olditems
+        
+    def is_modified(self):
+        return len(self._deleted_items) > 0 or len(self._added_items) > 0
+
+    def added_items(self):
+        return self._added_items
+
+    def unchanged_items(self):
+        return self._unchanged_items
+
+    def deleted_items(self):
+        return self._deleted_items
+    
\ No newline at end of file
index d16b4b287e59e5d724adbf3631ca437906d2a498..ae73f9c7c4bc207226d128c9d7f44e9cef70931d 100644 (file)
@@ -261,7 +261,9 @@ class PropertyLoader(StrategizedProperty):
     private = property(lambda s:s.cascade.delete_orphan)
 
     def create_strategy(self):
-        if self.lazy:
+        if self.lazy == 'dynamic':
+            return strategies.DynaLoader(self)
+        elif self.lazy:
             return strategies.LazyLoader(self)
         elif self.lazy is False:
             return strategies.EagerLoader(self)
index 501926d499a4513ccc782c1b0d22eb07ae8d9fd1..beb8f2755db864fbe1bb68380744fb4933221332 100644 (file)
@@ -235,10 +235,20 @@ class AbstractRelationLoader(LoaderStrategy):
     def _init_instance_attribute(self, instance, callable_=None):
         return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=callable_)
         
-    def _register_attribute(self, class_, callable_=None):
+    def _register_attribute(self, class_, callable_=None, **kwargs):
         self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__))
-        sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade,  trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator)
+        sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade,  trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, **kwargs)
 
+class DynaLoader(AbstractRelationLoader):
+    def init_class_attribute(self):
+        self.is_class_level = True
+        self._register_attribute(self.parent.class_, dynamic=True, target_mapper=self.parent_property.mapper)
+
+    def create_row_processor(self, selectcontext, mapper, row):
+        return (None, None)
+
+DynaLoader.logger = logging.class_logger(DynaLoader)
+        
 class NoLoader(AbstractRelationLoader):
     def init_class_attribute(self):
         self.is_class_level = True
index 9fcea88590861101ffc45f3c9c0eb6d441b8d9fe..59357c7b71095db163cdee4c1f0cd942a059d564 100644 (file)
@@ -32,6 +32,7 @@ def suite():
         'orm.compile',
         'orm.manytomany',
         'orm.onetoone',
+        'orm.dynamic',
         )
     alltests = unittest.TestSuite()
     for name in modules_to_test:
diff --git a/test/orm/dynamic.py b/test/orm/dynamic.py
new file mode 100644 (file)
index 0000000..434ac22
--- /dev/null
@@ -0,0 +1,59 @@
+import testbase
+import operator
+from sqlalchemy import *
+from sqlalchemy import ansisql
+from sqlalchemy.orm import *
+from testlib import *
+from fixtures import *
+
+from query import QueryTest
+
+class DynamicTest(QueryTest):
+    keep_mappers = False
+    
+    def setup_mappers(self):
+        pass
+
+    def test_basic(self):
+        mapper(User, users, properties={
+            'addresses':relation(mapper(Address, addresses), lazy='dynamic')
+        })
+        sess = create_session()
+        q = sess.query(User)
+
+        print q.filter(User.id==7).all()
+        u = q.filter(User.id==7).first()
+        print list(u.addresses)
+        assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(User.id==7).all()
+        assert fixtures.user_address_result == q.all()
+
+class FlushTest(FixtureTest):
+    def test_basic(self):
+        mapper(User, users, properties={
+            'addresses':relation(mapper(Address, addresses), lazy='dynamic')
+        })
+        sess = create_session()
+        u1 = User(name='jack')
+        u2 = User(name='ed')
+        u2.addresses.append(Address(email_address='foo@bar.com'))
+        u1.addresses.append(Address(email_address='lala@hoho.com'))
+        sess.save(u1)
+        sess.save(u2)
+        sess.flush()
+        
+        sess.clear()
+        
+        def go():
+            assert [
+                User(name='jack', addresses=[Address(email_address='lala@hoho.com')]),
+                User(name='ed', addresses=[Address(email_address='foo@bar.com')])
+            ] == sess.query(User).all()
+
+        # one query for the query(User).all(), one query for each address iter(),
+        # also one query for a count() on each address (the count() is an artifact of the
+        # fixtures.Base class, its not intrinsic to the property)
+        self.assert_sql_count(testbase.db, go, 5)
+        
+if __name__ == '__main__':
+    testbase.main()
+    
\ No newline at end of file
index 4a7d41459f74c086bd032ab728da0f78b6b07afe..8b7312251c5ac3f98a3aaeb31ef907a166c06c19 100644 (file)
@@ -161,6 +161,12 @@ def install_fixture_data():
         dict(keyword_id=6, item_id=3)
     )
 
+class FixtureTest(ORMTest):
+    def define_tables(self, meta):
+        # a slight dirty trick here. 
+        meta.tables = metadata.tables
+        metadata.connect(meta.bind)
+    
 class Fixtures(object):
     @property
     def user_address_result(self):