]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Feature enhancement: joined and subquery
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2011 19:14:03 +0000 (15:14 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2011 19:14:03 +0000 (15:14 -0400)
    loading will now traverse already-present related
    objects and collections in search of unpopulated
    attributes throughout the scope of the eager load
    being defined, so that the eager loading that is
    specified via mappings or query options
    unconditionally takes place for the full depth,
    populating whatever is not already populated.
    Previously, this traversal would stop if a related
    object or collection were already present leading
    to inconsistent behavior (though would save on
    loads/cycles for an already-loaded graph). For a
    subqueryload, this means that the additional
    SELECT statements emitted by subqueryload will
    invoke unconditionally, no matter how much of the
    existing graph is already present (hence the
    controversy). The previous behavior of "stopping"
    is still in effect when a query is the result of
    an attribute-initiated lazyload, as otherwise an
    "N+1" style of collection iteration can become
    needlessly expensive when the same related object
    is encountered repeatedly. There's also an
    as-yet-not-public generative Query method
    _with_invoke_all_eagers()
    which selects old/new behavior [ticket:2213]

CHANGES
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
test/orm/test_eager_relations.py
test/orm/test_subquery_relations.py

diff --git a/CHANGES b/CHANGES
index 5734405523d251002f189b46410351606abe8c08..cea7db4696501f3081f5f69d6b5ff6a96879d8d3 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -6,6 +6,32 @@ CHANGES
 0.7.2
 =====
 - orm
+  - Feature enhancement: joined and subquery
+    loading will now traverse already-present related
+    objects and collections in search of unpopulated
+    attributes throughout the scope of the eager load
+    being defined, so that the eager loading that is
+    specified via mappings or query options
+    unconditionally takes place for the full depth,
+    populating whatever is not already populated.
+    Previously, this traversal would stop if a related
+    object or collection were already present leading
+    to inconsistent behavior (though would save on
+    loads/cycles for an already-loaded graph). For a
+    subqueryload, this means that the additional
+    SELECT statements emitted by subqueryload will
+    invoke unconditionally, no matter how much of the
+    existing graph is already present (hence the
+    controversy). The previous behavior of "stopping"
+    is still in effect when a query is the result of
+    an attribute-initiated lazyload, as otherwise an
+    "N+1" style of collection iteration can become
+    needlessly expensive when the same related object
+    is encountered repeatedly. There's also an 
+    as-yet-not-public generative Query method 
+    _with_invoke_all_eagers()
+    which selects old/new behavior [ticket:2213]
+
   - Fixed subtle bug that caused SQL to blow
     up if: column_property() against subquery +
     joinedload + LIMIT + order by the column
index 19c302ec9e1028f829987c853e6c5e6de0d57346..4f8e6dbfda95b1cf333d5937ad6bd2b614b5be83 100644 (file)
@@ -631,7 +631,6 @@ class LoaderStrategy(object):
       on a particular mapped instance.
 
     """
-
     def __init__(self, parent):
         self.parent_property = parent
         self.is_class_level = False
index 6e97aaf25e3df21ecb4e3c3e02614b14080c283e..8477f5e2f442c5a6c222fb9ca792f700ebf8d4c8 100644 (file)
@@ -2418,6 +2418,7 @@ class Mapper(object):
 
         new_populators = []
         existing_populators = []
+        eager_populators = []
         load_path = context.query._current_path + path
 
         def populate_state(state, dict_, row, isnew, only_load_props):
@@ -2430,7 +2431,8 @@ class Mapper(object):
             if not new_populators:
                 self._populators(context, path, reduced_path, row, adapter,
                                 new_populators,
-                                existing_populators
+                                existing_populators,
+                                eager_populators
                 )
 
             if isnew:
@@ -2438,13 +2440,13 @@ class Mapper(object):
             else:
                 populators = existing_populators
 
-            if only_load_props:
+            if only_load_props is None:
+                for key, populator in populators:
+                    populator(state, dict_, row)
+            elif only_load_props:
                 for key, populator in populators:
                     if key in only_load_props:
                         populator(state, dict_, row)
-            else:
-                for key, populator in populators:
-                    populator(state, dict_, row)
 
         session_identity_map = context.session.identity_map
 
@@ -2455,12 +2457,21 @@ class Mapper(object):
         populate_instance = listeners.populate_instance or None
         append_result = listeners.append_result or None
         populate_existing = context.populate_existing or self.always_refresh
+        invoke_all_eagers = context.invoke_all_eagers
+
         if self.allow_partial_pks:
             is_not_primary_key = _none_set.issuperset
         else:
             is_not_primary_key = _none_set.issubset
 
         def _instance(row, result):
+            if not new_populators and invoke_all_eagers:
+                self._populators(context, path, reduced_path, row, adapter,
+                                new_populators,
+                                existing_populators,
+                                eager_populators
+                )
+
             if translate_row:
                 for fn in translate_row:
                     ret = fn(self, context, row)
@@ -2584,11 +2595,10 @@ class Mapper(object):
                 elif isnew:
                     state.manager.dispatch.refresh(state, context, only_load_props)
 
-            elif state in context.partials or state.unloaded:
+            elif state in context.partials or state.unloaded or eager_populators:
                 # state is having a partial set of its attributes
                 # refreshed.  Populate those attributes,
                 # and add to the "context.partials" collection.
-
                 if state in context.partials:
                     isnew = False
                     (d_, attrs) = context.partials[state]
@@ -2609,6 +2619,10 @@ class Mapper(object):
                 else:
                     populate_state(state, dict_, row, isnew, attrs)
 
+                for key, pop in eager_populators:
+                    if key not in state.unloaded:
+                        pop(state, dict_, row)
+
                 if isnew:
                     state.manager.dispatch.refresh(state, context, attrs)
 
@@ -2629,21 +2643,19 @@ class Mapper(object):
         return _instance
 
     def _populators(self, context, path, reduced_path, row, adapter,
-            new_populators, existing_populators):
+            new_populators, existing_populators, eager_populators):
         """Produce a collection of attribute level row processor callables."""
 
         delayed_populators = []
+        pops = (new_populators, existing_populators, delayed_populators, eager_populators)
         for prop in self._props.itervalues():
-            newpop, existingpop, delayedpop = prop.create_row_processor(
-                                                    context, path, 
-                                                    reduced_path,
-                                                    self, row, adapter)
-            if newpop:
-                new_populators.append((prop.key, newpop))
-            if existingpop:
-                existing_populators.append((prop.key, existingpop))
-            if delayedpop:
-                delayed_populators.append((prop.key, delayedpop))
+            for i, pop in enumerate(prop.create_row_processor(
+                                        context, path, 
+                                        reduced_path,
+                                        self, row, adapter)):
+                if pop is not None:
+                    pops[i].append((prop.key, pop))
+
         if delayed_populators:
             new_populators.extend(delayed_populators)
 
index 2a13ce32138fe33443fa77af0dd4436787f19b1c..5570e5a542827da1e052d42519c235c6954f6901 100644 (file)
@@ -82,6 +82,7 @@ class Query(object):
     _statement = None
     _correlate = frozenset()
     _populate_existing = False
+    _invoke_all_eagers = True
     _version_check = False
     _autoflush = True
     _current_path = ()
@@ -733,6 +734,17 @@ class Query(object):
         """
         self._populate_existing = True
 
+    @_generative()
+    def _with_invoke_all_eagers(self, value):
+        """Set the 'invoke all eagers' flag which causes joined- and
+        subquery loaders to traverse into already-loaded related objects
+        and collections.
+        
+        Default is that of :attr:`.Query._invoke_all_eagers`.
+
+        """
+        self._invoke_all_eagers = value
+
     def with_parent(self, instance, property=None):
         """Add filtering criterion that relates the given instance
         to a child object or collection, using its attribute state 
@@ -2908,6 +2920,7 @@ class QueryContext(object):
         self.query = query
         self.session = query.session
         self.populate_existing = query._populate_existing
+        self.invoke_all_eagers = query._invoke_all_eagers
         self.version_check = query._version_check
         self.refresh_state = query._refresh_state
         self.primary_columns = []
index 2adc5733a410a51f6c34e51e8ddd2a2b7073ee85..2d157aed5bfadf4bfc314042af7431216d09e4e6 100644 (file)
@@ -521,6 +521,8 @@ class LazyLoader(AbstractRelationshipLoader):
 
         q = session.query(prop_mapper)._adapt_all_clauses()
 
+        q = q._with_invoke_all_eagers(False)
+
         # don't autoflush on pending
         if pending:
             q = q.autoflush(False)
@@ -920,7 +922,6 @@ class JoinedLoader(AbstractRelationshipLoader):
     using joined eager loading.
     
     """
-
     def init(self):
         super(JoinedLoader, self).init()
         self.join_depth = self.parent_property.join_depth
@@ -1160,6 +1161,9 @@ class JoinedLoader(AbstractRelationshipLoader):
                                 our_reduced_path + (self.mapper.base_mapper,),
                                 eager_adapter)
 
+            def eager_exec(state, dict_, row):
+                _instance(row, None)
+
             if not self.uselist:
                 def new_execute(state, dict_, row):
                     # set a scalar object instance directly on the parent
@@ -1177,7 +1181,7 @@ class JoinedLoader(AbstractRelationshipLoader):
                             "Multiple rows returned with "
                             "uselist=False for eagerly-loaded attribute '%s' "
                             % self)
-                return new_execute, existing_execute, None
+                return new_execute, existing_execute, None, eager_exec
             else:
                 def new_execute(state, dict_, row):
                     collection = attributes.init_state_collection(
@@ -1202,7 +1206,7 @@ class JoinedLoader(AbstractRelationshipLoader):
                                                 'append_without_event')
                         context.attributes[(state, key)] = result_list
                     _instance(row, result_list)
-            return new_execute, existing_execute, None
+            return new_execute, existing_execute, None, eager_exec
         else:
             return self.parent_property.\
                             _get_strategy(LazyLoader).\
@@ -1243,8 +1247,6 @@ def factory(identifier):
     else:
         return LazyLoader
 
-
-
 class EagerJoinOption(PropertyOption):
 
     def __init__(self, key, innerjoin, chained=False):
index e3914e96c8af07b2965f8d9f9f8f955e2dcd9d77..f07949b769f6e8c39cedbf67ee0783ed4fcee880 100644 (file)
@@ -4,7 +4,7 @@ from test.lib.testing import eq_, is_, is_not_
 import sqlalchemy as sa
 from test.lib import testing
 from sqlalchemy.orm import joinedload, deferred, undefer, \
-    joinedload_all, backref, eagerload
+    joinedload_all, backref, eagerload, Session, immediateload
 from sqlalchemy import Integer, String, Date, ForeignKey, and_, select, \
     func
 from test.lib.schema import Table, Column
@@ -1330,6 +1330,8 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
             use_default_dialect=True
         )
 
+
+
 class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
     """test #2188"""
 
@@ -1504,6 +1506,107 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
         )
 
 
+class LoadOnExistingTest(_fixtures.FixtureTest):
+    """test that loaders from a base Query fully populate."""
+
+    run_inserts = 'once'
+    run_deletes = None
+
+    def _collection_to_scalar_fixture(self):
+        User, Address, Dingaling = self.classes.User, \
+            self.classes.Address, self.classes.Dingaling
+        mapper(User, self.tables.users, properties={
+            'addresses':relationship(Address),
+        })
+        mapper(Address, self.tables.addresses, properties={
+            'dingaling':relationship(Dingaling)
+        })
+        mapper(Dingaling, self.tables.dingalings)
+
+        sess = Session(autoflush=False)
+        return User, Address, Dingaling, sess
+
+    def _collection_to_collection_fixture(self):
+        User, Order, Item = self.classes.User, \
+            self.classes.Order, self.classes.Item
+        mapper(User, self.tables.users, properties={
+            'orders':relationship(Order), 
+        })
+        mapper(Order, self.tables.orders, properties={
+            'items':relationship(Item, secondary=self.tables.order_items),
+        })
+        mapper(Item, self.tables.items)
+
+        sess = Session(autoflush=False)
+        return User, Order, Item, sess
+
+    def _eager_config_fixture(self):
+        User, Address = self.classes.User, self.classes.Address
+        mapper(User, self.tables.users, properties={
+            'addresses':relationship(Address, lazy="joined"),
+        })
+        mapper(Address, self.tables.addresses)
+        sess = Session(autoflush=False)
+        return User, Address, sess
+
+    def test_no_query_on_refresh(self):
+        User, Address, sess = self._eager_config_fixture()
+
+        u1 = sess.query(User).get(8)
+        assert 'addresses' in u1.__dict__
+        sess.expire(u1)
+        def go():
+            eq_(u1.id, 8)
+        self.assert_sql_count(testing.db, go, 1)
+        assert 'addresses' not in u1.__dict__
+
+    def test_loads_second_level_collection_to_scalar(self):
+        User, Address, Dingaling, sess = self._collection_to_scalar_fixture()
+
+        u1 = sess.query(User).get(8)
+        a1 = Address()
+        u1.addresses.append(a1)
+        a2 = u1.addresses[0]
+        a2.email_address = 'foo'
+        sess.query(User).options(joinedload_all("addresses.dingaling")).\
+                            filter_by(id=8).all()
+        assert u1.addresses[-1] is a1
+        for a in u1.addresses:
+            if a is not a1:
+                assert 'dingaling' in a.__dict__
+            else:
+                assert 'dingaling' not in a.__dict__
+            if a is a2:
+                eq_(a2.email_address, 'foo')
+
+    def test_loads_second_level_collection_to_collection(self):
+        User, Order, Item, sess = self._collection_to_collection_fixture()
+
+        u1 = sess.query(User).get(7)
+        u1.orders
+        o1 = Order()
+        u1.orders.append(o1)
+        sess.query(User).options(joinedload_all("orders.items")).\
+                            filter_by(id=7).all()
+        for o in u1.orders:
+            if o is not o1:
+                assert 'items' in o.__dict__
+            else:
+                assert 'items' not in o.__dict__
+
+    def test_load_two_levels_collection_to_scalar(self):
+        User, Address, Dingaling, sess = self._collection_to_scalar_fixture()
+
+        u1 = sess.query(User).filter_by(id=8).options(joinedload("addresses")).one()
+        sess.query(User).filter_by(id=8).options(joinedload_all("addresses.dingaling")).first()
+        assert 'dingaling' in u1.addresses[0].__dict__
+
+    def test_load_two_levels_collection_to_collection(self):
+        User, Order, Item, sess = self._collection_to_collection_fixture()
+
+        u1 = sess.query(User).filter_by(id=7).options(joinedload("orders")).one()
+        sess.query(User).filter_by(id=7).options(joinedload_all("orders.items")).first()
+        assert 'items' in u1.orders[0].__dict__
 
 
 class AddEntityTest(_fixtures.FixtureTest):
index b84b4a3f5840c593a31ec7a46989473c063a8306..8673211f8b62d1262a474d8b201a850642c531a5 100644 (file)
@@ -4,7 +4,8 @@ from test.lib.schema import Table, Column
 from sqlalchemy import Integer, String, ForeignKey, bindparam
 from sqlalchemy.orm import backref, subqueryload, subqueryload_all, \
     mapper, relationship, clear_mappers, create_session, lazyload, \
-    aliased, joinedload, deferred, undefer, eagerload_all
+    aliased, joinedload, deferred, undefer, eagerload_all,\
+    Session
 from test.lib.testing import eq_, assert_raises, \
     assert_raises_message
 from test.lib.assertsql import CompiledSQL
@@ -766,6 +767,128 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
         assert_raises(sa.exc.SAWarning,
                 s.query(User).options(subqueryload(User.order)).all)
 
+class LoadOnExistingTest(_fixtures.FixtureTest):
+    """test that loaders from a base Query fully populate."""
+
+    run_inserts = 'once'
+    run_deletes = None
+
+    def _collection_to_scalar_fixture(self):
+        User, Address, Dingaling = self.classes.User, \
+            self.classes.Address, self.classes.Dingaling
+        mapper(User, self.tables.users, properties={
+            'addresses':relationship(Address),
+        })
+        mapper(Address, self.tables.addresses, properties={
+            'dingaling':relationship(Dingaling)
+        })
+        mapper(Dingaling, self.tables.dingalings)
+
+        sess = Session(autoflush=False)
+        return User, Address, Dingaling, sess
+
+    def _collection_to_collection_fixture(self):
+        User, Order, Item = self.classes.User, \
+            self.classes.Order, self.classes.Item
+        mapper(User, self.tables.users, properties={
+            'orders':relationship(Order), 
+        })
+        mapper(Order, self.tables.orders, properties={
+            'items':relationship(Item, secondary=self.tables.order_items),
+        })
+        mapper(Item, self.tables.items)
+
+        sess = Session(autoflush=False)
+        return User, Order, Item, sess
+
+    def _eager_config_fixture(self):
+        User, Address = self.classes.User, self.classes.Address
+        mapper(User, self.tables.users, properties={
+            'addresses':relationship(Address, lazy="subquery"),
+        })
+        mapper(Address, self.tables.addresses)
+        sess = Session(autoflush=False)
+        return User, Address, sess
+
+    def _deferred_config_fixture(self):
+        User, Address = self.classes.User, self.classes.Address
+        mapper(User, self.tables.users, properties={
+            'name':deferred(self.tables.users.c.name),
+            'addresses':relationship(Address, lazy="subquery"),
+        })
+        mapper(Address, self.tables.addresses)
+        sess = Session(autoflush=False)
+        return User, Address, sess
+
+    def test_no_query_on_refresh(self):
+        User, Address, sess = self._eager_config_fixture()
+
+        u1 = sess.query(User).get(8)
+        assert 'addresses' in u1.__dict__
+        sess.expire(u1)
+        def go():
+            eq_(u1.id, 8)
+        self.assert_sql_count(testing.db, go, 1)
+        assert 'addresses' not in u1.__dict__
+
+    def test_no_query_on_deferred(self):
+        User, Address, sess = self._deferred_config_fixture()
+        u1 = sess.query(User).get(8)
+        assert 'addresses' in u1.__dict__
+        sess.expire(u1, ['addresses'])
+        def go():
+            eq_(u1.name, 'ed')
+        self.assert_sql_count(testing.db, go, 1)
+        assert 'addresses' not in u1.__dict__
+
+    def test_loads_second_level_collection_to_scalar(self):
+        User, Address, Dingaling, sess = self._collection_to_scalar_fixture()
+
+        u1 = sess.query(User).get(8)
+        a1 = Address()
+        u1.addresses.append(a1)
+        a2 = u1.addresses[0]
+        a2.email_address = 'foo'
+        sess.query(User).options(subqueryload_all("addresses.dingaling")).\
+                            filter_by(id=8).all()
+        assert u1.addresses[-1] is a1
+        for a in u1.addresses:
+            if a is not a1:
+                assert 'dingaling' in a.__dict__
+            else:
+                assert 'dingaling' not in a.__dict__
+            if a is a2:
+                eq_(a2.email_address, 'foo')
+
+    def test_loads_second_level_collection_to_collection(self):
+        User, Order, Item, sess = self._collection_to_collection_fixture()
+
+        u1 = sess.query(User).get(7)
+        u1.orders
+        o1 = Order()
+        u1.orders.append(o1)
+        sess.query(User).options(subqueryload_all("orders.items")).\
+                            filter_by(id=7).all()
+        for o in u1.orders:
+            if o is not o1:
+                assert 'items' in o.__dict__
+            else:
+                assert 'items' not in o.__dict__
+
+    def test_load_two_levels_collection_to_scalar(self):
+        User, Address, Dingaling, sess = self._collection_to_scalar_fixture()
+
+        u1 = sess.query(User).filter_by(id=8).options(subqueryload("addresses")).one()
+        sess.query(User).filter_by(id=8).options(subqueryload_all("addresses.dingaling")).first()
+        assert 'dingaling' in u1.addresses[0].__dict__
+
+    def test_load_two_levels_collection_to_collection(self):
+        User, Order, Item, sess = self._collection_to_collection_fixture()
+
+        u1 = sess.query(User).filter_by(id=7).options(subqueryload("orders")).one()
+        sess.query(User).filter_by(id=7).options(subqueryload_all("orders.items")).first()
+        assert 'items' in u1.orders[0].__dict__
+
 class OrderBySecondaryTest(fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):