]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Robustness for lambdas, lambda statements
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 5 Aug 2020 20:42:26 +0000 (16:42 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 5 Aug 2020 20:42:26 +0000 (16:42 -0400)
in order to accommodate relationship loaders
with lambda caching, a lot more is needed.  This is
a full refactor of the lambda system such that it
now has two levels of caching; the first level caches what
can be known from the __code__ element, then the next level
of caching is against the lambda itself and the contents
of __closure__.  This allows for the elements inside
the lambdas, like columns and entities, to change and
then be part of the cache key.  Lazy/selectinloads' use of
baked queries had to add distinct cache key elements,
which was attempted here but overall things needed to be
more robust than that.

This commit is broken out from the very long and sprawling
commit at Id6b5c03b1ce9ddb7b280f66792212a0ef0a1c541 .

Change-Id: I29a513c98917b1d503abfdd61e6b6e8800851aa8

15 files changed:
lib/sqlalchemy/engine/create.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/lambdas.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/traversals.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/_collections.py
test/sql/test_compare.py
test/sql/test_external_traversal.py
test/sql/test_lambdas.py
test/sql/test_selectable.py

index 8b0377a58267704fedcac69326f80b33936d8d9f..985a12fa00d93e1b0f1251ce37a9e465929047a5 100644 (file)
@@ -230,7 +230,7 @@ def create_engine(url, **kwargs):
     :param future: Use the 2.0 style :class:`_future.Engine` and
         :class:`_future.Connection` API.
 
-        ..versionadded:: 1.4
+        .. versionadded:: 1.4
 
         .. seealso::
 
index 6f4934521ce01d630002f15ce522dc83b3bba75f..bcffca9324e2d97eb3550222d64d497700be1ef4 100644 (file)
@@ -1247,7 +1247,7 @@ class BaseCursorResult(object):
             if (
                 compiled
                 and compiled._result_columns
-                and context.cache_hit
+                and context.cache_hit is context.dialect.CACHE_HIT
                 and not compiled._rewrites_selected_columns
                 and compiled.statement is not context.invoked_statement
             ):
index c431fa7555f3da2484ce82f4551a149e1648c4b9..8d3c5de1573ae325a7064749b9ad065da28a0c13 100644 (file)
@@ -39,6 +39,12 @@ AUTOCOMMIT_REGEXP = re.compile(
 SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE)
 
 
+CACHE_HIT = util.symbol("CACHE_HIT")
+CACHE_MISS = util.symbol("CACHE_MISS")
+CACHING_DISABLED = util.symbol("CACHING_DISABLED")
+NO_CACHE_KEY = util.symbol("NO_CACHE_KEY")
+
+
 class DefaultDialect(interfaces.Dialect):
     """Default implementation of Dialect"""
 
@@ -195,6 +201,11 @@ class DefaultDialect(interfaces.Dialect):
 
     """
 
+    CACHE_HIT = CACHE_HIT
+    CACHE_MISS = CACHE_MISS
+    CACHING_DISABLED = CACHING_DISABLED
+    NO_CACHE_KEY = NO_CACHE_KEY
+
     @util.deprecated_params(
         convert_unicode=(
             "1.3",
@@ -725,6 +736,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
 
     _expanded_parameters = util.immutabledict()
 
+    cache_hit = NO_CACHE_KEY
+
     @classmethod
     def _init_ddl(
         cls,
@@ -788,7 +801,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         parameters,
         invoked_statement,
         extracted_parameters,
-        cache_hit=False,
+        cache_hit=CACHING_DISABLED,
     ):
         """Initialize execution context for a Compiled construct."""
 
@@ -1026,12 +1039,19 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
             return "raw sql"
 
         now = util.perf_counter()
-        if self.compiled.cache_key is None:
+
+        ch = self.cache_hit
+
+        if ch is NO_CACHE_KEY:
             return "no key %.5fs" % (now - self.compiled._gen_time,)
-        elif self.cache_hit:
+        elif ch is CACHE_HIT:
             return "cached since %.4gs ago" % (now - self.compiled._gen_time,)
-        else:
+        elif ch is CACHE_MISS:
             return "generated in %.5fs" % (now - self.compiled._gen_time,)
+        elif ch is CACHING_DISABLED:
+            return "caching disabled %.5fs" % (now - self.compiled._gen_time,)
+        else:
+            return "unknown"
 
     @util.memoized_property
     def engine(self):
index 588c485aeba249fe61dca05aede97327ce46e125..fa0f9c4357332fefcb2b20ee45579798749a2c62 100644 (file)
@@ -7,10 +7,13 @@
 
 import numbers
 import re
+import types
 
 from . import operators
 from . import roles
 from . import visitors
+from .base import Options
+from .traversals import HasCacheKey
 from .visitors import Visitable
 from .. import exc
 from .. import inspection
@@ -33,11 +36,36 @@ def _is_literal(element):
     of a SQL expression construct.
 
     """
+
     return not isinstance(
-        element, (Visitable, schema.SchemaEventTarget)
+        element, (Visitable, schema.SchemaEventTarget),
     ) and not hasattr(element, "__clause_element__")
 
 
+def _deep_is_literal(element):
+    """Return whether or not the element is a "literal" in the context
+    of a SQL expression construct.
+
+    does a deeper more esoteric check than _is_literal.   is used
+    for lambda elements that have to distinguish values that would
+    be bound vs. not without any context.
+
+    """
+
+    return (
+        not isinstance(
+            element,
+            (Visitable, schema.SchemaEventTarget, HasCacheKey, Options,),
+        )
+        and not hasattr(element, "__clause_element__")
+        and (
+            not isinstance(element, type)
+            or not issubclass(element, HasCacheKey)
+        )
+        and not isinstance(element, types.FunctionType)
+    )
+
+
 def _document_text_coercion(paramname, meth_rst, param_rst):
     return util.add_parameter_text(
         paramname,
@@ -711,9 +739,16 @@ class StatementImpl(_NoTextCoercion, RoleImpl):
 class CoerceTextStatementImpl(_CoerceLiterals, RoleImpl):
     __slots__ = ()
 
-    def _literal_coercion(self, element, **kw):
+    def _dont_literal_coercion(self, element, **kw):
         if callable(element) and hasattr(element, "__code__"):
-            return lambdas.StatementLambdaElement(element, self._role_class)
+            return lambdas.StatementLambdaElement(
+                element,
+                self._role_class,
+                additional_cache_criteria=kw.get(
+                    "additional_cache_criteria", ()
+                ),
+                tracked=kw["tra"],
+            )
         else:
             return super(CoerceTextStatementImpl, self)._literal_coercion(
                 element, **kw
index ca73a439246d255c58a158b00eeda213fcc6060f..8a506446db3f52a2743944491eaa8a39f7f9772c 100644 (file)
@@ -32,8 +32,8 @@ from .base import NO_ARG
 from .base import PARSE_AUTOCOMMIT
 from .base import SingletonConstant
 from .coercions import _document_text_coercion
-from .traversals import _copy_internals
 from .traversals import _get_children
+from .traversals import HasCopyInternals
 from .traversals import MemoizedHasCacheKey
 from .traversals import NO_CACHE
 from .visitors import cloned_traverse
@@ -182,6 +182,7 @@ class ClauseElement(
     roles.SQLRole,
     SupportsWrappingAnnotations,
     MemoizedHasCacheKey,
+    HasCopyInternals,
     Traversible,
 ):
     """Base class for elements of a programmatically constructed SQL
@@ -372,35 +373,6 @@ class ClauseElement(
         """
         return traversals.compare(self, other, **kw)
 
-    def _copy_internals(self, omit_attrs=(), **kw):
-        """Reassign internal elements to be clones of themselves.
-
-        Called during a copy-and-traverse operation on newly
-        shallow-copied elements to create a deep copy.
-
-        The given clone function should be used, which may be applying
-        additional transformations to the element (i.e. replacement
-        traversal, cloned traversal, annotations).
-
-        """
-
-        try:
-            traverse_internals = self._traverse_internals
-        except AttributeError:
-            # user-defined classes may not have a _traverse_internals
-            return
-
-        for attrname, obj, meth in _copy_internals.run_generated_dispatch(
-            self, traverse_internals, "_generated_copy_internals_traversal"
-        ):
-            if attrname in omit_attrs:
-                continue
-
-            if obj is not None:
-                result = meth(self, attrname, obj, **kw)
-                if result is not None:
-                    setattr(self, attrname, result)
-
     def get_children(self, omit_attrs=(), **kw):
         r"""Return immediate child :class:`.visitors.Traversible`
         elements of this :class:`.visitors.Traversible`.
@@ -535,8 +507,6 @@ class ClauseElement(
         else:
             elem_cache_key = None
 
-        cache_hit = False
-
         if elem_cache_key:
             cache_key, extracted_params = elem_cache_key
             key = (
@@ -549,6 +519,7 @@ class ClauseElement(
             compiled_sql = compiled_cache.get(key)
 
             if compiled_sql is None:
+                cache_hit = dialect.CACHE_MISS
                 compiled_sql = self._compiler(
                     dialect,
                     cache_key=elem_cache_key,
@@ -559,7 +530,7 @@ class ClauseElement(
                 )
                 compiled_cache[key] = compiled_sql
             else:
-                cache_hit = True
+                cache_hit = dialect.CACHE_HIT
         else:
             extracted_params = None
             compiled_sql = self._compiler(
@@ -570,6 +541,11 @@ class ClauseElement(
                 schema_translate_map=schema_translate_map,
                 **kw
             )
+            cache_hit = (
+                dialect.CACHING_DISABLED
+                if compiled_cache is None
+                else dialect.NO_CACHE_KEY
+            )
 
         return compiled_sql, extracted_params, cache_hit
 
@@ -1343,10 +1319,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
         if required is NO_ARG:
             required = value is NO_ARG and callable_ is None
         if value is NO_ARG:
-            self._value_required_for_cache = False
             value = None
-        else:
-            self._value_required_for_cache = True
 
         if quote is not None:
             key = quoted_name(key, quote)
@@ -1412,6 +1385,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
         """Return a copy of this :class:`.BindParameter` with the given value
         set.
         """
+
         cloned = self._clone(maintain_key=maintain_key)
         cloned.value = value
         cloned.callable = None
@@ -1465,7 +1439,8 @@ class BindParameter(roles.InElementRole, ColumnElement):
             anon_map[idself] = id_ = str(anon_map.index)
             anon_map.index += 1
 
-        bindparams.append(self)
+        if bindparams is not None:
+            bindparams.append(self)
 
         return (
             id_,
index 79241118967f87b5f6b2d656a965b113a3f4b1fe..3270039026075553be31d2488ed99f02b5a07786 100644 (file)
@@ -8,6 +8,7 @@
 import itertools
 import operator
 import sys
+import types
 import weakref
 
 from . import coercions
@@ -17,29 +18,23 @@ from . import schema
 from . import traversals
 from . import type_api
 from . import visitors
+from .base import _clone
 from .operators import ColumnOperators
 from .. import exc
 from .. import inspection
 from .. import util
 from ..util import collections_abc
 
-_trackers = weakref.WeakKeyDictionary()
+_closure_per_cache_key = util.LRUCache(1000)
 
 
-_TRACKERS = 0
-_STALE_CHECK = 1
-_REAL_FN = 2
-_EXPR = 3
-_IS_SEQUENCE = 4
-_PROPAGATE_ATTRS = 5
-
-
-def lambda_stmt(lmb):
+def lambda_stmt(lmb, **opts):
     """Produce a SQL statement that is cached as a lambda.
 
-    This SQL statement will only be constructed if element has not been
-    compiled yet.    The approach is used to save on Python function overhead
-    when constructing statements that will be cached.
+    The Python code object within the lambda is scanned for both Python
+    literals that will become bound parameters as well as closure variables
+    that refer to Core or ORM constructs that may vary.   The lambda itself
+    will be invoked only once per particular set of constructs detected.
 
     E.g.::
 
@@ -60,7 +55,8 @@ def lambda_stmt(lmb):
 
 
     """
-    return coercions.expect(roles.CoerceTextStatementRole, lmb)
+
+    return StatementLambdaElement(lmb, roles.CoerceTextStatementRole, **opts)
 
 
 class LambdaElement(elements.ClauseElement):
@@ -87,64 +83,108 @@ class LambdaElement(elements.ClauseElement):
 
     _is_lambda_element = True
 
-    _resolved_bindparams = ()
-
     _traverse_internals = [
         ("_resolved", visitors.InternalTraversal.dp_clauseelement)
     ]
 
+    _transforms = ()
+
+    parent_lambda = None
+
     def __repr__(self):
         return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
 
     def __init__(self, fn, role, apply_propagate_attrs=None, **kw):
         self.fn = fn
         self.role = role
-        self.parent_lambda = None
+        self.tracker_key = (fn.__code__,)
 
         if apply_propagate_attrs is None and (
             role is roles.CoerceTextStatementRole
         ):
             apply_propagate_attrs = self
 
-        if fn.__code__ not in _trackers:
-            rec = self._initialize_var_trackers(
-                role, apply_propagate_attrs, kw
+        rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, kw)
+
+        if apply_propagate_attrs is not None:
+            propagate_attrs = rec.propagate_attrs
+            if propagate_attrs:
+                apply_propagate_attrs._propagate_attrs = propagate_attrs
+
+    def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, kw):
+        lambda_cache = kw.get("lambda_cache", _closure_per_cache_key)
+
+        tracker_key = self.tracker_key
+
+        fn = self.fn
+        closure = fn.__closure__
+
+        tracker = AnalyzedCode.get(
+            fn,
+            self,
+            kw,
+            track_bound_values=kw.get("track_bound_values", True),
+            enable_tracking=kw.get("enable_tracking", True),
+            track_on=kw.get("track_on", None),
+        )
+
+        self._resolved_bindparams = bindparams = []
+
+        anon_map = traversals.anon_map()
+        cache_key = tuple(
+            [
+                getter(closure, kw, anon_map, bindparams)
+                for getter in tracker.closure_trackers
+            ]
+        )
+        if self.parent_lambda is not None:
+            cache_key = self.parent_lambda.closure_cache_key + cache_key
+
+        self.closure_cache_key = cache_key
+
+        try:
+            rec = lambda_cache[tracker_key + cache_key]
+        except KeyError:
+            rec = None
+
+        if rec is None:
+            rec = AnalyzedFunction(
+                tracker, self, apply_propagate_attrs, kw, fn
             )
+            rec.closure_bindparams = bindparams
+            lambda_cache[tracker_key + cache_key] = rec
         else:
-            rec = _trackers[self.fn.__code__]
-            closure = fn.__closure__
+            bindparams[:] = [
+                orig_bind._with_value(new_bind.value, maintain_key=True)
+                for orig_bind, new_bind in zip(
+                    rec.closure_bindparams, bindparams
+                )
+            ]
+
+        if self.parent_lambda is not None:
+            bindparams[:0] = self.parent_lambda._resolved_bindparams
 
-            # check if the objects fixed inside the lambda that we've cached
-            # have been changed.   This can apply to things like mappers that
-            # were recreated in test suites.    if so, re-initialize.
-            #
-            # this is a small performance hit on every use for a not very
-            # common situation, however it's very hard to debug if the
-            # condition does occur.
-            for idx, obj in rec[_STALE_CHECK]:
-                if closure[idx].cell_contents is not obj:
-                    rec = self._initialize_var_trackers(
-                        role, apply_propagate_attrs, kw
-                    )
-                    break
         self._rec = rec
 
-        if apply_propagate_attrs is not None:
-            propagate_attrs = rec[_PROPAGATE_ATTRS]
-            if propagate_attrs:
-                apply_propagate_attrs._propagate_attrs = propagate_attrs
+        lambda_element = self
+        while lambda_element is not None:
+            rec = lambda_element._rec
+            if rec.bindparam_trackers:
+                tracker_instrumented_fn = rec.tracker_instrumented_fn
+                for tracker in rec.bindparam_trackers:
+                    tracker(
+                        lambda_element.fn, tracker_instrumented_fn, bindparams
+                    )
+            lambda_element = lambda_element.parent_lambda
 
-        if rec[_TRACKERS]:
-            self._resolved_bindparams = bindparams = []
-            for tracker in rec[_TRACKERS]:
-                tracker(self.fn, bindparams)
+        return rec
 
     def __getattr__(self, key):
-        return getattr(self._rec[_EXPR], key)
+        return getattr(self._rec.expected_expr, key)
 
     @property
     def _is_sequence(self):
-        return self._rec[_IS_SEQUENCE]
+        return self._rec.is_sequence
 
     @property
     def _select_iterable(self):
@@ -169,8 +209,7 @@ class LambdaElement(elements.ClauseElement):
     def _param_dict(self):
         return {b.key: b.value for b in self._resolved_bindparams}
 
-    @util.memoized_property
-    def _resolved(self):
+    def _setup_binds_for_tracked_expr(self, expr):
         bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
 
         def replace(thing):
@@ -179,17 +218,11 @@ class LambdaElement(elements.ClauseElement):
                 and thing.key in bindparam_lookup
             ):
                 bind = bindparam_lookup[thing.key]
-                # TODO: consider
-                # if we should clone the bindparam here, re-cache the new
-                # version, etc.  also we make an assumption about "expanding"
-                # in this case.
                 if thing.expanding:
                     bind.expanding = True
                 return bind
 
-        expr = self._rec[_EXPR]
-
-        if self._rec[_IS_SEQUENCE]:
+        if self._rec.is_sequence:
             expr = [
                 visitors.replacement_traverse(sub_expr, {}, replace)
                 for sub_expr in expr
@@ -199,9 +232,39 @@ class LambdaElement(elements.ClauseElement):
 
         return expr
 
+    def _copy_internals(
+        self, clone=_clone, deferred_copy_internals=None, **kw
+    ):
+        # TODO: this needs A LOT of tests
+        self._resolved = clone(
+            self._resolved,
+            deferred_copy_internals=deferred_copy_internals,
+            **kw
+        )
+
+    @util.memoized_property
+    def _resolved(self):
+        expr = self._rec.expected_expr
+
+        if self._resolved_bindparams:
+            expr = self._setup_binds_for_tracked_expr(expr)
+
+        return expr
+
     def _gen_cache_key(self, anon_map, bindparams):
 
-        cache_key = (self.fn.__code__, self.__class__)
+        cache_key = (
+            self.fn.__code__,
+            self.__class__,
+        ) + self.closure_cache_key
+
+        parent = self.parent_lambda
+        while parent is not None:
+            cache_key = (
+                (parent.fn.__code__,) + parent.closure_cache_key + cache_key
+            )
+
+            parent = parent.parent_lambda
 
         if self._resolved_bindparams:
             bindparams.extend(self._resolved_bindparams)
@@ -211,101 +274,51 @@ class LambdaElement(elements.ClauseElement):
     def _invoke_user_fn(self, fn, *arg):
         return fn()
 
-    def _initialize_var_trackers(self, role, apply_propagate_attrs, coerce_kw):
-        fn = self.fn
 
-        # track objects referenced inside of lambdas, create bindparams
-        # ahead of time for literal values.   If bindparams are produced,
-        # then rewrite the function globals and closure as necessary so that
-        # it refers to the bindparams, then invoke the function
-        new_closure = {}
-        new_globals = fn.__globals__.copy()
-        tracker_collection = []
-        check_closure_for_stale = []
+class DeferredLambdaElement(LambdaElement):
+    """A LambdaElement where the lambda accepts arguments and is
+    invoked within the compile phase with special context.
 
-        for name in fn.__code__.co_names:
-            if name not in new_globals:
-                continue
+    This lambda doesn't normally produce its real SQL expression outside of the
+    compile phase.  It is passed a fixed set of initial arguments
+    so that it can generate a sample expression.
 
-            bound_value = _roll_down_to_literal(new_globals[name])
+    """
 
-            if coercions._is_literal(bound_value):
-                new_globals[name] = bind = PyWrapper(name, bound_value)
-                tracker_collection.append(_globals_tracker(name, bind))
+    def __init__(self, fn, role, lambda_args=(), **kw):
+        self.lambda_args = lambda_args
+        self.coerce_kw = kw
+        super(DeferredLambdaElement, self).__init__(fn, role, **kw)
 
-        if fn.__closure__:
-            for closure_index, (fv, cell) in enumerate(
-                zip(fn.__code__.co_freevars, fn.__closure__)
-            ):
+    def _invoke_user_fn(self, fn, *arg):
+        return fn(*self.lambda_args)
 
-                bound_value = _roll_down_to_literal(cell.cell_contents)
+    def _resolve_with_args(self, *lambda_args):
+        tracker_fn = self._rec.tracker_instrumented_fn
+        expr = tracker_fn(*lambda_args)
 
-                if coercions._is_literal(bound_value):
-                    new_closure[fv] = bind = PyWrapper(fv, bound_value)
-                    tracker_collection.append(
-                        _closure_tracker(fv, bind, closure_index)
-                    )
-                else:
-                    new_closure[fv] = cell.cell_contents
-                    # for normal cell contents, add them to a list that
-                    # we can compare later when we get new lambdas.  if
-                    # any identities have changed, then we will recalculate
-                    # the whole lambda and run it again.
-                    check_closure_for_stale.append(
-                        (closure_index, cell.cell_contents)
-                    )
+        expr = coercions.expect(self.role, expr, **self.coerce_kw)
 
-        if tracker_collection:
-            new_fn = _rewrite_code_obj(
-                fn,
-                [new_closure[name] for name in fn.__code__.co_freevars],
-                new_globals,
-            )
-            expr = self._invoke_user_fn(new_fn)
+        if self._resolved_bindparams:
+            expr = self._setup_binds_for_tracked_expr(expr)
 
-        else:
-            new_fn = fn
-            expr = self._invoke_user_fn(new_fn)
-            tracker_collection = []
+        # TODO: TEST TEST TEST, this is very out there
+        for deferred_copy_internals in self._transforms:
+            expr = deferred_copy_internals(expr)
 
-        if self.parent_lambda is None:
-            if isinstance(expr, collections_abc.Sequence):
-                expected_expr = [
-                    coercions.expect(
-                        role,
-                        sub_expr,
-                        apply_propagate_attrs=apply_propagate_attrs,
-                        **coerce_kw
-                    )
-                    for sub_expr in expr
-                ]
-                is_sequence = True
-            else:
-                expected_expr = coercions.expect(
-                    role,
-                    expr,
-                    apply_propagate_attrs=apply_propagate_attrs,
-                    **coerce_kw
-                )
-                is_sequence = False
-        else:
-            expected_expr = expr
-            is_sequence = False
+        return expr
 
-        if apply_propagate_attrs is not None:
-            propagate_attrs = apply_propagate_attrs._propagate_attrs
-        else:
-            propagate_attrs = util.immutabledict()
-
-        rec = _trackers[self.fn.__code__] = (
-            tracker_collection,
-            check_closure_for_stale,
-            new_fn,
-            expected_expr,
-            is_sequence,
-            propagate_attrs,
+    def _copy_internals(
+        self, clone=_clone, deferred_copy_internals=None, **kw
+    ):
+        super(DeferredLambdaElement, self)._copy_internals(
+            clone=clone, deferred_copy_internals=deferred_copy_internals, **kw
         )
-        return rec
+
+        # TODO: A LOT A LOT of tests.   for _resolve_with_args, we don't know
+        # our expression yet.   so hold onto the replacement
+        if deferred_copy_internals:
+            self._transforms += (deferred_copy_internals,)
 
 
 class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
@@ -334,13 +347,38 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
 
     """
 
+    def __init__(self, fn, parent_lambda, **kw):
+        self._default_kw = default_kw = {}
+        global_track_bound_values = kw.pop("global_track_bound_values", None)
+        if global_track_bound_values is not None:
+            default_kw["track_bound_values"] = global_track_bound_values
+            kw["track_bound_values"] = global_track_bound_values
+
+        if "lambda_cache" in kw:
+            default_kw["lambda_cache"] = kw["lambda_cache"]
+
+        super(StatementLambdaElement, self).__init__(fn, parent_lambda, **kw)
+
     def __add__(self, other):
-        return LinkedLambdaElement(other, parent_lambda=self)
+        return LinkedLambdaElement(
+            other, parent_lambda=self, **self._default_kw
+        )
+
+    def add_criteria(self, other, **kw):
+        if self._default_kw:
+            if kw:
+                default_kw = self._default_kw.copy()
+                default_kw.update(kw)
+                kw = default_kw
+            else:
+                kw = self._default_kw
+
+        return LinkedLambdaElement(other, parent_lambda=self, **kw)
 
     def _execute_on_connection(
         self, connection, multiparams, params, execution_options
     ):
-        if self._rec[_EXPR].supports_execution:
+        if self._rec.expected_expr.supports_execution:
             return connection._execute_clauseelement(
                 self, multiparams, params, execution_options
             )
@@ -349,93 +387,579 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
 
     @property
     def _with_options(self):
-        return self._rec[_EXPR]._with_options
+        return self._rec.expected_expr._with_options
 
     @property
     def _effective_plugin_target(self):
-        return self._rec[_EXPR]._effective_plugin_target
+        return self._rec.expected_expr._effective_plugin_target
 
     @property
     def _is_future(self):
-        return self._rec[_EXPR]._is_future
+        return self._rec.expected_expr._is_future
 
     @property
     def _execution_options(self):
-        return self._rec[_EXPR]._execution_options
+        return self._rec.expected_expr._execution_options
+
+    def spoil(self):
+        """Return a new :class:`.StatementLambdaElement` that will run
+        all lambdas unconditionally each time.
+
+        """
+        return NullLambdaStatement(self.fn())
+
+
+class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
+    """Provides the :class:`.StatementLambdaElement` API but does not
+    cache or analyze lambdas.
+
+    the lambdas are instead invoked immediately.
+
+    The intended use is to isolate issues that may arise when using
+    lambda statements.
+
+    """
+
+    __visit_name__ = "lambda_element"
+
+    _is_lambda_element = True
+
+    _traverse_internals = [
+        ("_resolved", visitors.InternalTraversal.dp_clauseelement)
+    ]
+
+    def __init__(self, statement):
+        self._resolved = statement
+        self._propagate_attrs = statement._propagate_attrs
+
+    def __getattr__(self, key):
+        return getattr(self._resolved, key)
+
+    def __add__(self, other):
+        statement = other(self._resolved)
+
+        return NullLambdaStatement(statement)
+
+    def add_criteria(self, other, **kw):
+        statement = other(self._resolved)
+
+        return NullLambdaStatement(statement)
+
+    def _execute_on_connection(
+        self, connection, multiparams, params, execution_options
+    ):
+        if self._resolved.supports_execution:
+            return connection._execute_clauseelement(
+                self, multiparams, params, execution_options
+            )
+        else:
+            raise exc.ObjectNotExecutableError(self)
 
 
 class LinkedLambdaElement(StatementLambdaElement):
+    """Represent subsequent links of a :class:`.StatementLambdaElement`."""
+
+    role = None
+
     def __init__(self, fn, parent_lambda, **kw):
+        self._default_kw = parent_lambda._default_kw
+
         self.fn = fn
         self.parent_lambda = parent_lambda
-        role = None
 
-        apply_propagate_attrs = self
+        self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
+        self._retrieve_tracker_rec(fn, self, kw)
+        self._propagate_attrs = parent_lambda._propagate_attrs
+
+    def _invoke_user_fn(self, fn, *arg):
+        return fn(self.parent_lambda._resolved)
+
+
+class AnalyzedCode(object):
+    __slots__ = (
+        "track_closure_variables",
+        "track_bound_values",
+        "bindparam_trackers",
+        "closure_trackers",
+        "build_py_wrappers",
+    )
+    _fns = weakref.WeakKeyDictionary()
+
+    @classmethod
+    def get(cls, fn, lambda_element, lambda_kw, **kw):
+        try:
+            # TODO: validate kw haven't changed?
+            return cls._fns[fn.__code__]
+        except KeyError:
+            pass
+        cls._fns[fn.__code__] = analyzed = AnalyzedCode(
+            fn, lambda_element, lambda_kw, **kw
+        )
+        return analyzed
+
+    def __init__(
+        self,
+        fn,
+        lambda_element,
+        lambda_kw,
+        track_bound_values=True,
+        enable_tracking=True,
+        track_on=None,
+    ):
+        closure = fn.__closure__
+
+        self.track_closure_variables = not track_on
+
+        self.track_bound_values = track_bound_values
+
+        # a list of callables generated from _bound_parameter_getter_*
+        # functions.  Each of these uses a PyWrapper object to retrieve
+        # a parameter value
+        self.bindparam_trackers = []
+
+        # a list of callables generated from _cache_key_getter_* functions
+        # these callables work to generate a cache key for the lambda
+        # based on what's inside its closure variables.
+        self.closure_trackers = []
 
-        if fn.__code__ not in _trackers:
-            rec = self._initialize_var_trackers(
-                role, apply_propagate_attrs, kw
+        self.build_py_wrappers = []
+
+        if enable_tracking:
+            if track_on:
+                self._init_track_on(track_on)
+
+            self._init_globals(fn)
+
+            if closure:
+                self._init_closure(fn)
+
+        self._setup_additional_closure_trackers(fn, lambda_element, lambda_kw)
+
+    def _init_track_on(self, track_on):
+        self.closure_trackers.extend(
+            self._cache_key_getter_track_on(idx, elem)
+            for idx, elem in enumerate(track_on)
+        )
+
+    def _init_globals(self, fn):
+        build_py_wrappers = self.build_py_wrappers
+        bindparam_trackers = self.bindparam_trackers
+        track_bound_values = self.track_bound_values
+
+        for name in fn.__code__.co_names:
+            if name not in fn.__globals__:
+                continue
+
+            _bound_value = self._roll_down_to_literal(fn.__globals__[name])
+
+            if coercions._deep_is_literal(_bound_value):
+                build_py_wrappers.append((name, None))
+                if track_bound_values:
+                    bindparam_trackers.append(
+                        self._bound_parameter_getter_func_globals(name)
+                    )
+
+    def _init_closure(self, fn):
+        build_py_wrappers = self.build_py_wrappers
+        closure = fn.__closure__
+
+        track_bound_values = self.track_bound_values
+        track_closure_variables = self.track_closure_variables
+        bindparam_trackers = self.bindparam_trackers
+        closure_trackers = self.closure_trackers
+
+        for closure_index, (fv, cell) in enumerate(
+            zip(fn.__code__.co_freevars, closure)
+        ):
+            _bound_value = self._roll_down_to_literal(cell.cell_contents)
+
+            if coercions._deep_is_literal(_bound_value):
+                build_py_wrappers.append((fv, closure_index))
+                if track_bound_values:
+                    bindparam_trackers.append(
+                        self._bound_parameter_getter_func_closure(
+                            fv, closure_index
+                        )
+                    )
+            else:
+                # for normal cell contents, add them to a list that
+                # we can compare later when we get new lambdas.  if
+                # any identities have changed, then we will
+                # recalculate the whole lambda and run it again.
+
+                if track_closure_variables:
+                    closure_trackers.append(
+                        self._cache_key_getter_closure_variable(
+                            closure_index, cell.cell_contents
+                        )
+                    )
+
+    def _setup_additional_closure_trackers(
+        self, fn, lambda_element, lambda_kw
+    ):
+        # an additional step is to actually run the function, then
+        # go through the PyWrapper objects that were set up to catch a bound
+        # parameter.   then if they *didn't* make a param, oh they're another
+        # object in the closure we have to track for our cache key.  so
+        # create trackers to catch those.
+
+        analyzed_function = AnalyzedFunction(
+            self, lambda_element, None, lambda_kw, fn,
+        )
+
+        closure_trackers = self.closure_trackers
+
+        for pywrapper in analyzed_function.closure_pywrappers:
+            if not pywrapper._sa__has_param:
+                closure_trackers.append(
+                    self._cache_key_getter_tracked_literal(pywrapper)
+                )
+
+    @classmethod
+    def _roll_down_to_literal(cls, element):
+        is_clause_element = hasattr(element, "__clause_element__")
+
+        if is_clause_element:
+            while not isinstance(
+                element, (elements.ClauseElement, schema.SchemaItem)
+            ):
+                try:
+                    element = element.__clause_element__()
+                except AttributeError:
+                    break
+
+        if not is_clause_element:
+            insp = inspection.inspect(element, raiseerr=False)
+            if insp is not None:
+                try:
+                    return insp.__clause_element__()
+                except AttributeError:
+                    return insp
+
+            # TODO: should we coerce consts None/True/False here?
+            return element
+        else:
+            return element
+
+    def _bound_parameter_getter_func_globals(self, name):
+        """Return a getter that will extend a list of bound parameters
+        with new entries from the ``__globals__`` collection of a particular
+        lambda.
+
+        """
+
+        def extract_parameter_value(
+            current_fn, tracker_instrumented_fn, result
+        ):
+            wrapper = tracker_instrumented_fn.__globals__[name]
+            object.__getattribute__(wrapper, "_extract_bound_parameters")(
+                current_fn.__globals__[name], result
+            )
+
+        return extract_parameter_value
+
+    def _bound_parameter_getter_func_closure(self, name, closure_index):
+        """Return a getter that will extend a list of bound parameters
+        with new entries from the ``__closure__`` collection of a particular
+        lambda.
+
+        """
+
+        def extract_parameter_value(
+            current_fn, tracker_instrumented_fn, result
+        ):
+            wrapper = tracker_instrumented_fn.__closure__[
+                closure_index
+            ].cell_contents
+            object.__getattribute__(wrapper, "_extract_bound_parameters")(
+                current_fn.__closure__[closure_index].cell_contents, result
+            )
+
+        return extract_parameter_value
+
+    def _cache_key_getter_track_on(self, idx, elem):
+        """Return a getter that will extend a cache key with new entries
+        from the "track_on" parameter passed to a :class:`.LambdaElement`.
+
+        """
+        if isinstance(elem, traversals.HasCacheKey):
+
+            def get(closure, kw, anon_map, bindparams):
+                return kw["track_on"][idx]._gen_cache_key(anon_map, bindparams)
+
+        else:
+
+            def get(closure, kw, anon_map, bindparams):
+                return kw["track_on"][idx]
+
+        return get
+
+    def _cache_key_getter_closure_variable(self, idx, cell_contents):
+        """Return a getter that will extend a cache key with new entries
+        from the ``__closure__`` collection of a particular lambda.
+
+        """
+
+        if isinstance(cell_contents, traversals.HasCacheKey):
+
+            def get(closure, kw, anon_map, bindparams):
+                return closure[idx].cell_contents._gen_cache_key(
+                    anon_map, bindparams
+                )
+
+        elif isinstance(cell_contents, types.FunctionType):
+
+            def get(closure, kw, anon_map, bindparams):
+                return closure[idx].cell_contents.__code__
+
+        elif cell_contents.__hash__ is None:
+            # this covers dict, etc.
+            def get(closure, kw, anon_map, bindparams):
+                return ()
+
+        else:
+
+            def get(closure, kw, anon_map, bindparams):
+                return closure[idx].cell_contents
+
+        return get
+
+    def _cache_key_getter_tracked_literal(self, pytracker):
+        """Return a getter that will extend a cache key with new entries
+        from the ``__closure__`` collection of a particular lambda.
+
+        this getter differs from _cache_key_getter_closure_variable
+        in that these are detected after the function is run, and PyWrapper
+        objects have recorded that a particular literal value is in fact
+        not being interpreted as a bound parameter.
+
+        """
+
+        elem = pytracker._sa__to_evaluate
+        closure_index = pytracker._sa__closure_index
+
+        if isinstance(elem, set):
+            raise exc.ArgumentError(
+                "Can't create a cache key for lambda closure variable "
+                '"%s" because it\'s a set.  try using a list'
+                % pytracker._sa__name
             )
+
+        elif isinstance(elem, list):
+
+            def get(closure, kw, anon_map, bindparams):
+                return tuple(
+                    elem._gen_cache_key(anon_map, bindparams)
+                    for elem in closure[closure_index].cell_contents
+                )
+
+        elif elem.__hash__ is None:
+            # this covers dict, etc.
+            def get(closure, kw, anon_map, bindparams):
+                return ()
+
         else:
-            rec = _trackers[self.fn.__code__]
 
+            def get(closure, kw, anon_map, bindparams):
+                return closure[closure_index].cell_contents
+
+        return get
+
+
+class AnalyzedFunction(object):
+    __slots__ = (
+        "analyzed_code",
+        "fn",
+        "closure_pywrappers",
+        "tracker_instrumented_fn",
+        "expr",
+        "bindparam_trackers",
+        "expected_expr",
+        "is_sequence",
+        "propagate_attrs",
+        "closure_bindparams",
+    )
+
+    def __init__(
+        self, analyzed_code, lambda_element, apply_propagate_attrs, kw, fn,
+    ):
+        self.analyzed_code = analyzed_code
+        self.fn = fn
+
+        self.bindparam_trackers = analyzed_code.bindparam_trackers
+
+        self._instrument_and_run_function(lambda_element)
+
+        self._coerce_expression(lambda_element, apply_propagate_attrs, kw)
+
+    def _instrument_and_run_function(self, lambda_element):
+        analyzed_code = self.analyzed_code
+
+        fn = self.fn
+        self.closure_pywrappers = closure_pywrappers = []
+
+        build_py_wrappers = analyzed_code.build_py_wrappers
+
+        if not build_py_wrappers:
+            self.tracker_instrumented_fn = tracker_instrumented_fn = fn
+            self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
+        else:
+            track_closure_variables = analyzed_code.track_closure_variables
             closure = fn.__closure__
 
-            # check if objects referred to by the lambda have changed and
-            # re-scan the lambda if so. see comments for this same section in
-            # LambdaElement.
-            for idx, obj in rec[_STALE_CHECK]:
-                if closure[idx].cell_contents is not obj:
-                    rec = self._initialize_var_trackers(
-                        role, apply_propagate_attrs, kw
+            # will form the __closure__ of the function when we rebuild it
+            if closure:
+                new_closure = {
+                    fv: cell.cell_contents
+                    for fv, cell in zip(fn.__code__.co_freevars, closure)
+                }
+            else:
+                new_closure = {}
+
+            # will form the __globals__ of the function when we rebuild it
+            new_globals = fn.__globals__.copy()
+
+            for name, closure_index in build_py_wrappers:
+                if closure_index is not None:
+                    value = closure[closure_index].cell_contents
+                    new_closure[name] = bind = PyWrapper(
+                        name, value, closure_index=closure_index
                     )
-                    break
+                    if track_closure_variables:
+                        closure_pywrappers.append(bind)
+                else:
+                    value = fn.__globals__[name]
+                    new_globals[name] = bind = PyWrapper(name, value)
+
+            # rewrite the original fn.   things that look like they will
+            # become bound parameters are wrapped in a PyWrapper.
+            self.tracker_instrumented_fn = (
+                tracker_instrumented_fn
+            ) = self._rewrite_code_obj(
+                fn,
+                [new_closure[name] for name in fn.__code__.co_freevars],
+                new_globals,
+            )
 
-        self._rec = rec
+            # now invoke the function.  This will give us a new SQL
+            # expression, but all the places that there would be a bound
+            # parameter, the PyWrapper in its place will give us a bind
+            # with a predictable name we can match up later.
 
-        self._propagate_attrs = parent_lambda._propagate_attrs
+            # additionally, each PyWrapper will log that it did in fact
+            # create a parameter, otherwise, it's some kind of Python
+            # object in the closure and we want to track that, to make
+            # sure it doesn't change to somehting else, or if it does,
+            # that we create a different tracked function with that
+            # variable.
+            self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
 
-        self._resolved_bindparams = bindparams = []
-        rec = self._rec
-        while True:
-            if rec[_TRACKERS]:
-                for tracker in rec[_TRACKERS]:
-                    tracker(self.fn, bindparams)
-            if self.parent_lambda is not None:
-                self = self.parent_lambda
-                rec = self._rec
+    def _coerce_expression(self, lambda_element, apply_propagate_attrs, kw):
+        """Run the tracker-generated expression through coercion rules.
+
+        After the user-defined lambda has been invoked to produce a statement
+        for re-use, run it through coercion rules to both check that it's the
+        correct type of object and also to coerce it to its useful form.
+
+        """
+
+        parent_lambda = lambda_element.parent_lambda
+        expr = self.expr
+
+        if parent_lambda is None:
+            if isinstance(expr, collections_abc.Sequence):
+                self.expected_expr = [
+                    coercions.expect(
+                        lambda_element.role,
+                        sub_expr,
+                        apply_propagate_attrs=apply_propagate_attrs,
+                        **kw
+                    )
+                    for sub_expr in expr
+                ]
+                self.is_sequence = True
             else:
-                break
+                self.expected_expr = coercions.expect(
+                    lambda_element.role,
+                    expr,
+                    apply_propagate_attrs=apply_propagate_attrs,
+                    **kw
+                )
+                self.is_sequence = False
+        else:
+            self.expected_expr = expr
+            self.is_sequence = False
 
-    def _invoke_user_fn(self, fn, *arg):
-        return fn(self.parent_lambda._rec[_EXPR])
+        if apply_propagate_attrs is not None:
+            self.propagate_attrs = apply_propagate_attrs._propagate_attrs
+        else:
+            self.propagate_attrs = util.EMPTY_DICT
 
-    def _gen_cache_key(self, anon_map, bindparams):
-        if self._resolved_bindparams:
-            bindparams.extend(self._resolved_bindparams)
+    def _rewrite_code_obj(self, f, cell_values, globals_):
+        """Return a copy of f, with a new closure and new globals
 
-        cache_key = (self.fn.__code__, self.__class__)
+        yes it works in pypy :P
 
-        parent = self.parent_lambda
-        while parent is not None:
-            cache_key = (parent.fn.__code__,) + cache_key
-            parent = parent.parent_lambda
+        """
 
-        return cache_key
+        argrange = range(len(cell_values))
+
+        code = "def make_cells():\n"
+        if cell_values:
+            code += "    (%s) = (%s)\n" % (
+                ", ".join("i%d" % i for i in argrange),
+                ", ".join("o%d" % i for i in argrange),
+            )
+        code += "    def closure():\n"
+        code += "        return %s\n" % ", ".join("i%d" % i for i in argrange)
+        code += "    return closure.__closure__"
+        vars_ = {"o%d" % i: cell_values[i] for i in argrange}
+        exec(code, vars_, vars_)
+        closure = vars_["make_cells"]()
+
+        func = type(f)(
+            f.__code__, globals_, f.__name__, f.__defaults__, closure
+        )
+        if sys.version_info >= (3,):
+            func.__annotations__ = f.__annotations__
+            func.__kwdefaults__ = f.__kwdefaults__
+        func.__doc__ = f.__doc__
+        func.__module__ = f.__module__
+
+        return func
 
 
 class PyWrapper(ColumnOperators):
-    def __init__(self, name, to_evaluate, getter=None):
+    """A wrapper object that is injected into the ``__globals__`` and
+    ``__closure__`` of a Python function.
+
+    When the function is instrumented with :class:`.PyWrapper` objects, it is
+    then invoked just once in order to set up the wrappers.  We look through
+    all the :class:`.PyWrapper` objects we made to find the ones that generated
+    a :class:`.BindParameter` object, e.g. the expression system interpreted
+    something as a literal.   Those positions in the globals/closure are then
+    ones that we will look at, each time a new lambda comes in that refers to
+    the same ``__code__`` object.   In this way, we keep a single version of
+    the SQL expression that this lambda produced, without calling upon the
+    Python function that created it more than once, unless its other closure
+    variables have changed.   The expression is then transformed to have the
+    new bound values embedded into it.
+
+    """
+
+    def __init__(self, name, to_evaluate, closure_index=None, getter=None):
         self._name = name
         self._to_evaluate = to_evaluate
         self._param = None
+        self._has_param = False
         self._bind_paths = {}
         self._getter = getter
+        self._closure_index = closure_index
 
     def __call__(self, *arg, **kw):
         elem = object.__getattribute__(self, "_to_evaluate")
         value = elem(*arg, **kw)
-        if coercions._is_literal(value) and not isinstance(
+        if coercions._deep_is_literal(value) and not isinstance(
             # TODO: coverage where an ORM option or similar is here
             value,
             traversals.HasCacheKey,
@@ -481,8 +1005,8 @@ class PyWrapper(ColumnOperators):
         if param is None:
             name = object.__getattribute__(self, "_name")
             self._param = param = elements.BindParameter(name, unique=True)
+            self._has_param = True
             param.type = type_api._resolve_value_to_type(to_evaluate)
-
         return param._with_value(to_evaluate, maintain_key=True)
 
     def __getattribute__(self, key):
@@ -497,7 +1021,15 @@ class PyWrapper(ColumnOperators):
         else:
             return self._sa__add_getter(key, operator.attrgetter)
 
+    def __iter__(self):
+        elem = object.__getattribute__(self, "_to_evaluate")
+        return iter(elem)
+
     def __getitem__(self, key):
+        elem = object.__getattribute__(self, "_to_evaluate")
+        if not hasattr(elem, "__getitem__"):
+            raise AttributeError("__getitem__")
+
         if isinstance(key, PyWrapper):
             # TODO: coverage
             raise exc.InvalidRequestError(
@@ -518,90 +1050,14 @@ class PyWrapper(ColumnOperators):
         elem = object.__getattribute__(self, "_to_evaluate")
         value = getter(elem)
 
-        if coercions._is_literal(value):
-            wrapper = PyWrapper(key, value, getter)
+        if coercions._deep_is_literal(value):
+            wrapper = PyWrapper(key, value, getter=getter)
             bind_paths[bind_path_key] = wrapper
             return wrapper
         else:
             return value
 
 
-def _roll_down_to_literal(element):
-    is_clause_element = hasattr(element, "__clause_element__")
-
-    if is_clause_element:
-        while not isinstance(
-            element, (elements.ClauseElement, schema.SchemaItem)
-        ):
-            try:
-                element = element.__clause_element__()
-            except AttributeError:
-                break
-
-    if not is_clause_element:
-        insp = inspection.inspect(element, raiseerr=False)
-        if insp is not None:
-            try:
-                return insp.__clause_element__()
-            except AttributeError:
-                return insp
-
-        # TODO: should we coerce consts None/True/False here?
-        return element
-    else:
-        return element
-
-
-def _globals_tracker(name, wrapper):
-    def extract_parameter_value(current_fn, result):
-        object.__getattribute__(wrapper, "_extract_bound_parameters")(
-            current_fn.__globals__[name], result
-        )
-
-    return extract_parameter_value
-
-
-def _closure_tracker(name, wrapper, closure_index):
-    def extract_parameter_value(current_fn, result):
-        object.__getattribute__(wrapper, "_extract_bound_parameters")(
-            current_fn.__closure__[closure_index].cell_contents, result
-        )
-
-    return extract_parameter_value
-
-
-def _rewrite_code_obj(f, cell_values, globals_):
-    """Return a copy of f, with a new closure and new globals
-
-    yes it works in pypy :P
-
-    """
-
-    argrange = range(len(cell_values))
-
-    code = "def make_cells():\n"
-    if cell_values:
-        code += "    (%s) = (%s)\n" % (
-            ", ".join("i%d" % i for i in argrange),
-            ", ".join("o%d" % i for i in argrange),
-        )
-    code += "    def closure():\n"
-    code += "        return %s\n" % ", ".join("i%d" % i for i in argrange)
-    code += "    return closure.__closure__"
-    vars_ = {"o%d" % i: cell_values[i] for i in argrange}
-    exec(code, vars_, vars_)
-    closure = vars_["make_cells"]()
-
-    func = type(f)(f.__code__, globals_, f.__name__, f.__defaults__, closure)
-    if sys.version_info >= (3,):
-        func.__annotations__ = f.__annotations__
-        func.__kwdefaults__ = f.__kwdefaults__
-    func.__doc__ = f.__doc__
-    func.__module__ = f.__module__
-
-    return func
-
-
 @inspection._inspects(LambdaElement)
 def insp(lmb):
     return inspection.inspect(lmb._resolved)
index 1155c273b42ff83ce28db719119dd90208eb1101..d67b6174347b3c68e38ee72d486b44b74c85596a 100644 (file)
@@ -2426,6 +2426,7 @@ class SelectBase(
     """
 
     _is_select_statement = True
+    is_select = True
 
     def _generate_fromclause_column_proxies(self, fromclause):
         # type: (FromClause) -> None
@@ -3867,19 +3868,19 @@ class Select(
         [
             ("_raw_columns", InternalTraversal.dp_clauseelement_list),
             ("_from_obj", InternalTraversal.dp_clauseelement_list),
-            ("_where_criteria", InternalTraversal.dp_clauseelement_list),
-            ("_having_criteria", InternalTraversal.dp_clauseelement_list),
-            ("_order_by_clauses", InternalTraversal.dp_clauseelement_list,),
-            ("_group_by_clauses", InternalTraversal.dp_clauseelement_list,),
+            ("_where_criteria", InternalTraversal.dp_clauseelement_tuple),
+            ("_having_criteria", InternalTraversal.dp_clauseelement_tuple),
+            ("_order_by_clauses", InternalTraversal.dp_clauseelement_tuple,),
+            ("_group_by_clauses", InternalTraversal.dp_clauseelement_tuple,),
             ("_setup_joins", InternalTraversal.dp_setup_join_tuple,),
             ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple,),
-            ("_correlate", InternalTraversal.dp_clauseelement_list),
-            ("_correlate_except", InternalTraversal.dp_clauseelement_list,),
+            ("_correlate", InternalTraversal.dp_clauseelement_tuple),
+            ("_correlate_except", InternalTraversal.dp_clauseelement_tuple,),
             ("_limit_clause", InternalTraversal.dp_clauseelement),
             ("_offset_clause", InternalTraversal.dp_clauseelement),
             ("_for_update_arg", InternalTraversal.dp_clauseelement),
             ("_distinct", InternalTraversal.dp_boolean),
-            ("_distinct_on", InternalTraversal.dp_clauseelement_list),
+            ("_distinct_on", InternalTraversal.dp_clauseelement_tuple),
             ("_label_style", InternalTraversal.dp_plain_obj),
             ("_is_future", InternalTraversal.dp_boolean),
         ]
@@ -4345,7 +4346,7 @@ class Select(
 
     @_generative
     def join(self, target, onclause=None, isouter=False, full=False):
-        r"""Create a SQL JOIN against this :class:`_expresson.Select`
+        r"""Create a SQL JOIN against this :class:`_expression.Select`
         object's criterion
         and apply generatively, returning the newly resulting
         :class:`_expression.Select`.
@@ -4474,7 +4475,7 @@ class Select(
         # they've become.  This allows us to ensure the same cloned from
         # is used when other items such as columns are "cloned"
 
-        all_the_froms = list(
+        all_the_froms = set(
             itertools.chain(
                 _from_objects(*self._raw_columns),
                 _from_objects(*self._where_criteria),
@@ -4490,10 +4491,15 @@ class Select(
         new_froms = {f: clone(f, **kw) for f in all_the_froms}
 
         # 2. copy FROM collections, adding in joins that we've created.
-        self._from_obj = tuple(clone(f, **kw) for f in self._from_obj) + tuple(
-            f for f in new_froms.values() if isinstance(f, Join)
+        existing_from_obj = [clone(f, **kw) for f in self._from_obj]
+        add_froms = (
+            set(f for f in new_froms.values() if isinstance(f, Join))
+            .difference(all_the_froms)
+            .difference(existing_from_obj)
         )
 
+        self._from_obj = tuple(existing_from_obj) + tuple(add_froms)
+
         # 3. clone everything else, making sure we use columns
         # corresponding to the froms we just made.
         def replace(obj, **kw):
@@ -4687,6 +4693,7 @@ class Select(
 
         """
 
+        assert isinstance(self._where_criteria, tuple)
         self._where_criteria += (
             coercions.expect(roles.WhereHavingRole, whereclause),
         )
@@ -5371,6 +5378,9 @@ class TextualSelect(SelectBase):
 
     _is_textual = True
 
+    is_text = True
+    is_select = True
+
     def __init__(self, text, columns, positional=False):
         self.element = text
         # convert for ORM attributes->columns, etc
index f41480a9479fd0314bd56fe82320cc530a88ffdf..cb38df6afa2fb29a65fd22ee807bd5cf6fcdff08 100644 (file)
@@ -190,7 +190,10 @@ class HasCacheKey(object):
                         # statements, not so much, but they usually won't have
                         # annotations.
                         result += self._annotations_cache_key
-                    elif meth is InternalTraversal.dp_clauseelement_list:
+                    elif (
+                        meth is InternalTraversal.dp_clauseelement_list
+                        or meth is InternalTraversal.dp_clauseelement_tuple
+                    ):
                         result += (
                             attrname,
                             tuple(
@@ -390,6 +393,7 @@ class _CacheKey(ExtendedInternalTraversal):
     visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY
     visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
     visit_annotations_key = InternalTraversal.dp_annotations_key
+    visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
 
     visit_string = (
         visit_boolean
@@ -451,6 +455,8 @@ class _CacheKey(ExtendedInternalTraversal):
             tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
         )
 
+    visit_executable_options = visit_has_cache_key_list
+
     def visit_inspectable_list(
         self, attrname, obj, parent, anon_map, bindparams
     ):
@@ -682,6 +688,41 @@ class _CacheKey(ExtendedInternalTraversal):
 _cache_key_traversal_visitor = _CacheKey()
 
 
+class HasCopyInternals(object):
+    def _clone(self, **kw):
+        raise NotImplementedError()
+
+    def _copy_internals(self, omit_attrs=(), **kw):
+        """Reassign internal elements to be clones of themselves.
+
+        Called during a copy-and-traverse operation on newly
+        shallow-copied elements to create a deep copy.
+
+        The given clone function should be used, which may be applying
+        additional transformations to the element (i.e. replacement
+        traversal, cloned traversal, annotations).
+
+        """
+
+        try:
+            traverse_internals = self._traverse_internals
+        except AttributeError:
+            # user-defined classes may not have a _traverse_internals
+            return
+
+        for attrname, obj, meth in _copy_internals.run_generated_dispatch(
+            self, traverse_internals, "_generated_copy_internals_traversal"
+        ):
+            if attrname in omit_attrs:
+                continue
+
+            if obj is not None:
+
+                result = meth(attrname, self, obj, **kw)
+                if result is not None:
+                    setattr(self, attrname, result)
+
+
 class _CopyInternals(InternalTraversal):
     """Generate a _copy_internals internal traversal dispatch for classes
     with a _traverse_internals collection."""
@@ -696,6 +737,16 @@ class _CopyInternals(InternalTraversal):
     ):
         return [clone(clause, **kw) for clause in element]
 
+    def visit_clauseelement_tuple(
+        self, attrname, parent, element, clone=_clone, **kw
+    ):
+        return tuple([clone(clause, **kw) for clause in element])
+
+    def visit_executable_options(
+        self, attrname, parent, element, clone=_clone, **kw
+    ):
+        return tuple([clone(clause, **kw) for clause in element])
+
     def visit_clauseelement_unordered_set(
         self, attrname, parent, element, clone=_clone, **kw
     ):
@@ -817,6 +868,9 @@ class _GetChildren(InternalTraversal):
     def visit_clauseelement_list(self, element, **kw):
         return element
 
+    def visit_clauseelement_tuple(self, element, **kw):
+        return element
+
     def visit_clauseelement_tuples(self, element, **kw):
         return itertools.chain.from_iterable(element)
 
@@ -840,8 +894,8 @@ class _GetChildren(InternalTraversal):
             if not isinstance(target, str):
                 yield _flatten_clauseelement(target)
 
-    #            if onclause is not None and not isinstance(onclause, str):
-    #                yield _flatten_clauseelement(onclause)
+            if onclause is not None and not isinstance(onclause, str):
+                yield _flatten_clauseelement(onclause)
 
     def visit_dml_ordered_values(self, element, **kw):
         for k, v in element:
@@ -1015,6 +1069,8 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
             ):
                 return COMPARE_FAILED
 
+    visit_executable_options = visit_has_cache_key_list
+
     def visit_clauseelement(
         self, attrname, left_parent, left, right_parent, right, **kw
     ):
@@ -1057,6 +1113,12 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
         for l, r in util.zip_longest(left, right, fillvalue=None):
             self.stack.append((l, r))
 
+    def visit_clauseelement_tuple(
+        self, attrname, left_parent, left, right_parent, right, **kw
+    ):
+        for l, r in util.zip_longest(left, right, fillvalue=None):
+            self.stack.append((l, r))
+
     def _compare_unordered_sequences(self, seq1, seq2, **kw):
         if seq1 is None:
             return seq2 is None
index 56d3c93b3cb764bce539d69e8cd0756f780eb10d..5cb3cba709908366211ef5bf023fda37f1b63d9f 100644 (file)
@@ -257,7 +257,7 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
 
     """
 
-    dp_clauseelement_tuples = symbol("CT")
+    dp_clauseelement_tuples = symbol("CTS")
     """Visit a list of tuples which contain :class:`_expression.ClauseElement`
     objects.
 
@@ -268,6 +268,13 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
 
     """
 
+    dp_clauseelement_tuple = symbol("CT")
+    """Visit a tuple of :class:`_expression.ClauseElement` objects.
+
+    """
+
+    dp_executable_options = symbol("EO")
+
     dp_fromclause_ordered_set = symbol("CO")
     """Visit an ordered set of :class:`_expression.FromClause` objects. """
 
@@ -712,6 +719,9 @@ def cloned_traverse(obj, opts, visitors):
     cloned = {}
     stop_on = set(opts.get("stop_on", []))
 
+    def deferred_copy_internals(obj):
+        return cloned_traverse(obj, opts, visitors)
+
     def clone(elem, **kw):
         if elem in stop_on:
             return elem
@@ -732,7 +742,7 @@ def cloned_traverse(obj, opts, visitors):
             return cloned[id(elem)]
 
     if obj is not None:
-        obj = clone(obj)
+        obj = clone(obj, deferred_copy_internals=deferred_copy_internals)
     clone = None  # remove gc cycles
     return obj
 
@@ -764,6 +774,9 @@ def replacement_traverse(obj, opts, replace):
     cloned = {}
     stop_on = {id(x) for x in opts.get("stop_on", [])}
 
+    def deferred_copy_internals(obj):
+        return replacement_traverse(obj, opts, replace)
+
     def clone(elem, **kw):
         if (
             id(elem) in stop_on
@@ -776,19 +789,24 @@ def replacement_traverse(obj, opts, replace):
                 stop_on.add(id(newelem))
                 return newelem
             else:
-
-                if elem not in cloned:
+                # base "already seen" on id(), not hash, so that we don't
+                # replace an Annotated element with its non-annotated one, and
+                # vice versa
+                id_elem = id(elem)
+                if id_elem not in cloned:
                     if "replace" in kw:
                         newelem = kw["replace"](elem)
                         if newelem is not None:
-                            cloned[elem] = newelem
+                            cloned[id_elem] = newelem
                             return newelem
 
-                    cloned[elem] = newelem = elem._clone()
+                    cloned[id_elem] = newelem = elem._clone()
                     newelem._copy_internals(clone=clone, **kw)
-                return cloned[elem]
+                return cloned[id_elem]
 
     if obj is not None:
-        obj = clone(obj, **opts)
+        obj = clone(
+            obj, deferred_copy_internals=deferred_copy_internals, **opts
+        )
     clone = None  # remove gc cycles
     return obj
index cea9c4f667d95eea348fb70942aaca14e0bf5f1c..42992af1308218fe0fcfb8f277e9c539560dbc24 100644 (file)
@@ -12,9 +12,11 @@ from functools import partial  # noqa
 from functools import update_wrapper  # noqa
 
 from ._collections import coerce_generator_arg  # noqa
+from ._collections import coerce_to_immutabledict  # noqa
 from ._collections import collections_abc  # noqa
 from ._collections import column_dict  # noqa
 from ._collections import column_set  # noqa
+from ._collections import EMPTY_DICT  # noqa
 from ._collections import EMPTY_SET  # noqa
 from ._collections import FacadeDict  # noqa
 from ._collections import flatten_iterator  # noqa
index 60568649478691733634b38fefa0bf4cc0335ad1..7c109b358e033bdd9506725c6f165c77a6fdc35b 100644 (file)
@@ -49,13 +49,25 @@ def _immutabledict_py_fallback():
         def __reduce__(self):
             return _immutabledict_reconstructor, (dict(self),)
 
-        def union(self, d):
-            if not d:
+        def union(self, __d=None):
+            if not __d:
                 return self
 
             new = dict.__new__(self.__class__)
             dict.__init__(new, self)
-            dict.update(new, d)
+            dict.update(new, __d)
+            return new
+
+        def _union_w_kw(self, __d=None, **kw):
+            # not sure if C version works correctly w/ this yet
+            if not __d and not kw:
+                return self
+
+            new = dict.__new__(self.__class__)
+            dict.__init__(new, self)
+            if __d:
+                dict.update(new, __d)
+            dict.update(new, kw)
             return new
 
         def merge_with(self, *dicts):
@@ -90,6 +102,18 @@ except ImportError:
         return immutabledict(*arg)
 
 
+def coerce_to_immutabledict(d):
+    if not d:
+        return EMPTY_DICT
+    elif isinstance(d, immutabledict):
+        return d
+    else:
+        return immutabledict(d)
+
+
+EMPTY_DICT = immutabledict()
+
+
 class FacadeDict(ImmutableContainer, dict):
     """A dictionary that is not publicly mutable."""
 
index 7ac716dbeb8e0005fd444296ebd2ef9a71d5ee1a..b573accbd947a2d71382e0668d63acba70c9c4b2 100644 (file)
@@ -131,6 +131,11 @@ class MyEntity(HasCacheKey):
     ]
 
 
+class Foo:
+    x = 10
+    y = 15
+
+
 dml.Insert.argument_for("sqlite", "foo", None)
 dml.Update.argument_for("sqlite", "foo", None)
 dml.Delete.argument_for("sqlite", "foo", None)
@@ -790,7 +795,7 @@ class CoreFixtures(object):
 
         def two():
             r = random.randint(1, 10)
-            q = 20
+            q = 408
             return LambdaElement(
                 lambda: table_a.c.a + q == r, roles.WhereHavingRole
             )
@@ -803,10 +808,6 @@ class CoreFixtures(object):
                 roles.WhereHavingRole,
             )
 
-        class Foo:
-            x = 10
-            y = 15
-
         def four():
             return LambdaElement(
                 lambda: and_(table_a.c.a == Foo.x), roles.WhereHavingRole
@@ -833,6 +834,16 @@ class CoreFixtures(object):
                 lambda s: s.where(table_a.c.a == value)
             )
 
+        from sqlalchemy.sql import lambdas
+
+        def eight():
+            q = 5
+            return lambdas.DeferredLambdaElement(
+                lambda t: t.c.a > q,
+                roles.WhereHavingRole,
+                lambda_args=(table_a,),
+            )
+
         return [
             one(),
             two(),
@@ -841,6 +852,7 @@ class CoreFixtures(object):
             five(),
             six(),
             seven(),
+            eight(),
         ]
 
     dont_compare_values_fixtures.append(_lambda_fixtures)
index aefcaf252ce2ff971a52f3491b48c730c1a96afd..4918afc9c6dcd521840bd34bc71ff44a7135c8bd 100644 (file)
@@ -791,6 +791,90 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
             "JOIN table2 ON table1.col1 = table2.col2) AS anon_1",
         )
 
+    def test_this_thing_using_setup_joins_three(self):
+
+        j = t1.join(t2, t1.c.col1 == t2.c.col2)
+
+        s1 = select(j)
+
+        s2 = s1.join(t3, t1.c.col1 == t3.c.col1)
+
+        self.assert_compile(
+            s2,
+            "SELECT table1.col1, table1.col2, table1.col3, "
+            "table2.col1, table2.col2, table2.col3 FROM table1 "
+            "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+            "ON table3.col1 = table1.col1",
+        )
+
+        vis = sql_util.ClauseAdapter(j)
+
+        s3 = vis.traverse(s1)
+
+        s4 = s3.join(t3, t1.c.col1 == t3.c.col1)
+
+        self.assert_compile(
+            s4,
+            "SELECT table1.col1, table1.col2, table1.col3, "
+            "table2.col1, table2.col2, table2.col3 FROM table1 "
+            "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+            "ON table3.col1 = table1.col1",
+        )
+
+        s5 = vis.traverse(s3)
+
+        s6 = s5.join(t3, t1.c.col1 == t3.c.col1)
+
+        self.assert_compile(
+            s6,
+            "SELECT table1.col1, table1.col2, table1.col3, "
+            "table2.col1, table2.col2, table2.col3 FROM table1 "
+            "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+            "ON table3.col1 = table1.col1",
+        )
+
+    def test_this_thing_using_setup_joins_four(self):
+
+        j = t1.join(t2, t1.c.col1 == t2.c.col2)
+
+        s1 = select(j)
+
+        assert not s1._from_obj
+
+        s2 = s1.join(t3, t1.c.col1 == t3.c.col1)
+
+        self.assert_compile(
+            s2,
+            "SELECT table1.col1, table1.col2, table1.col3, "
+            "table2.col1, table2.col2, table2.col3 FROM table1 "
+            "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+            "ON table3.col1 = table1.col1",
+        )
+
+        s3 = visitors.replacement_traverse(s1, {}, lambda elem: None)
+
+        s4 = s3.join(t3, t1.c.col1 == t3.c.col1)
+
+        self.assert_compile(
+            s4,
+            "SELECT table1.col1, table1.col2, table1.col3, "
+            "table2.col1, table2.col2, table2.col3 FROM table1 "
+            "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+            "ON table3.col1 = table1.col1",
+        )
+
+        s5 = visitors.replacement_traverse(s3, {}, lambda elem: None)
+
+        s6 = s5.join(t3, t1.c.col1 == t3.c.col1)
+
+        self.assert_compile(
+            s6,
+            "SELECT table1.col1, table1.col2, table1.col3, "
+            "table2.col1, table2.col2, table2.col3 FROM table1 "
+            "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+            "ON table3.col1 = table1.col1",
+        )
+
     def test_select_fromtwice_one(self):
         t1a = t1.alias()
 
index 53f6a9544c0edfaee22ad22d99bf54c859731dc1..a91242de5ec503d2ecefda37eee446bcb0d35e3e 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy.schema import Column
 from sqlalchemy.schema import ForeignKey
 from sqlalchemy.schema import Table
 from sqlalchemy.sql import and_
+from sqlalchemy.sql import bindparam
 from sqlalchemy.sql import coercions
 from sqlalchemy.sql import column
 from sqlalchemy.sql import join
@@ -19,6 +20,7 @@ from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import ne_
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.types import Integer
 from sqlalchemy.types import String
@@ -46,10 +48,39 @@ class DeferredLambdaTest(
             go(), "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :x_1 AND t1.p = :y_1"
         )
 
+    def test_global_tracking(self):
+        t1 = table("t1", column("q"), column("p"))
+
+        global global_x, global_y
+
+        global_x = 10
+        global_y = 17
+
+        def go():
+            return select([t1]).where(
+                lambda: and_(t1.c.q == global_x, t1.c.p == global_y)
+            )
+
+        self.assert_compile(
+            go(),
+            "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :global_x_1 "
+            "AND t1.p = :global_y_1",
+            checkparams={"global_x_1": 10, "global_y_1": 17},
+        )
+
+        global_y = 9
+
+        self.assert_compile(
+            go(),
+            "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :global_x_1 "
+            "AND t1.p = :global_y_1",
+            checkparams={"global_x_1": 10, "global_y_1": 9},
+        )
+
     def test_stale_checker_embedded(self):
         def go(x):
 
-            stmt = select([lambda: x])
+            stmt = select(lambda: x)
             return stmt
 
         c1 = column("x")
@@ -67,7 +98,7 @@ class DeferredLambdaTest(
     def test_stale_checker_statement(self):
         def go(x):
 
-            stmt = lambdas.lambda_stmt(lambda: select([x]))
+            stmt = lambdas.lambda_stmt(lambda: select(x))
             return stmt
 
         c1 = column("x")
@@ -85,13 +116,13 @@ class DeferredLambdaTest(
     def test_stale_checker_linked(self):
         def go(x, y):
 
-            stmt = lambdas.lambda_stmt(lambda: select([x])) + (
+            stmt = lambdas.lambda_stmt(lambda: select(x)) + (
                 lambda s: s.where(y > 5)
             )
             return stmt
 
-        c1 = column("x")
-        c2 = column("y")
+        c1 = oldc1 = column("x")
+        c2 = oldc2 = column("y")
         s1 = go(c1, c2)
         s2 = go(c1, c2)
 
@@ -104,6 +135,426 @@ class DeferredLambdaTest(
         s3 = go(c1, c2)
         self.assert_compile(s3, "SELECT q WHERE p > :p_1")
 
+        s4 = go(c1, c2)
+        self.assert_compile(s4, "SELECT q WHERE p > :p_1")
+
+        s5 = go(oldc1, oldc2)
+        self.assert_compile(s5, "SELECT x WHERE y > :y_1")
+
+    def test_stmt_lambda_w_additional_hascachekey_variants(self):
+        def go(col_expr, q):
+            stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+            stmt += lambda stmt: stmt.where(col_expr == q)
+
+            return stmt
+
+        c1 = column("x")
+        c2 = column("y")
+
+        s1 = go(c1, 5)
+        s2 = go(c2, 10)
+        s3 = go(c1, 8)
+        s4 = go(c2, 12)
+
+        self.assert_compile(
+            s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5}
+        )
+        self.assert_compile(
+            s2, "SELECT y WHERE y = :q_1", checkparams={"q_1": 10}
+        )
+        self.assert_compile(
+            s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8}
+        )
+        self.assert_compile(
+            s4, "SELECT y WHERE y = :q_1", checkparams={"q_1": 12}
+        )
+
+        s1key = s1._generate_cache_key()
+        s2key = s2._generate_cache_key()
+        s3key = s3._generate_cache_key()
+        s4key = s4._generate_cache_key()
+
+        eq_(s1key[0], s3key[0])
+        eq_(s2key[0], s4key[0])
+        ne_(s1key[0], s2key[0])
+
+    def test_stmt_lambda_w_atonce_whereclause_values_notrack(self):
+        def go(col_expr, whereclause):
+            stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+            stmt = stmt.add_criteria(
+                lambda stmt: stmt.where(whereclause), enable_tracking=False
+            )
+
+            return stmt
+
+        c1 = column("x")
+
+        s1 = go(c1, c1 == 5)
+        s2 = go(c1, c1 == 10)
+
+        self.assert_compile(
+            s1, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5}
+        )
+
+        # and as we see, this is wrong.   Because whereclause
+        # is fixed for the lambda and we do not re-evaluate the closure
+        # for this value changing.   this can't be passed unless
+        # enable_tracking=False.
+        self.assert_compile(
+            s2, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5}
+        )
+
+    def test_stmt_lambda_w_atonce_whereclause_values(self):
+        c2 = column("y")
+
+        def go(col_expr, whereclause, x):
+            stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+            stmt = stmt.add_criteria(
+                lambda stmt: stmt.where(whereclause).order_by(c2 > x),
+            )
+
+            return stmt
+
+        c1 = column("x")
+
+        s1 = go(c1, c1 == 5, 9)
+        s2 = go(c1, c1 == 10, 15)
+
+        s1key = s1._generate_cache_key()
+        s2key = s2._generate_cache_key()
+
+        eq_([b.value for b in s1key.bindparams], [5, 9])
+        eq_([b.value for b in s2key.bindparams], [10, 15])
+
+        self.assert_compile(
+            s1,
+            "SELECT x WHERE x = :x_1 ORDER BY y > :x_2",
+            checkparams={"x_1": 5, "x_2": 9},
+        )
+
+        self.assert_compile(
+            s2,
+            "SELECT x WHERE x = :x_1 ORDER BY y > :x_2",
+            checkparams={"x_1": 10, "x_2": 15},
+        )
+
+    def test_stmt_lambda_plain_customtrack(self):
+        c2 = column("y")
+
+        def go(col_expr, whereclause, p):
+            stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+            stmt = stmt.add_criteria(lambda stmt: stmt.where(whereclause))
+            stmt = stmt.add_criteria(
+                lambda stmt: stmt.order_by(col_expr), track_on=(col_expr,)
+            )
+            stmt = stmt.add_criteria(lambda stmt: stmt.where(col_expr == p))
+            return stmt
+
+        c1 = column("x")
+        c2 = column("y")
+
+        s1 = go(c1, c1 == 5, 9)
+        s2 = go(c1, c1 == 10, 15)
+        s3 = go(c2, c2 == 18, 12)
+
+        s1key = s1._generate_cache_key()
+        s2key = s2._generate_cache_key()
+        s3key = s3._generate_cache_key()
+
+        eq_([b.value for b in s1key.bindparams], [5, 9])
+        eq_([b.value for b in s2key.bindparams], [10, 15])
+        eq_([b.value for b in s3key.bindparams], [18, 12])
+
+        self.assert_compile(
+            s1,
+            "SELECT x WHERE x = :x_1 AND x = :p_1 ORDER BY x",
+            checkparams={"x_1": 5, "p_1": 9},
+        )
+
+        self.assert_compile(
+            s2,
+            "SELECT x WHERE x = :x_1 AND x = :p_1 ORDER BY x",
+            checkparams={"x_1": 10, "p_1": 15},
+        )
+
+        self.assert_compile(
+            s3,
+            "SELECT y WHERE y = :y_1 AND y = :p_1 ORDER BY y",
+            checkparams={"y_1": 18, "p_1": 12},
+        )
+
+    def test_stmt_lambda_w_atonce_whereclause_customtrack_binds(self):
+        c2 = column("y")
+
+        # this pattern is *completely unnecessary*, and I would prefer
+        # if we can detect this and just raise, because when it is not done
+        # correctly, it is *extremely* difficult to catch it failing.
+        # however I also can't come up with a reliable way to catch it.
+        # so we will keep the use of "track_on" to be internal.
+
+        def go(col_expr, whereclause, p):
+            stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+            stmt = stmt.add_criteria(
+                lambda stmt: stmt.where(whereclause).order_by(col_expr > p),
+                track_on=(whereclause, whereclause.right.value),
+            )
+
+            return stmt
+
+        c1 = column("x")
+        c2 = column("y")
+
+        s1 = go(c1, c1 == 5, 9)
+        s2 = go(c1, c1 == 10, 15)
+        s3 = go(c2, c2 == 18, 12)
+
+        s1key = s1._generate_cache_key()
+        s2key = s2._generate_cache_key()
+        s3key = s3._generate_cache_key()
+
+        eq_([b.value for b in s1key.bindparams], [5, 9])
+        eq_([b.value for b in s2key.bindparams], [10, 15])
+        eq_([b.value for b in s3key.bindparams], [18, 12])
+
+        self.assert_compile(
+            s1,
+            "SELECT x WHERE x = :x_1 ORDER BY x > :p_1",
+            checkparams={"x_1": 5, "p_1": 9},
+        )
+
+        self.assert_compile(
+            s2,
+            "SELECT x WHERE x = :x_1 ORDER BY x > :p_1",
+            checkparams={"x_1": 10, "p_1": 15},
+        )
+
+        self.assert_compile(
+            s3,
+            "SELECT y WHERE y = :y_1 ORDER BY y > :p_1",
+            checkparams={"y_1": 18, "p_1": 12},
+        )
+
+    def test_stmt_lambda_track_closure_binds_one(self):
+        def go(col_expr, whereclause):
+            stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+            stmt += lambda stmt: stmt.where(whereclause)
+
+            return stmt
+
+        c1 = column("x")
+
+        s1 = go(c1, c1 == 5)
+        s2 = go(c1, c1 == 10)
+
+        self.assert_compile(
+            s1, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5}
+        )
+        self.assert_compile(
+            s2, "SELECT x WHERE x = :x_1", checkparams={"x_1": 10}
+        )
+
+        s1key = s1._generate_cache_key()
+        s2key = s2._generate_cache_key()
+
+        eq_(s1key.key, s2key.key)
+
+        eq_([b.value for b in s1key.bindparams], [5])
+        eq_([b.value for b in s2key.bindparams], [10])
+
+    def test_stmt_lambda_track_closure_binds_two(self):
+        def go(col_expr, whereclause, x, y):
+            stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+            stmt += lambda stmt: stmt.where(whereclause).where(
+                and_(c1 == x, c1 < y)
+            )
+
+            return stmt
+
+        c1 = column("x")
+
+        s1 = go(c1, c1 == 5, 8, 9)
+        s2 = go(c1, c1 == 10, 12, 14)
+
+        s1key = s1._generate_cache_key()
+        s2key = s2._generate_cache_key()
+
+        self.assert_compile(
+            s1,
+            "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1",
+            checkparams={"x_1": 5, "x_2": 8, "y_1": 9},
+        )
+        self.assert_compile(
+            s2,
+            "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1",
+            checkparams={"x_1": 10, "x_2": 12, "y_1": 14},
+        )
+
+        eq_([b.value for b in s1key.bindparams], [5, 8, 9])
+        eq_([b.value for b in s2key.bindparams], [10, 12, 14])
+
+        s1_compiled_cached = s1.compile(cache_key=s1key)
+
+        params = s1_compiled_cached.construct_params(
+            extracted_parameters=s2key[1]
+        )
+
+        eq_(params, {"x_1": 10, "x_2": 12, "y_1": 14})
+
+    def test_stmt_lambda_track_closure_binds_three(self):
+        def go(col_expr, whereclause, x, y):
+            stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+            stmt += lambda stmt: stmt.where(whereclause)
+            stmt += lambda stmt: stmt.where(and_(c1 == x, c1 < y))
+
+            return stmt
+
+        c1 = column("x")
+
+        s1 = go(c1, c1 == 5, 8, 9)
+        s2 = go(c1, c1 == 10, 12, 14)
+
+        s1key = s1._generate_cache_key()
+        s2key = s2._generate_cache_key()
+
+        self.assert_compile(
+            s1,
+            "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1",
+            checkparams={"x_1": 5, "x_2": 8, "y_1": 9},
+        )
+        self.assert_compile(
+            s2,
+            "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1",
+            checkparams={"x_1": 10, "x_2": 12, "y_1": 14},
+        )
+
+        eq_([b.value for b in s1key.bindparams], [5, 8, 9])
+        eq_([b.value for b in s2key.bindparams], [10, 12, 14])
+
+        s1_compiled_cached = s1.compile(cache_key=s1key)
+
+        params = s1_compiled_cached.construct_params(
+            extracted_parameters=s2key[1]
+        )
+
+        eq_(params, {"x_1": 10, "x_2": 12, "y_1": 14})
+
+    def test_stmt_lambda_w_atonce_whereclause_novalue(self):
+        def go(col_expr, whereclause):
+            stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+            stmt += lambda stmt: stmt.where(whereclause)
+
+            return stmt
+
+        c1 = column("x")
+
+        s1 = go(c1, bindparam("x"))
+
+        self.assert_compile(s1, "SELECT x WHERE :x")
+
+    def test_stmt_lambda_w_additional_hashable_variants(self):
+        # note a Python 2 old style class would fail here because it
+        # isn't hashable.   right now we do a hard check for __hash__ which
+        # will raise if the attr isn't present
+        class Thing(object):
+            def __init__(self, col_expr):
+                self.col_expr = col_expr
+
+        def go(thing, q):
+            stmt = lambdas.lambda_stmt(lambda: select(thing.col_expr))
+            stmt += lambda stmt: stmt.where(thing.col_expr == q)
+
+            return stmt
+
+        c1 = Thing(column("x"))
+        c2 = Thing(column("y"))
+
+        s1 = go(c1, 5)
+        s2 = go(c2, 10)
+        s3 = go(c1, 8)
+        s4 = go(c2, 12)
+
+        self.assert_compile(
+            s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5}
+        )
+        self.assert_compile(
+            s2, "SELECT y WHERE y = :q_1", checkparams={"q_1": 10}
+        )
+        self.assert_compile(
+            s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8}
+        )
+        self.assert_compile(
+            s4, "SELECT y WHERE y = :q_1", checkparams={"q_1": 12}
+        )
+
+        s1key = s1._generate_cache_key()
+        s2key = s2._generate_cache_key()
+        s3key = s3._generate_cache_key()
+        s4key = s4._generate_cache_key()
+
+        eq_(s1key[0], s3key[0])
+        eq_(s2key[0], s4key[0])
+        ne_(s1key[0], s2key[0])
+
+    def test_stmt_lambda_w_set_of_opts(self):
+
+        stmt = lambdas.lambda_stmt(lambda: select(column("x")))
+
+        opts = {column("x"), column("y")}
+
+        assert_raises_message(
+            exc.ArgumentError,
+            'Can\'t create a cache key for lambda closure variable "opts" '
+            "because it's a set.  try using a list",
+            stmt.__add__,
+            lambda stmt: stmt.options(*opts),
+        )
+
+    def test_stmt_lambda_w_list_of_opts(self):
+        def go(opts):
+            stmt = lambdas.lambda_stmt(lambda: select(column("x")))
+            stmt += lambda stmt: stmt.options(*opts)
+
+            return stmt
+
+        s1 = go([column("a"), column("b")])
+
+        s2 = go([column("a"), column("b")])
+
+        s3 = go([column("q"), column("b")])
+
+        s1key = s1._generate_cache_key()
+        s2key = s2._generate_cache_key()
+        s3key = s3._generate_cache_key()
+
+        eq_(s1key.key, s2key.key)
+        ne_(s1key.key, s3key.key)
+
+    def test_stmt_lambda_hey_theres_multiple_paths(self):
+        def go(x, y):
+            stmt = lambdas.lambda_stmt(lambda: select(column("x")))
+
+            if x > 5:
+                stmt += lambda stmt: stmt.where(column("x") == x)
+            else:
+                stmt += lambda stmt: stmt.where(column("y") == y)
+
+            stmt += lambda stmt: stmt.order_by(column("q"))
+
+            # TODO: need more path variety here to exercise
+            # using a full path key
+
+            return stmt
+
+        s1 = go(2, 5)
+        s2 = go(8, 7)
+        s3 = go(4, 9)
+        s4 = go(10, 1)
+
+        self.assert_compile(s1, "SELECT x WHERE y = :y_1 ORDER BY q")
+        self.assert_compile(s2, "SELECT x WHERE x = :x_1 ORDER BY q")
+        self.assert_compile(s3, "SELECT x WHERE y = :y_1 ORDER BY q")
+        self.assert_compile(s4, "SELECT x WHERE x = :x_1 ORDER BY q")
+
     def test_coercion_cols_clause(self):
         assert_raises_message(
             exc.ArgumentError,
index 01c8d7ca658efbe297c356178fc2148f4952b6fa..58280bb6770d2016c02c77b24e3e8c9972b1c2cb 100644 (file)
@@ -29,6 +29,7 @@ from sqlalchemy import TypeDecorator
 from sqlalchemy import union
 from sqlalchemy import util
 from sqlalchemy.sql import Alias
+from sqlalchemy.sql import annotation
 from sqlalchemy.sql import base
 from sqlalchemy.sql import column
 from sqlalchemy.sql import elements
@@ -2352,6 +2353,33 @@ class AnnotationsTest(fixtures.TestBase):
             annot = obj._annotate({})
             ne_(set([obj]), set([annot]))
 
+    def test_replacement_traverse_preserve(self):
+        """test that replacement traverse that hits an unannotated column
+        does not use it when replacing an annotated column.
+
+        this requires that replacement traverse store elements in the
+        "seen" hash based on id(), not hash.
+
+        """
+        t = table("t", column("x"))
+
+        stmt = select([t.c.x])
+
+        whereclause = annotation._deep_annotate(t.c.x == 5, {"foo": "bar"})
+
+        eq_(whereclause._annotations, {"foo": "bar"})
+        eq_(whereclause.left._annotations, {"foo": "bar"})
+        eq_(whereclause.right._annotations, {"foo": "bar"})
+
+        stmt = stmt.where(whereclause)
+
+        s2 = visitors.replacement_traverse(stmt, {}, lambda elem: None)
+
+        whereclause = s2._where_criteria[0]
+        eq_(whereclause._annotations, {"foo": "bar"})
+        eq_(whereclause.left._annotations, {"foo": "bar"})
+        eq_(whereclause.right._annotations, {"foo": "bar"})
+
     def test_proxy_set_iteration_includes_annotated(self):
         from sqlalchemy.schema import Column