]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
honor NO_CACHE in lambdas
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 15 Aug 2021 22:12:42 +0000 (18:12 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Aug 2021 18:30:24 +0000 (14:30 -0400)
Fixed issue in lambda caching system where an element of a query that
produces no cache key, like a custom option or clause element, would still
populate the expression in the "lambda cache" inappropriately.

This was discovered as part of :ticket:`6887` but is a separate
issue.

References: #6887
Change-Id: I1665f4320254ddc63a0abf3088e9daeaffbd1840

doc/build/changelog/unreleased_14/lmb_no_cache.rst [new file with mode: 0644]
lib/sqlalchemy/sql/lambdas.py
test/sql/test_lambdas.py

diff --git a/doc/build/changelog/unreleased_14/lmb_no_cache.rst b/doc/build/changelog/unreleased_14/lmb_no_cache.rst
new file mode 100644 (file)
index 0000000..4f6e193
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, sql
+
+    Fixed issue in lambda caching system where an element of a query that
+    produces no cache key, like a custom option or clause element, would still
+    populate the expression in the "lambda cache" inappropriately.
index d33e8ebfb1f0fc3ea638f95762c918cd77e491d9..36e470ce7c06c08c0a92e22c842c60d078d41974 100644 (file)
@@ -182,28 +182,49 @@ class LambdaElement(elements.ClauseElement):
 
         self._resolved_bindparams = bindparams = []
 
-        anon_map = traversals.anon_map()
-        cache_key = tuple(
-            [
-                getter(closure, opts, anon_map, bindparams)
-                for getter in tracker.closure_trackers
-            ]
-        )
-
         if self.parent_lambda is not None:
-            cache_key = self.parent_lambda.closure_cache_key + cache_key
+            parent_closure_cache_key = self.parent_lambda.closure_cache_key
+        else:
+            parent_closure_cache_key = ()
+
+        if parent_closure_cache_key is not traversals.NO_CACHE:
+            anon_map = traversals.anon_map()
+            cache_key = tuple(
+                [
+                    getter(closure, opts, anon_map, bindparams)
+                    for getter in tracker.closure_trackers
+                ]
+            )
 
-        self.closure_cache_key = cache_key
+            if traversals.NO_CACHE not in anon_map:
+                cache_key = parent_closure_cache_key + cache_key
 
-        try:
-            rec = lambda_cache[tracker_key + cache_key]
-        except KeyError:
+                self.closure_cache_key = cache_key
+
+                try:
+                    rec = lambda_cache[tracker_key + cache_key]
+                except KeyError:
+                    rec = None
+            else:
+                cache_key = traversals.NO_CACHE
+                rec = None
+
+        else:
+            cache_key = traversals.NO_CACHE
             rec = None
 
+        self.closure_cache_key = cache_key
+
         if rec is None:
-            rec = AnalyzedFunction(tracker, self, apply_propagate_attrs, fn)
-            rec.closure_bindparams = bindparams
-            lambda_cache[tracker_key + cache_key] = rec
+            if cache_key is not traversals.NO_CACHE:
+                rec = AnalyzedFunction(
+                    tracker, self, apply_propagate_attrs, fn
+                )
+                rec.closure_bindparams = bindparams
+                lambda_cache[tracker_key + cache_key] = rec
+            else:
+                rec = NonAnalyzedFunction(self._invoke_user_fn(fn))
+
         else:
             bindparams[:] = [
                 orig_bind._with_value(new_bind.value, maintain_key=True)
@@ -212,21 +233,24 @@ class LambdaElement(elements.ClauseElement):
                 )
             ]
 
-        if self.parent_lambda is not None:
-            bindparams[:0] = self.parent_lambda._resolved_bindparams
-
         self._rec = rec
 
-        lambda_element = self
-        while lambda_element is not None:
-            rec = lambda_element._rec
-            if rec.bindparam_trackers:
-                tracker_instrumented_fn = rec.tracker_instrumented_fn
-                for tracker in rec.bindparam_trackers:
-                    tracker(
-                        lambda_element.fn, tracker_instrumented_fn, bindparams
-                    )
-            lambda_element = lambda_element.parent_lambda
+        if cache_key is not traversals.NO_CACHE:
+            if self.parent_lambda is not None:
+                bindparams[:0] = self.parent_lambda._resolved_bindparams
+
+            lambda_element = self
+            while lambda_element is not None:
+                rec = lambda_element._rec
+                if rec.bindparam_trackers:
+                    tracker_instrumented_fn = rec.tracker_instrumented_fn
+                    for tracker in rec.bindparam_trackers:
+                        tracker(
+                            lambda_element.fn,
+                            tracker_instrumented_fn,
+                            bindparams,
+                        )
+                lambda_element = lambda_element.parent_lambda
 
         return rec
 
@@ -304,6 +328,9 @@ class LambdaElement(elements.ClauseElement):
         return expr
 
     def _gen_cache_key(self, anon_map, bindparams):
+        if self.closure_cache_key is traversals.NO_CACHE:
+            anon_map[traversals.NO_CACHE] = True
+            return None
 
         cache_key = (
             self.fn.__code__,
@@ -914,6 +941,20 @@ class AnalyzedCode(object):
         )
 
 
+class NonAnalyzedFunction(object):
+    __slots__ = ("expr",)
+
+    closure_bindparams = None
+    bindparam_trackers = None
+
+    def __init__(self, expr):
+        self.expr = expr
+
+    @property
+    def expected_expr(self):
+        return self.expr
+
+
 class AnalyzedFunction(object):
     __slots__ = (
         "analyzed_code",
index 51530b0791ccd12f41ad8380d4187ebf9b384a0d..2e794d7bcf9640545906f86690e641f5657b84fe 100644 (file)
@@ -17,6 +17,7 @@ from sqlalchemy.sql import roles
 from sqlalchemy.sql import select
 from sqlalchemy.sql import table
 from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql.traversals import HasCacheKey
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
@@ -953,6 +954,63 @@ class LambdaElementTest(
         eq_(s1key.key, s2key.key)
         ne_(s1key.key, s3key.key)
 
+    def test_stmt_lambda_opt_w_key(self):
+        """test issue related to #6887"""
+
+        def go(opts):
+            stmt = lambdas.lambda_stmt(lambda: select(column("x")))
+            stmt += lambda stmt: stmt.options(*opts)
+
+            return stmt
+
+        class SomeOpt(HasCacheKey):
+            def _gen_cache_key(self, anon_map, bindparams):
+                return ("fixed_key",)
+
+        # generates no key, will not be cached
+        eq_(SomeOpt()._generate_cache_key().key, ("fixed_key",))
+
+        s1o, s2o = SomeOpt(), SomeOpt()
+        s1 = go([s1o])
+        s2 = go([s2o])
+
+        s1key = s1._generate_cache_key()
+        s2key = s2._generate_cache_key()
+
+        eq_(s1key.key[-1], (("fixed_key",),))
+        eq_(s1key.key, s2key.key)
+
+        eq_(s1._resolved._with_options, (s1o,))
+        eq_(s2._resolved._with_options, (s1o,))
+        ne_(s2._resolved._with_options, (s2o,))
+
+    def test_stmt_lambda_opt_w_no_key(self):
+        """test issue related to #6887"""
+
+        def go(opts):
+            stmt = lambdas.lambda_stmt(lambda: select(column("x")))
+            stmt += lambda stmt: stmt.options(*opts)
+
+            return stmt
+
+        class SomeOpt(HasCacheKey):
+            pass
+
+        # generates no key, will not be cached
+        eq_(SomeOpt()._generate_cache_key(), None)
+
+        s1o, s2o = SomeOpt(), SomeOpt()
+        s1 = go([s1o])
+        s2 = go([s2o])
+
+        s1key = s1._generate_cache_key()
+
+        eq_(s1key, None)
+
+        eq_(s1._resolved._with_options, (s1o,))
+        eq_(s2._resolved._with_options, (s2o,))
+        ne_(s2._resolved._with_options, (s1o,))
+
     def test_stmt_lambda_hey_theres_multiple_paths(self):
         def go(x, y):
             stmt = lambdas.lambda_stmt(lambda: select(column("x")))