]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- refactor a bit the loader options system to make it a bit more
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 10 Jan 2016 22:47:38 +0000 (17:47 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 10 Jan 2016 22:47:38 +0000 (17:47 -0500)
intelligible, given the fixes for ref #3623.  unfortunately the system
is still quite weird even though it was rewritten to be... less weird

lib/sqlalchemy/orm/strategy_options.py
test/orm/test_options.py

index f08367941667739bab487ea3561a3dae5916fa55..aa818258a523e6afb40b9c66d56775c1941ed9f8 100644 (file)
@@ -80,6 +80,8 @@ class Load(Generative, MapperOption):
     def __init__(self, entity):
         insp = inspect(entity)
         self.path = insp._path_registry
+        # note that this .context is shared among all descendant
+        # Load objects
         self.context = {}
         self.local_opts = {}
 
@@ -88,7 +90,7 @@ class Load(Generative, MapperOption):
         cloned.local_opts = {}
         return cloned
 
-    _merge_into_path = False
+    is_opts_only = False
     strategy = None
     propagate_to_loaders = False
 
@@ -201,7 +203,7 @@ class Load(Generative, MapperOption):
             self._set_path_strategy()
 
     @_generative
-    def set_column_strategy(self, attrs, strategy, opts=None):
+    def set_column_strategy(self, attrs, strategy, opts=None, opts_only=False):
         strategy = self._coerce_strat(strategy)
 
         for attr in attrs:
@@ -212,21 +214,34 @@ class Load(Generative, MapperOption):
             cloned.propagate_to_loaders = True
             if opts:
                 cloned.local_opts.update(opts)
+            if opts_only:
+                cloned.is_opts_only = True
             cloned._set_path_strategy()
 
-    def _set_path_strategy(self):
-        if self._merge_into_path:
-            # special helper for undefer_group
-            existing = self.path.get(self.context, "loader")
+    def _set_for_path(self, context, path, replace=True, merge_opts=False):
+        if merge_opts or not replace:
+            existing = path.get(self.context, "loader")
+
             if existing:
-                existing.local_opts.update(self.local_opts)
+                if merge_opts:
+                    existing.local_opts.update(self.local_opts)
             else:
-                self.path.set(self.context, "loader", self)
+                path.set(context, "loader", self)
+        else:
+            existing = path.get(self.context, "loader")
+            path.set(context, "loader", self)
+            if existing and existing.is_opts_only:
+                self.local_opts.update(existing.local_opts)
 
-        elif self.path.has_entity:
-            self.path.parent.set(self.context, "loader", self)
+    def _set_path_strategy(self):
+        if self.path.has_entity:
+            effective_path = self.path.parent
         else:
-            self.path.set(self.context, "loader", self)
+            effective_path = self.path
+
+        self._set_for_path(
+            self.context, effective_path, replace=True,
+            merge_opts=self.is_opts_only)
 
     def __getstate__(self):
         d = self.__dict__.copy()
@@ -314,7 +329,7 @@ class _UnboundLoad(Load):
             val._bind_loader(query, query._attributes, raiseerr)
 
     @classmethod
-    def _from_keys(self, meth, keys, chained, kw):
+    def _from_keys(cls, meth, keys, chained, kw):
         opt = _UnboundLoad()
 
         def _split_key(key):
@@ -399,6 +414,7 @@ class _UnboundLoad(Load):
         loader = Load(path_element)
         loader.context = context
         loader.strategy = self.strategy
+        loader.is_opts_only = self.is_opts_only
 
         path = loader.path
         for token in start_path:
@@ -420,24 +436,15 @@ class _UnboundLoad(Load):
 
         if effective_path.is_token:
             for path in effective_path.generate_for_superclasses():
-                if self._merge_into_path:
-                    # special helper for undefer_group
-                    existing = path.get(context, "loader")
-                    if existing:
-                        existing.local_opts.update(self.local_opts)
-                    else:
-                        path.set(context, "loader", loader)
-                elif self._is_chain_link:
-                    path.setdefault(context, "loader", loader)
-                else:
-                    path.set(context, "loader", loader)
+                loader._set_for_path(
+                    context, path,
+                    replace=not self._is_chain_link,
+                    merge_opts=self.is_opts_only)
         else:
-            # only supported for the undefer_group() wildcard opt
-            assert not self._merge_into_path
-            if self._is_chain_link:
-                effective_path.setdefault(context, "loader", loader)
-            else:
-                effective_path.set(context, "loader", loader)
+            loader._set_for_path(
+                context, effective_path,
+                replace=not self._is_chain_link,
+                merge_opts=self.is_opts_only)
 
     def _find_entity_prop_comparator(self, query, token, mapper, raiseerr):
         if _is_aliased_class(mapper):
@@ -1043,11 +1050,11 @@ def undefer_group(loadopt, name):
         :func:`.orm.undefer`
 
     """
-    loadopt._merge_into_path = True
     return loadopt.set_column_strategy(
         "*",
         None,
-        {"undefer_group_%s" % name: True}
+        {"undefer_group_%s" % name: True},
+        opts_only=True
     )
 
 
index e1e26c62ce4db4e8e25fe4a3c0de561c3708339e..e7b750cf4baec39255332e40c27cc17fecc333b0 100644 (file)
@@ -3,11 +3,14 @@ from sqlalchemy.orm import attributes, mapper, relationship, backref, \
     configure_mappers, create_session, synonym, Session, class_mapper, \
     aliased, column_property, joinedload_all, joinedload, Query,\
     util as orm_util, Load, defer
+from sqlalchemy.orm.query import QueryContext
+from sqlalchemy.orm import strategy_options
 import sqlalchemy as sa
 from sqlalchemy import testing
-from sqlalchemy.testing.assertions import eq_, assert_raises, assert_raises_message
+from sqlalchemy.testing.assertions import eq_, assert_raises_message
 from test.orm import _fixtures
 
+
 class QueryTest(_fixtures.FixtureTest):
     run_setup_mappers = 'once'
     run_inserts = 'once'
@@ -17,6 +20,7 @@ class QueryTest(_fixtures.FixtureTest):
     def setup_mappers(cls):
         cls._setup_stock_mapping()
 
+
 class PathTest(object):
     def _make_path(self, path):
         r = []
@@ -160,11 +164,11 @@ class LoadTest(PathTest, QueryTest):
         )
 
 
+
+
 class OptionsTest(PathTest, QueryTest):
 
     def _option_fixture(self, *arg):
-        from sqlalchemy.orm import strategy_options
-
         return strategy_options._UnboundLoad._from_keys(
                     strategy_options._UnboundLoad.joinedload, arg, True, {})
 
@@ -768,3 +772,121 @@ class OptionsNoPropTest(_fixtures.FixtureTest):
                               create_session().query(column).options,
                               joinedload(eager_option))
 
+
+class LocalOptsTest(PathTest, QueryTest):
+    @classmethod
+    def setup_class(cls):
+        super(LocalOptsTest, cls).setup_class()
+
+        @strategy_options.loader_option()
+        def some_col_opt_only(loadopt, key, opts):
+            return loadopt.set_column_strategy(
+                (key, ),
+                None,
+                opts,
+                opts_only=True
+            )
+
+        @strategy_options.loader_option()
+        def some_col_opt_strategy(loadopt, key, opts):
+            return loadopt.set_column_strategy(
+                (key, ),
+                {"deferred": True, "instrument": True},
+                opts
+            )
+
+        cls.some_col_opt_only = some_col_opt_only
+        cls.some_col_opt_strategy = some_col_opt_strategy
+
+    def _assert_attrs(self, opts, expected):
+        User = self.classes.User
+
+        query = create_session().query(User)
+        attr = {}
+
+        for opt in opts:
+            if isinstance(opt, strategy_options._UnboundLoad):
+                for tb in opt._to_bind:
+                    tb._bind_loader(query, attr, False)
+            else:
+                attr.update(opt.context)
+
+        key = (
+            'loader',
+            tuple(inspect(User)._path_registry[User.name.property]))
+        eq_(
+            attr[key].local_opts,
+            expected
+        )
+
+    def test_single_opt_only(self):
+        opt = strategy_options._UnboundLoad().some_col_opt_only(
+            "name", {"foo": "bar"}
+        )
+        self._assert_attrs([opt], {"foo": "bar"})
+
+    def test_unbound_multiple_opt_only(self):
+        opts = [
+            strategy_options._UnboundLoad().some_col_opt_only(
+                "name", {"foo": "bar"}
+            ),
+            strategy_options._UnboundLoad().some_col_opt_only(
+                "name", {"bat": "hoho"}
+            )
+        ]
+        self._assert_attrs(opts, {"foo": "bar", "bat": "hoho"})
+
+    def test_bound_multiple_opt_only(self):
+        User = self.classes.User
+        opts = [
+            Load(User).some_col_opt_only(
+                "name", {"foo": "bar"}
+            ).some_col_opt_only(
+                "name", {"bat": "hoho"}
+            )
+        ]
+        self._assert_attrs(opts, {"foo": "bar", "bat": "hoho"})
+
+    def test_bound_strat_opt_recvs_from_optonly(self):
+        User = self.classes.User
+        opts = [
+            Load(User).some_col_opt_only(
+                "name", {"foo": "bar"}
+            ).some_col_opt_strategy(
+                "name", {"bat": "hoho"}
+            )
+        ]
+        self._assert_attrs(opts, {"foo": "bar", "bat": "hoho"})
+
+    def test_unbound_strat_opt_recvs_from_optonly(self):
+        opts = [
+            strategy_options._UnboundLoad().some_col_opt_only(
+                "name", {"foo": "bar"}
+            ),
+            strategy_options._UnboundLoad().some_col_opt_strategy(
+                "name", {"bat": "hoho"}
+            )
+        ]
+        self._assert_attrs(opts, {"foo": "bar", "bat": "hoho"})
+
+    def test_unbound_opt_only_adds_to_strat(self):
+        opts = [
+            strategy_options._UnboundLoad().some_col_opt_strategy(
+                "name", {"bat": "hoho"}
+            ),
+            strategy_options._UnboundLoad().some_col_opt_only(
+                "name", {"foo": "bar"}
+            ),
+        ]
+        self._assert_attrs(opts, {"foo": "bar", "bat": "hoho"})
+
+    def test_bound_opt_only_adds_to_strat(self):
+        User = self.classes.User
+        opts = [
+            Load(User).some_col_opt_strategy(
+                "name", {"bat": "hoho"}
+            ).some_col_opt_only(
+                "name", {"foo": "bar"}
+            ),
+        ]
+        self._assert_attrs(opts, {"foo": "bar", "bat": "hoho"})