]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add baked.Result.with_post_criteria method
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 12 Nov 2017 23:44:41 +0000 (18:44 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 12 Nov 2017 23:44:41 +0000 (18:44 -0500)
Added new method :meth:`.baked.Result.with_post_criteria` to baked
query system, allowing non-SQL-modifying transformations to take place
after the query has been pulled from the cache.  Among other things,
this method can be used with :class:`.horizontal_shard.ShardedQuery`
to set the shard identifier.   :class:`.horizontal_shard.ShardedQuery`
has also been modified such that its :meth:`.ShardedQuery.get` method
interacts correctly with that of :class:`.baked.Result`.

Change-Id: I04630c683240abbb4b99f0510a1a3dcb564815b4
Fixes: #4135
doc/build/changelog/unreleased_12/4135.rst [new file with mode: 0644]
lib/sqlalchemy/ext/baked.py
lib/sqlalchemy/ext/horizontal_shard.py
test/ext/test_baked.py
test/ext/test_horizontal_shard.py

diff --git a/doc/build/changelog/unreleased_12/4135.rst b/doc/build/changelog/unreleased_12/4135.rst
new file mode 100644 (file)
index 0000000..36dd869
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: enhancement, ext
+    :tickets: 4135
+
+    Added new method :meth:`.baked.Result.with_post_criteria` to baked
+    query system, allowing non-SQL-modifying transformations to take place
+    after the query has been pulled from the cache.  Among other things,
+    this method can be used with :class:`.horizontal_shard.ShardedQuery`
+    to set the shard identifier.   :class:`.horizontal_shard.ShardedQuery`
+    has also been modified such that its :meth:`.ShardedQuery.get` method
+    interacts correctly with that of :class:`.baked.Result`.
\ No newline at end of file
index c0fe963ac62b60ceb68e6736f9a44283bd23e4c0..8cae6e24b5af053a5c867d7f41a8dd312b3f27db 100644 (file)
@@ -261,12 +261,13 @@ class Result(object):
     against a target :class:`.Session`, and is then invoked for results.
 
     """
-    __slots__ = 'bq', 'session', '_params'
+    __slots__ = 'bq', 'session', '_params', '_post_criteria'
 
     def __init__(self, bq, session):
         self.bq = bq
         self.session = session
         self._params = {}
+        self._post_criteria = []
 
     def params(self, *args, **kw):
         """Specify parameters to be replaced into the string SQL statement."""
@@ -280,8 +281,37 @@ class Result(object):
         self._params.update(kw)
         return self
 
+    def _using_post_criteria(self, fns):
+        if fns:
+            self._post_criteria.extend(fns)
+        return self
+
+    def with_post_criteria(self, fn):
+        """Add a criteria function that will be applied post-cache.
+
+        This adds a function that will be run against the
+        :class:`.Query` object after it is retrieved from the
+        cache.    Functions here can be used to alter the query in ways
+        that **do not affect the SQL output**, such as execution options
+        and shard identifiers (when using a shard-enabled query object)
+
+        .. warning::  :meth:`.Result.with_post_criteria` functions are applied
+           to the :class:`.Query` object **after** the query's SQL statement
+           object has been retrieved from the cache.   Any operations here
+           which intend to modify the SQL should ensure that
+           :meth:`.BakedQuery.spoil` was called first.
+
+        .. versionadded:: 1.2
+
+
+        """
+        return self._using_post_criteria([fn])
+
     def _as_query(self):
-        return self.bq._as_query(self.session).params(self._params)
+        q = self.bq._as_query(self.session).params(self._params)
+        for fn in self._post_criteria:
+            q = fn(q)
+        return q
 
     def __str__(self):
         return str(self._as_query())
@@ -304,8 +334,11 @@ class Result(object):
         context.statement.use_labels = True
         if context.autoflush and not context.populate_existing:
             self.session._autoflush()
-        return context.query.params(self._params).\
-            with_session(self.session)._execute_and_instances(context)
+        q = context.query.params(self._params).with_session(self.session)
+        for fn in self._post_criteria:
+            q = fn(q)
+
+        return q._execute_and_instances(context)
 
     def count(self):
         """return the 'count'.
@@ -348,7 +381,9 @@ class Result(object):
 
         """
         bq = self.bq.with_criteria(lambda q: q.slice(0, 1))
-        ret = list(bq.for_session(self.session).params(self._params))
+        ret = list(
+            bq.for_session(self.session).params(self._params).
+            _using_post_criteria(self._post_criteria))
         if len(ret) > 0:
             return ret[0]
         else:
@@ -435,6 +470,8 @@ class Result(object):
 
             _lcl_get_clause = q._adapt_clause(_lcl_get_clause, True, False)
             q._criterion = _lcl_get_clause
+            for fn in self._post_criteria:
+                q = fn(q)
             return q
 
         # cache the query against a key that includes
index d20fbd4842c6ba1662bc1bafdee8b2b1bf0e070e..8902ae6065c94a3d38cbe92683b56a0017b1e9af 100644 (file)
@@ -61,17 +61,21 @@ class ShardedQuery(Query):
             # were done, this is where it would happen
             return iter(partial)
 
-    def get(self, ident, **kwargs):
-        if self._shard_id is not None:
-            return super(ShardedQuery, self).get(ident)
-        else:
-            ident = util.to_list(ident)
-            for shard_id in self.id_chooser(self, ident):
-                o = self.set_shard(shard_id).get(ident, **kwargs)
-                if o is not None:
-                    return o
+    def _get_impl(self, ident, fallback_fn):
+        def _fallback(query, ident):
+            if self._shard_id is not None:
+                return fallback_fn(self, ident)
             else:
-                return None
+                ident = util.to_list(ident)
+                for shard_id in self.id_chooser(self, ident):
+                    q = self.set_shard(shard_id)
+                    o = fallback_fn(q, ident)
+                    if o is not None:
+                        return o
+                else:
+                    return None
+
+        return super(ShardedQuery, self)._get_impl(ident, _fallback)
 
 
 class ShardedSession(Session):
index d2fcfbab85c60487ebddbf7dd53aff5d98846e01..47da6d0edfbcba117a47c7e7538ef09b7c431840 100644 (file)
@@ -13,6 +13,7 @@ from sqlalchemy.orm import exc as orm_exc
 import itertools
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertsql import CompiledSQL
+import contextlib
 
 
 class BakedTest(_fixtures.FixtureTest):
@@ -375,6 +376,71 @@ class LikeQueryTest(BakedTest):
         eq_(len(bq._bakery), 4)
 
 
+class ResultPostCriteriaTest(BakedTest):
+
+    @classmethod
+    def setup_mappers(cls):
+        User = cls.classes.User
+        Address = cls.classes.Address
+        Order = cls.classes.Order
+
+        mapper(User, cls.tables.users, properties={
+            "addresses": relationship(
+                Address, order_by=cls.tables.addresses.c.id),
+            "orders": relationship(
+                Order, order_by=cls.tables.orders.c.id)
+        })
+        mapper(Address, cls.tables.addresses)
+        mapper(Order, cls.tables.orders)
+
+    @contextlib.contextmanager
+    def _fixture(self):
+        from sqlalchemy import event
+        User = self.classes.User
+
+        with testing.db.connect() as conn:
+            @event.listens_for(conn, "before_execute")
+            def before_execute(conn, clauseelement, multiparams, params):
+                assert "yes" in conn._execution_options
+
+            bq = self.bakery(
+                lambda s: s.query(User.id).order_by(User.id))
+
+            sess = Session(conn)
+
+            yield sess, bq
+
+    def test_first(self):
+        with self._fixture() as (sess, bq):
+            result = bq(sess).with_post_criteria(
+                lambda q: q.execution_options(yes=True))
+            eq_(result.first(), (7, ))
+
+    def test_iter(self):
+        with self._fixture() as (sess, bq):
+            result = bq(sess).with_post_criteria(
+                lambda q: q.execution_options(yes=True))
+            eq_(list(result)[0], (7, ))
+
+    def test_spoiled(self):
+        with self._fixture() as (sess, bq):
+
+            result = bq.spoil()(sess).with_post_criteria(
+                lambda q: q.execution_options(yes=True))
+
+            eq_(list(result)[0], (7, ))
+
+    def test_get(self):
+        User = self.classes.User
+        with self._fixture() as (sess, bq):
+            bq = self.bakery(
+                lambda s: s.query(User))
+
+            result = bq(sess).with_post_criteria(
+                lambda q: q.execution_options(yes=True))
+            eq_(result.get(7), User(id=7))
+
+
 class ResultTest(BakedTest):
     __backend__ = True
 
index 2a596d8c0358ddd4908ddd9a500f8c2a9d25902a..79487b2a7975d4287682db99438c41888d9dc4e2 100644 (file)
@@ -186,6 +186,54 @@ class ShardTest(object):
         eq_(set([c.city for c in asia_and_europe]), set(['Tokyo',
             'London', 'Dublin']))
 
+    def test_get_baked_query(self):
+        sess = self._fixture_data()
+
+        tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one()
+        tokyo.city
+        sess.expunge_all()
+
+        from sqlalchemy.ext.baked import BakedQuery
+
+        bakery = BakedQuery.bakery()
+
+        bq = bakery(lambda session: session.query(WeatherLocation))
+        t = bq(sess).get(tokyo.id)
+        eq_(t.city, tokyo.city)
+
+    def test_get_baked_query_shard_id(self):
+        sess = self._fixture_data()
+
+        tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one()
+        tokyo.city
+        sess.expunge_all()
+
+        from sqlalchemy.ext.baked import BakedQuery
+
+        bakery = BakedQuery.bakery()
+
+        bq = bakery(lambda session: session.query(WeatherLocation))
+        t = bq(sess).with_post_criteria(
+            lambda q: q.set_shard("asia")).get(tokyo.id)
+        eq_(t.city, tokyo.city)
+
+    def test_filter_baked_query_shard_id(self):
+        sess = self._fixture_data()
+
+        tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one()
+        tokyo.city
+        sess.expunge_all()
+
+        from sqlalchemy.ext.baked import BakedQuery
+
+        bakery = BakedQuery.bakery()
+
+        bq = bakery(lambda session: session.query(WeatherLocation)).\
+            with_criteria(lambda q: q.filter_by(id=tokyo.id))
+        t = bq(sess).with_post_criteria(
+            lambda q: q.set_shard("asia")).one()
+        eq_(t.city, tokyo.city)
+
     def test_shard_id_event(self):
         canary = []