]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- eagerload(), lazyload(), eagerload_all() take an optional
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 11 Dec 2007 19:33:36 +0000 (19:33 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 11 Dec 2007 19:33:36 +0000 (19:33 +0000)
second class-or-mapper argument, which will select the mapper
to apply the option towards.  This can select among other
mappers which were added using add_entity().

- eagerloading will work with mappers added via add_entity().

CHANGES
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
test/orm/eager_relations.py
test/orm/mapper.py
test/orm/merge.py

diff --git a/CHANGES b/CHANGES
index 5a9893368cff6933a77c71cacc188fb731670dd3..38dec541b1fe0b49019fbb7abb579dabfda7f442 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -78,7 +78,14 @@ CHANGES
      new behavior allows not just joins from the main table, but select 
      statements as well.  Filter criterion, order bys, eager load
      clauses will be "aliased" against the given statement.
-     
+
+   - eagerload(), lazyload(), eagerload_all() take an optional 
+     second class-or-mapper argument, which will select the mapper
+     to apply the option towards.  This can select among other
+     mappers which were added using add_entity().  
+
+   - eagerloading will work with mappers added via add_entity().
+          
    - added "cascade delete" behavior to "dynamic" relations just like
      that of regular relations.  if passive_deletes flag (also just added)
      is not set, a delete of the parent item will trigger a full load of 
index 56edc03f20d38200954c474f808a6c62b8ea8d3b..ac784ec08d6d90f2441c26328510e80d00f04818 100644 (file)
@@ -622,15 +622,15 @@ def extension(ext):
 
     return ExtensionOption(ext)
 
-def eagerload(name):
+def eagerload(name, mapper=None):
     """Return a ``MapperOption`` that will convert the property of the given name into an eager load.
 
     Used with ``query.options()``.
     """
 
-    return strategies.EagerLazyOption(name, lazy=False)
+    return strategies.EagerLazyOption(name, lazy=False, mapper=mapper)
 
-def eagerload_all(name):
+def eagerload_all(name, mapper=None):
     """Return a ``MapperOption`` that will convert all properties along the given dot-separated path into an eager load.
     
     For example, this::
@@ -643,16 +643,16 @@ def eagerload_all(name):
     Used with ``query.options()``.
     """
 
-    return strategies.EagerLazyOption(name, lazy=False, chained=True)
+    return strategies.EagerLazyOption(name, lazy=False, chained=True, mapper=mapper)
 
-def lazyload(name):
+def lazyload(name, mapper=None):
     """Return a ``MapperOption`` that will convert the property of the
     given name into a lazy load.
 
     Used with ``query.options()``.
     """
 
-    return strategies.EagerLazyOption(name, lazy=True)
+    return strategies.EagerLazyOption(name, lazy=True, mapper=mapper)
 
 def fetchmode(name, type):
     return strategies.FetchModeOption(name, type)
index 1ff019c78b763e82a61a5e914c9e0f861746076c..413a1af2cdb5900282aa84149dbf5099f983c4b3 100644 (file)
@@ -5,8 +5,9 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 
-from sqlalchemy import util, logging
+from sqlalchemy import util, logging, exceptions
 from sqlalchemy.sql import expression
+class_mapper = None
 
 __all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
            'MapperProperty', 'PropComparator', 'StrategizedProperty', 
@@ -503,6 +504,14 @@ class MapperOption(object):
 
     def process_query(self, query):
         pass
+    
+    def process_query_conditionally(self, query):
+        """same as process_query(), except that this option may not apply
+        to the given query.  
+        
+        Used when secondary loaders resend existing options to a new 
+        Query."""
+        self.process_query(query)
 
 class ExtensionOption(MapperOption):
     """a MapperOption that applies a MapperExtension to a query operation."""
@@ -520,30 +529,47 @@ class PropertyOption(MapperOption):
     one of its child mappers, identified by a dot-separated key.
     """
 
-    def __init__(self, key):
+    def __init__(self, key, mapper=None):
         self.key = key
-
+        self.mapper = mapper
+        
     def process_query(self, query):
+        self._process(query, True)
+        
+    def process_query_conditionally(self, query):
+        self._process(query, False)
+        
+    def _process(self, query, raiseerr):
         if self._should_log_debug:
             self.logger.debug("applying option to Query, property key '%s'" % self.key)
-        paths = self._get_paths(query)
+        paths = self._get_paths(query, raiseerr)
         if paths:
             self.process_query_property(query, paths)
 
     def process_query_property(self, query, paths):
         pass
 
-    def _get_paths(self, query):
+    def _get_paths(self, query, raiseerr):
         path = None
         l = []
         current_path = list(query._current_path)
         
-        mapper = query.mapper
+        if self.mapper:
+            global class_mapper
+            if class_mapper is None:
+                from sqlalchemy.orm import class_mapper
+            mapper = self.mapper
+            if isinstance(self.mapper, type):
+                mapper = class_mapper(mapper)
+            if mapper is not query.mapper and mapper not in [q[0] for q in query._entities]:
+                raise exceptions.ArgumentError("Can't find entity %s in Query.  Current list: %r" % (str(mapper), [str(m) for m in [query.mapper] + query._entities]))
+        else:
+            mapper = query.mapper
         for token in self.key.split('.'):
             if current_path and token == current_path[1]:
                 current_path = current_path[2:]
                 continue
-            prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=False)
+            prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
             if prop is None:
                 return []
             path = build_path(mapper, prop.key, path)
index bbbdce69846b25b25b81f6e7575304a3b3c0ab67..d4e6ccb408119f959234924d2346aed3dd0a35ab 100644 (file)
@@ -279,16 +279,26 @@ class Query(object):
         MapperOptions.
         """
         
+        return self._options(False, *args)
+
+    def _conditional_options(self, *args):
+        return self._options(True, *args)
+        
+    def _options(self, conditional, *args):
         q = self._clone()
         # most MapperOptions write to the '_attributes' dictionary,
         # so copy that as well
         q._attributes = q._attributes.copy()
         opts = [o for o in util.flatten_iterator(args)]
         q._with_options = q._with_options + opts
-        for opt in opts:
-            opt.process_query(q)
+        if conditional:
+            for opt in opts:
+                opt.process_query_conditionally(q)
+        else:
+            for opt in opts:
+                opt.process_query(q)
         return q
-
+    
     def with_lockmode(self, mode):
         """Return a new Query object with the specified locking mode."""
         q = self._clone()
@@ -903,6 +913,11 @@ class Query(object):
         whereclause = self._criterion
 
         from_obj = self._from_obj
+        
+        # indicates if the "from" clause of the query does not include 
+        # the normally mapped table, i.e. the user issued select_from(somestatement)
+        # or similar.  all clauses which derive from the mapped table will need to
+        # be adapted to be relative to the user-supplied selectable.
         adapt_criterion = self.table not in self._get_joinable_tables()
 
         if not adapt_criterion and whereclause is not None and (self.mapper is not self.select_mapper):
@@ -947,7 +962,7 @@ class Query(object):
             clauses = self._get_entity_clauses(tup)
             if isinstance(m, mapper.Mapper):
                 for value in m.iterate_properties:
-                    context.exec_with_path(self.select_mapper, value.key, value.setup, context, parentclauses=clauses)
+                    context.exec_with_path(m, value.key, value.setup, context, parentclauses=clauses)
             elif isinstance(m, sql.ColumnElement):
                 if clauses is not None:
                     m = clauses.aliased_column(m)
index 027a323939766da4e4f2cabf003908d3850bd577..d9390345e9a97b3c0325ee77966c7786ea0bd949 100644 (file)
@@ -352,7 +352,7 @@ class LazyLoader(AbstractRelationLoader):
                 if not nonnulls:
                     return None
                 if options:
-                    q = q.options(*options)
+                    q = q._conditional_options(*options)
                 return q.get(ident)
             elif self.order_by is not False:
                 q = q.order_by(self.order_by)
@@ -360,7 +360,7 @@ class LazyLoader(AbstractRelationLoader):
                 q = q.order_by(self.secondary.default_order_by())
 
             if options:
-                q = q.options(*options)
+                q = q._conditional_options(*options)
             q = q.filter(self.lazy_clause(instance))
 
             result = q.all()
@@ -617,8 +617,8 @@ class EagerLoader(AbstractRelationLoader):
 EagerLoader.logger = logging.class_logger(EagerLoader)
 
 class EagerLazyOption(StrategizedOption):
-    def __init__(self, key, lazy=True, chained=False):
-        super(EagerLazyOption, self).__init__(key)
+    def __init__(self, key, lazy=True, chained=False, mapper=None):
+        super(EagerLazyOption, self).__init__(key, mapper)
         self.lazy = lazy
         self.chained = chained
         
index 6b2baad95d94e025ca0d5063839c1f07d3a6a580..192eafaed34cbbffa9b3ff57e59602a50ab0118a 100644 (file)
@@ -11,9 +11,6 @@ class EagerTest(FixtureTest):
     keep_mappers = False
     keep_data = True
     
-    def setup_mappers(self):
-        pass
-
     def test_basic(self):
         mapper(User, users, properties={
             'addresses':relation(mapper(Address, addresses), lazy=False)
@@ -541,6 +538,76 @@ class EagerTest(FixtureTest):
         l = q.filter(addresses.c.email_address == 'ed@lala.com').filter(Address.user_id==User.id)
         assert fixtures.user_address_result[1:2] == l.all()
 
+class AddEntityTest(FixtureTest):
+    keep_mappers = False
+    keep_data = True
+
+    def _assert_result(self):
+        return [
+            (
+                User(id=7, addresses=[Address(id=1)]),
+                Order(id=1, items=[Item(id=1), Item(id=2), Item(id=3)]),
+            ),
+            (
+                User(id=7, addresses=[Address(id=1)]),
+                Order(id=3, items=[Item(id=3), Item(id=4), Item(id=5)]),
+            ),
+            (
+                User(id=7, addresses=[Address(id=1)]),
+                Order(id=5, items=[Item(id=5)]),
+            ),
+            (
+                 User(id=9, addresses=[Address(id=5)]),
+                 Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)]),
+             ),
+             (
+                  User(id=9, addresses=[Address(id=5)]),
+                  Order(id=4, items=[Item(id=1), Item(id=5)]),
+              )
+        ]
+        
+    def test_basic(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, lazy=False),
+            'orders':relation(Order)
+        })
+        mapper(Address, addresses)
+        mapper(Order, orders, properties={
+            'items':relation(Item, secondary=order_items, lazy=False)
+        })
+        mapper(Item, items)
+
+
+        sess = create_session()
+        def go():
+            ret = sess.query(User).add_entity(Order).join('orders', aliased=True).all()
+            self.assertEquals(ret, self._assert_result())
+        self.assert_sql_count(testbase.db, go, 1)
+
+    def test_options(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address),
+            'orders':relation(Order)
+        })
+        mapper(Address, addresses)
+        mapper(Order, orders, properties={
+            'items':relation(Item, secondary=order_items)
+        })
+        mapper(Item, items)
+
+        sess = create_session()
+
+        def go():
+            ret = sess.query(User).options(eagerload('addresses')).add_entity(Order).join('orders', aliased=True).all()
+            self.assertEquals(ret, self._assert_result())
+        self.assert_sql_count(testbase.db, go, 6)
+
+        sess.clear()
+        def go():
+            ret = sess.query(User).options(eagerload('addresses')).add_entity(Order).options(eagerload('items', Order)).join('orders', aliased=True).all()
+            self.assertEquals(ret, self._assert_result())
+        self.assert_sql_count(testbase.db, go, 1)
+
 class SelfReferentialEagerTest(ORMTest):
     def define_tables(self, metadata):
         global nodes
index 65a6ad8fa016ceb59d789810115dc81df960d934..3847a49a0294e086956dc497126ff20561c27335 100644 (file)
@@ -711,11 +711,15 @@ class OptionsTest(MapperSuperTest):
         u = q2.select()
         def go():
             print u[0].orders[1].items[0].keywords[1]
-        print "-------MARK3----------"
         self.assert_sql_count(testbase.db, go, 0)
-        print "-------MARK4----------"
 
         sess.clear()
+        
+        try:
+            sess.query(User).options(eagerload('items', Order))
+            assert False
+        except exceptions.ArgumentError, e:
+            assert str(e) == "Can't find entity Mapper|Order|orders in Query.  Current list: ['Mapper|User|users']"
 
         # eagerload "keywords" on items.  it will lazy load "orders", then lazy load
         # the "items" on the order, but on "items" it will eager load the "keywords"
index a58665136fb24eb31c5e4fb883d3d97dc4948263..8eeafb8e62d04fd1a6c4cf689e5b807a9754108f 100644 (file)
@@ -257,7 +257,7 @@ class MergeTest(AssertMixin):
         except exceptions.InvalidRequestError, e:
             assert "merge() with dont_load=True option does not support objects marked as 'dirty'.  flush() all changes on mapped instances before merging with dont_load=True." in str(e)
             
-        u2 = sess2.query(User).options(eagerload('addresses')).get(7)
+        u2 = sess2.query(User).get(7)
         
         sess3 = create_session()
         u3 = sess3.merge(u2, dont_load=True)