]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- trimming down redundancy in lazyloader code
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 28 Jul 2007 05:16:03 +0000 (05:16 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 28 Jul 2007 05:16:03 +0000 (05:16 +0000)
- fixups to ORM test fixture code
- fixup to dynamic realtions, test for autoflush session, delete-orphan
- made new dynamic_loader() function to create them
- removed old hasparent() call on AttributeHistory

lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/dynamic.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
test/orm/dynamic.py
test/orm/fixtures.py
test/orm/lazy_relations.py
test/orm/query.py

index 1982a94f780e4c78ca23d27ed93f91f6c8a004ce..380d2957700d0cbd56d33c882fa6a80c5b3482a6 100644 (file)
@@ -24,7 +24,7 @@ from sqlalchemy.orm.session import object_session, attribute_manager
 __all__ = ['relation', 'column_property', 'composite', 'backref', 'eagerload',
            'eagerload_all', 'lazyload', 'noload', 'deferred', 'defer', 'undefer',
            'undefer_group', 'extension', 'mapper', 'clear_mappers',
-           'compile_mappers', 'class_mapper', 'object_mapper',
+           'compile_mappers', 'class_mapper', 'object_mapper', 'dynamic_loader',
            'MapperExtension', 'Query', 'polymorphic_union', 'create_session',
            'synonym', 'contains_alias', 'contains_eager', 'EXT_PASS',
            'object_session', 'PropComparator'
@@ -170,7 +170,26 @@ def relation(argument, secondary=None, **kwargs):
 
     return PropertyLoader(argument, secondary=secondary, **kwargs)
 
-#    return _relation_loader(argument, secondary=secondary, **kwargs)
+def dynamic_loader(argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, 
+    foreign_keys=None, backref=None, post_update=False, cascade=None, remote_side=None, enable_typechecks=True):
+    """construct a dynamically-loading mapper property.
+    
+    This property is similar to relation(), except read operations
+    return an active Query object, which reads from the database in all 
+    cases.  Items may be appended to the attribute via append(), or
+    removed via remove(); changes will be persisted
+    to the database during a flush().  However, no other list mutation
+    operations are available.
+    
+    A subset of arguments available to relation() are available here.
+    """
+
+    from sqlalchemy.orm.strategies import DynaLoader
+    
+    return PropertyLoader(argument, secondary=secondary, primaryjoin=primaryjoin, 
+            secondaryjoin=secondaryjoin, entity_name=entity_name, foreign_keys=foreign_keys, backref=backref, 
+            post_update=post_update, cascade=cascade, remote_side=remote_side, enable_typechecks=enable_typechecks, 
+            strategy_class=DynaLoader)
 
 #def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=None, lazy=True, **kwargs):
 
index b903d5aa08d55663a4397a9d6da85d29a6fd7c70..f783a381fb5ed8df23e280289253c8c4ef3499ca 100644 (file)
@@ -633,12 +633,6 @@ class AttributeHistory(object):
     def deleted_items(self):
         return self._deleted_items
 
-    def hasparent(self, obj):
-        """Deprecated.  This should be called directly from the appropriate ``InstrumentedAttribute`` object.
-        """
-
-        return self.attr.hasparent(obj)
-
 class AttributeManager(object):
     """Allow the instrumentation of object attributes."""
 
index c06db69631e7768b489e63d5fddafe7411dddead..cfe501695ab216bb8329b7b3655f0d74e5d49ae4 100644 (file)
@@ -51,6 +51,12 @@ class DependencyProcessor(object):
 
         return getattr(self.parent.class_, self.key)
 
+    def hasparent(self, obj):
+        """return True if the given object instance has a parent, 
+        according to the ``InstrumentedAttribute`` handled by this ``DependencyProcessor``."""
+        
+        return self._get_instrumented_attribute().hasparent(obj)
+        
     def register_dependencies(self, uowcommit):
         """Tell a ``UOWTransaction`` what mappers are dependent on
         which, with regards to the two or three mappers handled by
@@ -187,7 +193,7 @@ class OneToManyDP(DependencyProcessor):
                     childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes)
                     if childlist is not None:
                         for child in childlist.deleted_items():
-                            if child is not None and childlist.hasparent(child) is False:
+                            if child is not None and self.hasparent(child) is False:
                                 self._synchronize(obj, child, None, True, uowcommit)
                                 self._conditional_post_update(child, uowcommit, [obj])
                         for child in childlist.unchanged_items():
@@ -202,7 +208,7 @@ class OneToManyDP(DependencyProcessor):
                         self._synchronize(obj, child, None, False, uowcommit)
                         self._conditional_post_update(child, uowcommit, [obj])
                     for child in childlist.deleted_items():
-                        if not self.cascade.delete_orphan and not self._get_instrumented_attribute().hasparent(child):
+                        if not self.cascade.delete_orphan and not self.hasparent(child):
                             self._synchronize(obj, child, None, True, uowcommit)
 
     def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
@@ -216,7 +222,7 @@ class OneToManyDP(DependencyProcessor):
                     childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes)
                     if childlist is not None:
                         for child in childlist.deleted_items():
-                            if child is not None and childlist.hasparent(child) is False:
+                            if child is not None and self.hasparent(child) is False:
                                 uowcommit.register_object(child)
                         for child in childlist.unchanged_items():
                             if child is not None:
@@ -231,7 +237,7 @@ class OneToManyDP(DependencyProcessor):
                     for child in childlist.deleted_items():
                         if not self.cascade.delete_orphan:
                             uowcommit.register_object(child, isdelete=False)
-                        elif childlist.hasparent(child) is False:
+                        elif self.hasparent(child) is False:
                             uowcommit.register_object(child, isdelete=True)
                             for c in self.mapper.cascade_iterator('delete', child):
                                 uowcommit.register_object(c, isdelete=True)
@@ -285,7 +291,7 @@ class ManyToOneDP(DependencyProcessor):
                     childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes)
                     if childlist is not None:
                         for child in childlist.deleted_items() + childlist.unchanged_items():
-                            if child is not None and childlist.hasparent(child) is False:
+                            if child is not None and self.hasparent(child) is False:
                                 uowcommit.register_object(child, isdelete=True)
                                 for c in self.mapper.cascade_iterator('delete', child):
                                     uowcommit.register_object(c, isdelete=True)
@@ -296,7 +302,7 @@ class ManyToOneDP(DependencyProcessor):
                     childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes)
                     if childlist is not None:
                         for child in childlist.deleted_items():
-                            if childlist.hasparent(child) is False:
+                            if self.hasparent(child) is False:
                                 uowcommit.register_object(child, isdelete=True)
                                 for c in self.mapper.cascade_iterator('delete', child):
                                     uowcommit.register_object(c, isdelete=True)
@@ -382,7 +388,7 @@ class ManyToManyDP(DependencyProcessor):
                 childlist = self.get_object_dependencies(obj, uowcommit, passive=True)
                 if childlist is not None:
                     for child in childlist.deleted_items():
-                        if self.cascade.delete_orphan and childlist.hasparent(child) is False:
+                        if self.cascade.delete_orphan and self.hasparent(child) is False:
                             uowcommit.register_object(child, isdelete=True)
                             for c in self.mapper.cascade_iterator('delete', child):
                                 uowcommit.register_object(c, isdelete=True)
index b1293148288d86be920e594634fe8f1218566d97..d06b874c9878df3b990df4ec5e4d37eb8c321be4 100644 (file)
@@ -18,7 +18,8 @@ class DynamicCollectionAttribute(attributes.InstrumentedAttribute):
 
     def commit_to_state(self, state, obj, value=attributes.NO_VALUE):
         # we have our own AttributeHistory therefore dont need CommittedState
-        pass
+        # instead, we reset the history stored on the attribute
+        obj.__dict__[self.key] = CollectionHistory(self, obj)
     
     def set(self, obj, value, initiator):
         if initiator is self:
@@ -58,38 +59,47 @@ class AppenderQuery(Query):
         self.instance = instance
         self.attr = attr
     
-    def __len__(self):
+    def __session(self):
+        sess = object_session(self.instance)
+        if sess is not None and self.instance in sess and sess.autoflush:
+            sess.flush()
         if not has_identity(self.instance):
-            # TODO: all these various calls to _added_items should be more
-            # intelligently calculated from the CollectionHistory object 
-            # (i.e. account for deletes too)
+            return None
+        else:
+            return sess
+            
+    def __len__(self):
+        sess = self.__session()
+        if sess is None:
             return len(self.attr.get_history(self.instance)._added_items)
         else:
-            return self._clone().count()
+            return self._clone(sess).count()
         
     def __iter__(self):
-        if not has_identity(self.instance):
+        sess = self.__session()
+        if sess is None:
             return iter(self.attr.get_history(self.instance)._added_items)
         else:
-            return iter(self._clone())
+            return iter(self._clone(sess))
 
     def __getitem__(self, index):
-        if not has_identity(self.instance):
-            # TODO: hmm
+        sess = self.__session()
+        if sess is None:
             return self.attr.get_history(self.instance)._added_items.__getitem__(index)
         else:
-            return self._clone().__getitem__(index)
-        
-    def _clone(self):
+            return self._clone(sess).__getitem__(index)
+
+    def _clone(self, sess=None):
         # note we're returning an entirely new Query class instance here
         # without any assignment capabilities;
         # the class of this query is determined by the session.
-        sess = object_session(self.instance)
         if sess is None:
-            try:
-                sess = mapper.object_mapper(instance).get_session()
-            except exceptions.InvalidRequestError:
-                raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
+            sess = object_session(self.instance)
+            if sess is None:
+                try:
+                    sess = mapper.object_mapper(instance).get_session()
+                except exceptions.InvalidRequestError:
+                    raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
 
         return sess.query(self.attr.target_mapper).with_parent(self.instance)
 
@@ -102,11 +112,11 @@ class AppenderQuery(Query):
         return oldlist
         
     def append(self, item):
-        self.attr.append(self.instance, item, self.attr)
+        self.attr.append(self.instance, item, None)
 
-    # TODO:jek: I think this should probably be axed, time will tell.
     def remove(self, item):
-        self.attr.remove(self.instance, item, self.attr)
+        self.attr.remove(self.instance, item, None)
+
             
 class CollectionHistory(attributes.AttributeHistory): 
     """Overrides AttributeHistory to receive append/remove events directly."""
index ae73f9c7c4bc207226d128c9d7f44e9cef70931d..a6bb1a371d67db0d8adbe176229534193713952e 100644 (file)
@@ -124,7 +124,7 @@ class PropertyLoader(StrategizedProperty):
     of items that correspond to a related database table.
     """
 
-    def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True, join_depth=None):
+    def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True, join_depth=None, strategy_class=None):
         self.uselist = uselist
         self.argument = argument
         self.entity_name = entity_name
@@ -144,6 +144,7 @@ class PropertyLoader(StrategizedProperty):
         self._parent_join_cache = {}
         self.comparator = PropertyLoader.Comparator(self)
         self.join_depth = join_depth
+        self.strategy_class = strategy_class
         
         if cascade is not None:
             self.cascade = mapperutil.CascadeOptions(cascade)
@@ -247,22 +248,13 @@ class PropertyLoader(StrategizedProperty):
             return op(self.comparator, value)
     
     def _optimized_compare(self, value, value_is_parent=False):
-        # optimized operation for ==, uses a lazy clause.
-        (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent)
-        bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
-
-        class Visitor(sql.ClauseVisitor):
-            def visit_bindparam(s, bindparam):
-                mapper = value_is_parent and self.parent or self.mapper
-                bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key])
-        Visitor().traverse(criterion)
-        return criterion
+        return self._get_strategy(strategies.LazyLoader).lazy_clause(value, reverse_direction=not value_is_parent)
         
     private = property(lambda s:s.cascade.delete_orphan)
 
     def create_strategy(self):
-        if self.lazy == 'dynamic':
-            return strategies.DynaLoader(self)
+        if self.strategy_class:
+            return self.strategy_class(self)
         elif self.lazy:
             return strategies.LazyLoader(self)
         elif self.lazy is False:
index 9040655b286f2e51c254fb370b7c5866b39c36dc..a504aa9a67d23f73ea563a5cbbfed9ea4e2fe365 100644 (file)
@@ -621,6 +621,7 @@ class Query(object):
     def __iter__(self):
         statement = self.compile()
         statement.use_labels = True
+        print "ITER !", self.session.autoflush
         if self.session.autoflush:
             self.session.flush()
         return self._execute_and_instances(statement)
index beb8f2755db864fbe1bb68380744fb4933221332..54a4c65247b24b67e0d3d1b4aa06e642bc0b00ce 100644 (file)
@@ -282,6 +282,20 @@ class LazyLoader(AbstractRelationLoader):
         self.is_class_level = True
         self._register_attribute(self.parent.class_, callable_=lambda i: self.setup_loader(i))
 
+    def lazy_clause(self, instance, reverse_direction=False):
+        if not reverse_direction:
+            (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse)
+        else:
+            (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
+        bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+
+        class Visitor(sql.ClauseVisitor):
+            def visit_bindparam(s, bindparam):
+                mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
+                if bindparam.key in bind_to_col:
+                    bindparam.value = mapper.get_attr_by_column(instance, bind_to_col[bindparam.key])
+        return Visitor().traverse(criterion, clone=True)
+    
     def setup_loader(self, instance, options=None):
         if not mapper.has_mapper(instance):
             return None
@@ -296,23 +310,8 @@ class LazyLoader(AbstractRelationLoader):
 
         def lazyload():
             self.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance)))
-            params = {}
-            allparams = True
-            # if the instance wasnt loaded from the database, then it cannot lazy load
-            # child items.  one reason for this is that a bi-directional relationship
-            # will not update properly, since bi-directional uses lazy loading functions
-            # in both directions, and this instance will not be present in the lazily-loaded
-            # results of the other objects since its not in the database
-            if not mapper.has_identity(instance):
-                return None
-            #print "setting up loader, lazywhere", str(self.lazywhere), "binds", self.lazybinds
-            for col, bind in self.lazybinds.iteritems():
-                params[bind.key] = self.parent.get_attr_by_column(instance, col)
-                if params[bind.key] is None:
-                    allparams = False
-                    break
 
-            if not allparams:
+            if not mapper.has_identity(instance):
                 return None
 
             session = sessionlib.object_session(instance)
@@ -326,14 +325,15 @@ class LazyLoader(AbstractRelationLoader):
             # to possibly save a DB round trip
             q = session.query(self.mapper)
             if self.use_get:
+                params = {}
+                for col, bind in self.lazybinds.iteritems():
+                    params[bind.key] = self.parent.get_attr_by_column(instance, col)
                 ident = []
-                # TODO: when options are added to allow switching between union-based and non-union
-                # based polymorphic loads on a per-query basis, this code needs to switch between "mapper" and "select_mapper",
-                # probably via the query's own "mapper" property, and also use one of two "lazy" clauses,
-                # one against the "union" the other not
                 for primary_key in self.select_mapper.primary_key: 
                     bind = self.lazyreverse[primary_key]
                     ident.append(params[bind.key])
+                if options:
+                    q = q.options(*options)
                 return q.get(ident)
             elif self.order_by is not False:
                 q = q.order_by(self.order_by)
@@ -342,7 +342,7 @@ class LazyLoader(AbstractRelationLoader):
 
             if options:
                 q = q.options(*options)
-            q = q.filter(self.lazywhere).params(**params)
+            q = q.filter(self.lazy_clause(instance))
 
             result = q.all()
             if self.uselist:
@@ -352,11 +352,6 @@ class LazyLoader(AbstractRelationLoader):
                     return result[0]
                 else:
                     return None
-            
-            if self.uselist:
-                return q.all()
-            else:
-                return q.first()
 
         return lazyload
 
@@ -382,7 +377,7 @@ class LazyLoader(AbstractRelationLoader):
                     sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
             return (execute, None)
 
-    def _create_lazy_clause(cls, prop, reverse_direction=False, op='=='):
+    def _create_lazy_clause(cls, prop, reverse_direction=False):
         (primaryjoin, secondaryjoin, remote_side) = (prop.polymorphic_primaryjoin, prop.polymorphic_secondaryjoin, prop.remote_side)
         
         binds = {}
index 2cd616c12d62c68d74b0a66803c0fd6b401e8098..958f5b1598e2d1a6fff78ec4db1bd5ec05b9705d 100644 (file)
@@ -16,7 +16,7 @@ class DynamicTest(QueryTest):
 
     def test_basic(self):
         mapper(User, users, properties={
-            'addresses':relation(mapper(Address, addresses), lazy='dynamic')
+            'addresses':dynamic_loader(mapper(Address, addresses))
         })
         sess = create_session()
         q = sess.query(User)
@@ -29,8 +29,11 @@ class DynamicTest(QueryTest):
 
 class FlushTest(FixtureTest):
     def test_basic(self):
+        class Fixture(Base):
+            pass
+            
         mapper(User, users, properties={
-            'addresses':relation(mapper(Address, addresses), lazy='dynamic')
+            'addresses':dynamic_loader(mapper(Address, addresses))
         })
         sess = create_session()
         u1 = User(name='jack')
@@ -42,88 +45,115 @@ class FlushTest(FixtureTest):
         sess.flush()
         
         sess.clear()
+
+        # test the test fixture a little bit
+        assert User(name='jack', addresses=[Address(email_address='wrong')]) != sess.query(User).first()
+        assert User(name='jack', addresses=[Address(email_address='lala@hoho.com')]) == sess.query(User).first()
         
-        def go():
-            assert [
-                User(name='jack', addresses=[Address(email_address='lala@hoho.com')]),
-                User(name='ed', addresses=[Address(email_address='foo@bar.com')])
-            ] == sess.query(User).all()
-
-        # one query for the query(User).all(), one query for each address
-        # iter(), also one query for a count() on each address (the count()
-        # is an artifact of the fixtures.Base class, its not intrinsic to the
-        # property)
-        self.assert_sql_count(testbase.db, go, 5)
-
-    def test_backref_unsaved_u(self):
+        assert [
+            User(name='jack', addresses=[Address(email_address='lala@hoho.com')]),
+            User(name='ed', addresses=[Address(email_address='foo@bar.com')])
+        ] == sess.query(User).all()
+
+    def test_delete(self):
         mapper(User, users, properties={
-            'addresses':relation(mapper(Address, addresses), lazy='dynamic',
-                                 backref='user')
+            'addresses':dynamic_loader(mapper(Address, addresses), backref='user')
         })
-        sess = create_session()
-
-        u = User(name='buffy')
-
-        a = Address(email_address='foo@bar.com')
-        a.user = u
-
+        sess = create_session(autoflush=True)
+        u = User(name='ed')
+        u.addresses.append(Address(email_address='a'))
+        u.addresses.append(Address(email_address='b'))
+        u.addresses.append(Address(email_address='c'))
+        u.addresses.append(Address(email_address='d'))
+        u.addresses.append(Address(email_address='e'))
+        u.addresses.append(Address(email_address='f'))
         sess.save(u)
-        sess.flush()
+        
+        assert Address(email_address='c') == u.addresses[2]
+        sess.delete(u.addresses[2])
+        sess.delete(u.addresses[4])
+        sess.delete(u.addresses[3])
+        assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == list(u.addresses)
+        
+        sess.close()
 
-    def test_backref_unsaved_a(self):
+    def test_remove_orphans(self):
         mapper(User, users, properties={
-            'addresses':relation(mapper(Address, addresses), lazy='dynamic',
-                                 backref='user')
+            'addresses':dynamic_loader(mapper(Address, addresses), cascade="all, delete-orphan", backref='user')
         })
-        sess = create_session()
-
-        u = User(name='buffy')
-
-        a = Address(email_address='foo@bar.com')
-        a.user = u
+        sess = create_session(autoflush=True)
+        u = User(name='ed')
+        u.addresses.append(Address(email_address='a'))
+        u.addresses.append(Address(email_address='b'))
+        u.addresses.append(Address(email_address='c'))
+        u.addresses.append(Address(email_address='d'))
+        u.addresses.append(Address(email_address='e'))
+        u.addresses.append(Address(email_address='f'))
+        sess.save(u)
 
-        self.assert_(list(u.addresses) == [a])
-        self.assert_(u.addresses[0] == a)
+        assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='c'), 
+            Address(email_address='d'), Address(email_address='e'), Address(email_address='f')] == sess.query(Address).all()
 
-        sess.save(a)
-        sess.flush()
+        assert Address(email_address='c') == u.addresses[2]
         
-        self.assert_(list(u.addresses) == [a])
-
-        a.user = None
-        self.assert_(list(u.addresses) == [a])
+        try:
+            del u.addresses[3]
+            assert False
+        except TypeError, e:
+            assert str(e) == "object doesn't support item deletion"
+        
+        for a in u.addresses.filter(Address.email_address.in_('c', 'e', 'f')):
+            u.addresses.remove(a)
+            
+        assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == list(u.addresses)
 
-        sess.flush()
-        self.assert_(list(u.addresses) == [])
+        assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == sess.query(Address).all()
         
+        sess.close()
 
-    def test_backref_unsaved_u(self):
+def create_backref_test(autoflush, saveuser):
+    def test_backref(self):
         mapper(User, users, properties={
-            'addresses':relation(mapper(Address, addresses), lazy='dynamic',
-                                 backref='user')
+            'addresses':dynamic_loader(mapper(Address, addresses), backref='user')
         })
-        sess = create_session()
+        sess = create_session(autoflush=autoflush)
 
         u = User(name='buffy')
 
         a = Address(email_address='foo@bar.com')
         a.user = u
 
-        self.assert_(list(u.addresses) == [a])
-        self.assert_(u.addresses[0] == a)
+        if saveuser:
+            sess.save(u)
+        else:
+            sess.save(a)
 
-        sess.save(u)
-        sess.flush()
+        if not autoflush:
+            sess.flush()
+        
+        assert u in sess
+        assert a in sess
         
-        assert list(u.addresses) == [a]
+        self.assert_(list(u.addresses) == [a])
 
         a.user = None
-        self.assert_(list(u.addresses) == [a])
+        if not autoflush:
+            self.assert_(list(u.addresses) == [a])
 
-        sess.flush()
+        if not autoflush:
+            sess.flush()
         self.assert_(list(u.addresses) == [])
 
-        
+    test_backref.__name__ = "test_%s%s" % (
+        (autoflush and "autoflush" or ""),
+        (saveuser and "_saveuser" or "_savead"),
+    )
+    setattr(FlushTest, test_backref.__name__, test_backref)
+
+for autoflush in (False, True):
+    for saveuser in (False, True):   
+        create_backref_test(autoflush, saveuser) 
+
 if __name__ == '__main__':
     testbase.main()
     
index 8b7312251c5ac3f98a3aaeb31ef907a166c06c19..73449eb96c65930a348f444415f6b120df6d7b5f 100644 (file)
@@ -28,8 +28,8 @@ class Base(object):
                     continue
                 value = getattr(self, attr)
                 if hasattr(value, '__iter__') and not isinstance(value, basestring):
-                    if len(value) == 0:
-                        continue
+                    if len(value) != len(getattr(other, attr)):
+                       return False
                     for (us, them) in zip(value, getattr(other, attr)):
                         if us != them:
                             return False
index 6684c628815d3a8d256cf9a7a966dcb4b27e8912..6baba83941495688898643f18b495602ccaa40db 100644 (file)
@@ -214,6 +214,11 @@ class LazyTest(QueryTest):
             User(id=10)
         
         ] == q.all()
+        
+        sess = create_session()
+        user = sess.query(User).get(7)
+        assert [Order(id=1), Order(id=5)] == create_session().query(Order, entity_name='closed').with_parent(user, property='closed_orders').all()
+        assert [Order(id=3)] == create_session().query(Order, entity_name='open').with_parent(user, property='open_orders').all()
 
     def test_many_to_many(self):
 
index 5d17d7d817e03f5ba4dcbdcb2073a94739cbb802..e02c5a6432e2b9d275ba997471500e12d967fd00 100644 (file)
@@ -6,7 +6,7 @@ from sqlalchemy.orm import *
 from testlib import *
 from fixtures import *
 
-class QueryTest(ORMTest):
+class QueryTest(FixtureTest):
     keep_mappers = True
     keep_data = True
     
@@ -19,11 +19,6 @@ class QueryTest(ORMTest):
         clear_mappers()
         super(QueryTest, self).tearDownAll()
           
-    def define_tables(self, meta):
-        # a slight dirty trick here. 
-        meta.tables = metadata.tables
-        metadata.connect(meta.bind)
-        
     def setup_mappers(self):
         mapper(User, users, properties={
             'addresses':relation(Address, backref='user'),