]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Include Session._query_cls as part of the cache key
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 26 Aug 2018 16:35:59 +0000 (12:35 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 26 Aug 2018 19:49:06 +0000 (15:49 -0400)
Fixed issue where :class:`.BakedQuery` did not include the specific query
class used by the :class:`.Session` as part of the cache key, leading to
incompatibilities when using custom query classes, in particular the
:class:`.ShardedQuery` which has some different argument signatures.

Fixes: #4328
Change-Id: I829c2a8b09c91e91c8dc8ea5476c0d7aa47028bd

doc/build/changelog/unreleased_12/4328.rst [new file with mode: 0644]
lib/sqlalchemy/ext/baked.py
test/ext/test_horizontal_shard.py

diff --git a/doc/build/changelog/unreleased_12/4328.rst b/doc/build/changelog/unreleased_12/4328.rst
new file mode 100644 (file)
index 0000000..fd63ce5
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, ext
+    :tickets: 4328
+
+    Fixed issue where :class:`.BakedQuery` did not include the specific query
+    class used by the :class:`.Session` as part of the cache key, leading to
+    incompatibilities when using custom query classes, in particular the
+    :class:`.ShardedQuery` which has some different argument signatures.
index addec90da55dfb1a2487dd5982ec2e0c988745e6..e605fb1f99d6863f693172ba7110c8104625df20 100644 (file)
@@ -154,6 +154,19 @@ class BakedQuery(object):
         self._spoiled = True
         return self
 
+    def _effective_key(self, session):
+        """Return the key that actually goes into the cache dictionary for
+        this :class:`.BakedQuery`, taking into account the given
+        :class:`.Session`.
+
+        This basically means we also will include the session's query_class,
+        as the actual :class:`.Query` object is part of what's cached
+        and needs to match the type of :class:`.Query` that a later
+        session will want to use.
+
+        """
+        return self._cache_key + (session._query_cls, )
+
     def _with_lazyload_options(self, options, effective_path, cache_path=None):
         """Cloning version of _add_lazyload_options.
         """
@@ -195,10 +208,10 @@ class BakedQuery(object):
         )
 
     def _retrieve_baked_query(self, session):
-        query = self._bakery.get(self._cache_key, None)
+        query = self._bakery.get(self._effective_key(session), None)
         if query is None:
             query = self._as_query(session)
-            self._bakery[self._cache_key] = query.with_session(None)
+            self._bakery[self._effective_key(session)] = query.with_session(None)
         return query.with_session(session)
 
     def _bake(self, session):
@@ -218,7 +231,7 @@ class BakedQuery(object):
                 '_correlate', '_from_obj', '_mapper_adapter_map',
                 '_joinpath', '_joinpoint'):
             query.__dict__.pop(attr, None)
-        self._bakery[self._cache_key] = context
+        self._bakery[self._effective_key(session)] = context
         return context
 
     def _as_query(self, session):
@@ -332,7 +345,7 @@ class Result(object):
         if not self.session.enable_baked_queries or bq._spoiled:
             return iter(self._as_query())
 
-        baked_context = bq._bakery.get(bq._cache_key, None)
+        baked_context = bq._bakery.get(bq._effective_key(self.session), None)
         if baked_context is None:
             baked_context = bq._bake(self.session)
 
index e270dd3536069345dae15780c2bb7ecb2bb64b92..2773be7d533a8abaf80a823ad6a26890293b0112 100644 (file)
@@ -295,6 +295,32 @@ class ShardTest(object):
              'south_america']
         )
 
+    def test_baked_mix(self):
+        sess = self._fixture_data()
+
+        tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one()
+        tokyo.city
+        sess.expunge_all()
+
+        from sqlalchemy.ext.baked import BakedQuery
+
+        bakery = BakedQuery.bakery()
+
+        def get_tokyo(sess):
+            bq = bakery(lambda session: session.query(WeatherLocation))
+            t = bq(sess).get(tokyo.id)
+            return t
+
+        Sess = sessionmaker(class_=Session, bind=db2,
+                                      autoflush=True, autocommit=False)
+        sess2 = Sess()
+
+        t = get_tokyo(sess)
+        eq_(t.city, tokyo.city)
+
+        t = get_tokyo(sess2)
+        eq_(t.city, tokyo.city)
+
 
 from sqlalchemy.testing import provision