From: Mike Bayer Date: Wed, 5 Aug 2020 20:42:26 +0000 (-0400) Subject: Robustness for lambdas, lambda statements X-Git-Tag: rel_1_4_0b1~198 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cc57ea495f6460dd56daa6de57e40047ed999369;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Robustness for lambdas, lambda statements in order to accommodate relationship loaders with lambda caching, a lot more is needed. This is a full refactor of the lambda system such that it now has two levels of caching; the first level caches what can be known from the __code__ element, then the next level of caching is against the lambda itself and the contents of __closure__. This allows for the elements inside the lambdas, like columns and entities, to change and then be part of the cache key. Lazy/selectinloads' use of baked queries had to add distinct cache key elements, which was attempted here but overall things needed to be more robust than that. This commit is broken out from the very long and sprawling commit at Id6b5c03b1ce9ddb7b280f66792212a0ef0a1c541 . Change-Id: I29a513c98917b1d503abfdd61e6b6e8800851aa8 --- diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index 8b0377a582..985a12fa00 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -230,7 +230,7 @@ def create_engine(url, **kwargs): :param future: Use the 2.0 style :class:`_future.Engine` and :class:`_future.Connection` API. - ..versionadded:: 1.4 + .. versionadded:: 1.4 .. seealso:: diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 6f4934521c..bcffca9324 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1247,7 +1247,7 @@ class BaseCursorResult(object): if ( compiled and compiled._result_columns - and context.cache_hit + and context.cache_hit is context.dialect.CACHE_HIT and not compiled._rewrites_selected_columns and compiled.statement is not context.invoked_statement ): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index c431fa7555..8d3c5de157 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -39,6 +39,12 @@ AUTOCOMMIT_REGEXP = re.compile( SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE) +CACHE_HIT = util.symbol("CACHE_HIT") +CACHE_MISS = util.symbol("CACHE_MISS") +CACHING_DISABLED = util.symbol("CACHING_DISABLED") +NO_CACHE_KEY = util.symbol("NO_CACHE_KEY") + + class DefaultDialect(interfaces.Dialect): """Default implementation of Dialect""" @@ -195,6 +201,11 @@ class DefaultDialect(interfaces.Dialect): """ + CACHE_HIT = CACHE_HIT + CACHE_MISS = CACHE_MISS + CACHING_DISABLED = CACHING_DISABLED + NO_CACHE_KEY = NO_CACHE_KEY + @util.deprecated_params( convert_unicode=( "1.3", @@ -725,6 +736,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): _expanded_parameters = util.immutabledict() + cache_hit = NO_CACHE_KEY + @classmethod def _init_ddl( cls, @@ -788,7 +801,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): parameters, invoked_statement, extracted_parameters, - cache_hit=False, + cache_hit=CACHING_DISABLED, ): """Initialize execution context for a Compiled construct.""" @@ -1026,12 +1039,19 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return "raw sql" now = util.perf_counter() - if self.compiled.cache_key is None: + + ch = self.cache_hit + + if ch is NO_CACHE_KEY: return "no key %.5fs" % (now - self.compiled._gen_time,) - elif self.cache_hit: + elif ch is CACHE_HIT: return "cached since %.4gs ago" % (now - self.compiled._gen_time,) - else: + elif ch is CACHE_MISS: return "generated in %.5fs" % (now - self.compiled._gen_time,) + elif ch is CACHING_DISABLED: + return "caching disabled %.5fs" % (now - self.compiled._gen_time,) + else: + return "unknown" @util.memoized_property def engine(self): diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 588c485aeb..fa0f9c4357 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -7,10 +7,13 @@ import numbers import re +import types from . import operators from . import roles from . import visitors +from .base import Options +from .traversals import HasCacheKey from .visitors import Visitable from .. import exc from .. import inspection @@ -33,11 +36,36 @@ def _is_literal(element): of a SQL expression construct. """ + return not isinstance( - element, (Visitable, schema.SchemaEventTarget) + element, (Visitable, schema.SchemaEventTarget), ) and not hasattr(element, "__clause_element__") +def _deep_is_literal(element): + """Return whether or not the element is a "literal" in the context + of a SQL expression construct. + + does a deeper more esoteric check than _is_literal. is used + for lambda elements that have to distinguish values that would + be bound vs. not without any context. + + """ + + return ( + not isinstance( + element, + (Visitable, schema.SchemaEventTarget, HasCacheKey, Options,), + ) + and not hasattr(element, "__clause_element__") + and ( + not isinstance(element, type) + or not issubclass(element, HasCacheKey) + ) + and not isinstance(element, types.FunctionType) + ) + + def _document_text_coercion(paramname, meth_rst, param_rst): return util.add_parameter_text( paramname, @@ -711,9 +739,16 @@ class StatementImpl(_NoTextCoercion, RoleImpl): class CoerceTextStatementImpl(_CoerceLiterals, RoleImpl): __slots__ = () - def _literal_coercion(self, element, **kw): + def _dont_literal_coercion(self, element, **kw): if callable(element) and hasattr(element, "__code__"): - return lambdas.StatementLambdaElement(element, self._role_class) + return lambdas.StatementLambdaElement( + element, + self._role_class, + additional_cache_criteria=kw.get( + "additional_cache_criteria", () + ), + tracked=kw["tra"], + ) else: return super(CoerceTextStatementImpl, self)._literal_coercion( element, **kw diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index ca73a43924..8a506446db 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -32,8 +32,8 @@ from .base import NO_ARG from .base import PARSE_AUTOCOMMIT from .base import SingletonConstant from .coercions import _document_text_coercion -from .traversals import _copy_internals from .traversals import _get_children +from .traversals import HasCopyInternals from .traversals import MemoizedHasCacheKey from .traversals import NO_CACHE from .visitors import cloned_traverse @@ -182,6 +182,7 @@ class ClauseElement( roles.SQLRole, SupportsWrappingAnnotations, MemoizedHasCacheKey, + HasCopyInternals, Traversible, ): """Base class for elements of a programmatically constructed SQL @@ -372,35 +373,6 @@ class ClauseElement( """ return traversals.compare(self, other, **kw) - def _copy_internals(self, omit_attrs=(), **kw): - """Reassign internal elements to be clones of themselves. - - Called during a copy-and-traverse operation on newly - shallow-copied elements to create a deep copy. - - The given clone function should be used, which may be applying - additional transformations to the element (i.e. replacement - traversal, cloned traversal, annotations). - - """ - - try: - traverse_internals = self._traverse_internals - except AttributeError: - # user-defined classes may not have a _traverse_internals - return - - for attrname, obj, meth in _copy_internals.run_generated_dispatch( - self, traverse_internals, "_generated_copy_internals_traversal" - ): - if attrname in omit_attrs: - continue - - if obj is not None: - result = meth(self, attrname, obj, **kw) - if result is not None: - setattr(self, attrname, result) - def get_children(self, omit_attrs=(), **kw): r"""Return immediate child :class:`.visitors.Traversible` elements of this :class:`.visitors.Traversible`. @@ -535,8 +507,6 @@ class ClauseElement( else: elem_cache_key = None - cache_hit = False - if elem_cache_key: cache_key, extracted_params = elem_cache_key key = ( @@ -549,6 +519,7 @@ class ClauseElement( compiled_sql = compiled_cache.get(key) if compiled_sql is None: + cache_hit = dialect.CACHE_MISS compiled_sql = self._compiler( dialect, cache_key=elem_cache_key, @@ -559,7 +530,7 @@ class ClauseElement( ) compiled_cache[key] = compiled_sql else: - cache_hit = True + cache_hit = dialect.CACHE_HIT else: extracted_params = None compiled_sql = self._compiler( @@ -570,6 +541,11 @@ class ClauseElement( schema_translate_map=schema_translate_map, **kw ) + cache_hit = ( + dialect.CACHING_DISABLED + if compiled_cache is None + else dialect.NO_CACHE_KEY + ) return compiled_sql, extracted_params, cache_hit @@ -1343,10 +1319,7 @@ class BindParameter(roles.InElementRole, ColumnElement): if required is NO_ARG: required = value is NO_ARG and callable_ is None if value is NO_ARG: - self._value_required_for_cache = False value = None - else: - self._value_required_for_cache = True if quote is not None: key = quoted_name(key, quote) @@ -1412,6 +1385,7 @@ class BindParameter(roles.InElementRole, ColumnElement): """Return a copy of this :class:`.BindParameter` with the given value set. """ + cloned = self._clone(maintain_key=maintain_key) cloned.value = value cloned.callable = None @@ -1465,7 +1439,8 @@ class BindParameter(roles.InElementRole, ColumnElement): anon_map[idself] = id_ = str(anon_map.index) anon_map.index += 1 - bindparams.append(self) + if bindparams is not None: + bindparams.append(self) return ( id_, diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 7924111896..3270039026 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -8,6 +8,7 @@ import itertools import operator import sys +import types import weakref from . import coercions @@ -17,29 +18,23 @@ from . import schema from . import traversals from . import type_api from . import visitors +from .base import _clone from .operators import ColumnOperators from .. import exc from .. import inspection from .. import util from ..util import collections_abc -_trackers = weakref.WeakKeyDictionary() +_closure_per_cache_key = util.LRUCache(1000) -_TRACKERS = 0 -_STALE_CHECK = 1 -_REAL_FN = 2 -_EXPR = 3 -_IS_SEQUENCE = 4 -_PROPAGATE_ATTRS = 5 - - -def lambda_stmt(lmb): +def lambda_stmt(lmb, **opts): """Produce a SQL statement that is cached as a lambda. - This SQL statement will only be constructed if element has not been - compiled yet. The approach is used to save on Python function overhead - when constructing statements that will be cached. + The Python code object within the lambda is scanned for both Python + literals that will become bound parameters as well as closure variables + that refer to Core or ORM constructs that may vary. The lambda itself + will be invoked only once per particular set of constructs detected. E.g.:: @@ -60,7 +55,8 @@ def lambda_stmt(lmb): """ - return coercions.expect(roles.CoerceTextStatementRole, lmb) + + return StatementLambdaElement(lmb, roles.CoerceTextStatementRole, **opts) class LambdaElement(elements.ClauseElement): @@ -87,64 +83,108 @@ class LambdaElement(elements.ClauseElement): _is_lambda_element = True - _resolved_bindparams = () - _traverse_internals = [ ("_resolved", visitors.InternalTraversal.dp_clauseelement) ] + _transforms = () + + parent_lambda = None + def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self.fn.__code__) def __init__(self, fn, role, apply_propagate_attrs=None, **kw): self.fn = fn self.role = role - self.parent_lambda = None + self.tracker_key = (fn.__code__,) if apply_propagate_attrs is None and ( role is roles.CoerceTextStatementRole ): apply_propagate_attrs = self - if fn.__code__ not in _trackers: - rec = self._initialize_var_trackers( - role, apply_propagate_attrs, kw + rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, kw) + + if apply_propagate_attrs is not None: + propagate_attrs = rec.propagate_attrs + if propagate_attrs: + apply_propagate_attrs._propagate_attrs = propagate_attrs + + def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, kw): + lambda_cache = kw.get("lambda_cache", _closure_per_cache_key) + + tracker_key = self.tracker_key + + fn = self.fn + closure = fn.__closure__ + + tracker = AnalyzedCode.get( + fn, + self, + kw, + track_bound_values=kw.get("track_bound_values", True), + enable_tracking=kw.get("enable_tracking", True), + track_on=kw.get("track_on", None), + ) + + self._resolved_bindparams = bindparams = [] + + anon_map = traversals.anon_map() + cache_key = tuple( + [ + getter(closure, kw, anon_map, bindparams) + for getter in tracker.closure_trackers + ] + ) + if self.parent_lambda is not None: + cache_key = self.parent_lambda.closure_cache_key + cache_key + + self.closure_cache_key = cache_key + + try: + rec = lambda_cache[tracker_key + cache_key] + except KeyError: + rec = None + + if rec is None: + rec = AnalyzedFunction( + tracker, self, apply_propagate_attrs, kw, fn ) + rec.closure_bindparams = bindparams + lambda_cache[tracker_key + cache_key] = rec else: - rec = _trackers[self.fn.__code__] - closure = fn.__closure__ + bindparams[:] = [ + orig_bind._with_value(new_bind.value, maintain_key=True) + for orig_bind, new_bind in zip( + rec.closure_bindparams, bindparams + ) + ] + + if self.parent_lambda is not None: + bindparams[:0] = self.parent_lambda._resolved_bindparams - # check if the objects fixed inside the lambda that we've cached - # have been changed. This can apply to things like mappers that - # were recreated in test suites. if so, re-initialize. - # - # this is a small performance hit on every use for a not very - # common situation, however it's very hard to debug if the - # condition does occur. - for idx, obj in rec[_STALE_CHECK]: - if closure[idx].cell_contents is not obj: - rec = self._initialize_var_trackers( - role, apply_propagate_attrs, kw - ) - break self._rec = rec - if apply_propagate_attrs is not None: - propagate_attrs = rec[_PROPAGATE_ATTRS] - if propagate_attrs: - apply_propagate_attrs._propagate_attrs = propagate_attrs + lambda_element = self + while lambda_element is not None: + rec = lambda_element._rec + if rec.bindparam_trackers: + tracker_instrumented_fn = rec.tracker_instrumented_fn + for tracker in rec.bindparam_trackers: + tracker( + lambda_element.fn, tracker_instrumented_fn, bindparams + ) + lambda_element = lambda_element.parent_lambda - if rec[_TRACKERS]: - self._resolved_bindparams = bindparams = [] - for tracker in rec[_TRACKERS]: - tracker(self.fn, bindparams) + return rec def __getattr__(self, key): - return getattr(self._rec[_EXPR], key) + return getattr(self._rec.expected_expr, key) @property def _is_sequence(self): - return self._rec[_IS_SEQUENCE] + return self._rec.is_sequence @property def _select_iterable(self): @@ -169,8 +209,7 @@ class LambdaElement(elements.ClauseElement): def _param_dict(self): return {b.key: b.value for b in self._resolved_bindparams} - @util.memoized_property - def _resolved(self): + def _setup_binds_for_tracked_expr(self, expr): bindparam_lookup = {b.key: b for b in self._resolved_bindparams} def replace(thing): @@ -179,17 +218,11 @@ class LambdaElement(elements.ClauseElement): and thing.key in bindparam_lookup ): bind = bindparam_lookup[thing.key] - # TODO: consider - # if we should clone the bindparam here, re-cache the new - # version, etc. also we make an assumption about "expanding" - # in this case. if thing.expanding: bind.expanding = True return bind - expr = self._rec[_EXPR] - - if self._rec[_IS_SEQUENCE]: + if self._rec.is_sequence: expr = [ visitors.replacement_traverse(sub_expr, {}, replace) for sub_expr in expr @@ -199,9 +232,39 @@ class LambdaElement(elements.ClauseElement): return expr + def _copy_internals( + self, clone=_clone, deferred_copy_internals=None, **kw + ): + # TODO: this needs A LOT of tests + self._resolved = clone( + self._resolved, + deferred_copy_internals=deferred_copy_internals, + **kw + ) + + @util.memoized_property + def _resolved(self): + expr = self._rec.expected_expr + + if self._resolved_bindparams: + expr = self._setup_binds_for_tracked_expr(expr) + + return expr + def _gen_cache_key(self, anon_map, bindparams): - cache_key = (self.fn.__code__, self.__class__) + cache_key = ( + self.fn.__code__, + self.__class__, + ) + self.closure_cache_key + + parent = self.parent_lambda + while parent is not None: + cache_key = ( + (parent.fn.__code__,) + parent.closure_cache_key + cache_key + ) + + parent = parent.parent_lambda if self._resolved_bindparams: bindparams.extend(self._resolved_bindparams) @@ -211,101 +274,51 @@ class LambdaElement(elements.ClauseElement): def _invoke_user_fn(self, fn, *arg): return fn() - def _initialize_var_trackers(self, role, apply_propagate_attrs, coerce_kw): - fn = self.fn - # track objects referenced inside of lambdas, create bindparams - # ahead of time for literal values. If bindparams are produced, - # then rewrite the function globals and closure as necessary so that - # it refers to the bindparams, then invoke the function - new_closure = {} - new_globals = fn.__globals__.copy() - tracker_collection = [] - check_closure_for_stale = [] +class DeferredLambdaElement(LambdaElement): + """A LambdaElement where the lambda accepts arguments and is + invoked within the compile phase with special context. - for name in fn.__code__.co_names: - if name not in new_globals: - continue + This lambda doesn't normally produce its real SQL expression outside of the + compile phase. It is passed a fixed set of initial arguments + so that it can generate a sample expression. - bound_value = _roll_down_to_literal(new_globals[name]) + """ - if coercions._is_literal(bound_value): - new_globals[name] = bind = PyWrapper(name, bound_value) - tracker_collection.append(_globals_tracker(name, bind)) + def __init__(self, fn, role, lambda_args=(), **kw): + self.lambda_args = lambda_args + self.coerce_kw = kw + super(DeferredLambdaElement, self).__init__(fn, role, **kw) - if fn.__closure__: - for closure_index, (fv, cell) in enumerate( - zip(fn.__code__.co_freevars, fn.__closure__) - ): + def _invoke_user_fn(self, fn, *arg): + return fn(*self.lambda_args) - bound_value = _roll_down_to_literal(cell.cell_contents) + def _resolve_with_args(self, *lambda_args): + tracker_fn = self._rec.tracker_instrumented_fn + expr = tracker_fn(*lambda_args) - if coercions._is_literal(bound_value): - new_closure[fv] = bind = PyWrapper(fv, bound_value) - tracker_collection.append( - _closure_tracker(fv, bind, closure_index) - ) - else: - new_closure[fv] = cell.cell_contents - # for normal cell contents, add them to a list that - # we can compare later when we get new lambdas. if - # any identities have changed, then we will recalculate - # the whole lambda and run it again. - check_closure_for_stale.append( - (closure_index, cell.cell_contents) - ) + expr = coercions.expect(self.role, expr, **self.coerce_kw) - if tracker_collection: - new_fn = _rewrite_code_obj( - fn, - [new_closure[name] for name in fn.__code__.co_freevars], - new_globals, - ) - expr = self._invoke_user_fn(new_fn) + if self._resolved_bindparams: + expr = self._setup_binds_for_tracked_expr(expr) - else: - new_fn = fn - expr = self._invoke_user_fn(new_fn) - tracker_collection = [] + # TODO: TEST TEST TEST, this is very out there + for deferred_copy_internals in self._transforms: + expr = deferred_copy_internals(expr) - if self.parent_lambda is None: - if isinstance(expr, collections_abc.Sequence): - expected_expr = [ - coercions.expect( - role, - sub_expr, - apply_propagate_attrs=apply_propagate_attrs, - **coerce_kw - ) - for sub_expr in expr - ] - is_sequence = True - else: - expected_expr = coercions.expect( - role, - expr, - apply_propagate_attrs=apply_propagate_attrs, - **coerce_kw - ) - is_sequence = False - else: - expected_expr = expr - is_sequence = False + return expr - if apply_propagate_attrs is not None: - propagate_attrs = apply_propagate_attrs._propagate_attrs - else: - propagate_attrs = util.immutabledict() - - rec = _trackers[self.fn.__code__] = ( - tracker_collection, - check_closure_for_stale, - new_fn, - expected_expr, - is_sequence, - propagate_attrs, + def _copy_internals( + self, clone=_clone, deferred_copy_internals=None, **kw + ): + super(DeferredLambdaElement, self)._copy_internals( + clone=clone, deferred_copy_internals=deferred_copy_internals, **kw ) - return rec + + # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know + # our expression yet. so hold onto the replacement + if deferred_copy_internals: + self._transforms += (deferred_copy_internals,) class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): @@ -334,13 +347,38 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): """ + def __init__(self, fn, parent_lambda, **kw): + self._default_kw = default_kw = {} + global_track_bound_values = kw.pop("global_track_bound_values", None) + if global_track_bound_values is not None: + default_kw["track_bound_values"] = global_track_bound_values + kw["track_bound_values"] = global_track_bound_values + + if "lambda_cache" in kw: + default_kw["lambda_cache"] = kw["lambda_cache"] + + super(StatementLambdaElement, self).__init__(fn, parent_lambda, **kw) + def __add__(self, other): - return LinkedLambdaElement(other, parent_lambda=self) + return LinkedLambdaElement( + other, parent_lambda=self, **self._default_kw + ) + + def add_criteria(self, other, **kw): + if self._default_kw: + if kw: + default_kw = self._default_kw.copy() + default_kw.update(kw) + kw = default_kw + else: + kw = self._default_kw + + return LinkedLambdaElement(other, parent_lambda=self, **kw) def _execute_on_connection( self, connection, multiparams, params, execution_options ): - if self._rec[_EXPR].supports_execution: + if self._rec.expected_expr.supports_execution: return connection._execute_clauseelement( self, multiparams, params, execution_options ) @@ -349,93 +387,579 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): @property def _with_options(self): - return self._rec[_EXPR]._with_options + return self._rec.expected_expr._with_options @property def _effective_plugin_target(self): - return self._rec[_EXPR]._effective_plugin_target + return self._rec.expected_expr._effective_plugin_target @property def _is_future(self): - return self._rec[_EXPR]._is_future + return self._rec.expected_expr._is_future @property def _execution_options(self): - return self._rec[_EXPR]._execution_options + return self._rec.expected_expr._execution_options + + def spoil(self): + """Return a new :class:`.StatementLambdaElement` that will run + all lambdas unconditionally each time. + + """ + return NullLambdaStatement(self.fn()) + + +class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement): + """Provides the :class:`.StatementLambdaElement` API but does not + cache or analyze lambdas. + + the lambdas are instead invoked immediately. + + The intended use is to isolate issues that may arise when using + lambda statements. + + """ + + __visit_name__ = "lambda_element" + + _is_lambda_element = True + + _traverse_internals = [ + ("_resolved", visitors.InternalTraversal.dp_clauseelement) + ] + + def __init__(self, statement): + self._resolved = statement + self._propagate_attrs = statement._propagate_attrs + + def __getattr__(self, key): + return getattr(self._resolved, key) + + def __add__(self, other): + statement = other(self._resolved) + + return NullLambdaStatement(statement) + + def add_criteria(self, other, **kw): + statement = other(self._resolved) + + return NullLambdaStatement(statement) + + def _execute_on_connection( + self, connection, multiparams, params, execution_options + ): + if self._resolved.supports_execution: + return connection._execute_clauseelement( + self, multiparams, params, execution_options + ) + else: + raise exc.ObjectNotExecutableError(self) class LinkedLambdaElement(StatementLambdaElement): + """Represent subsequent links of a :class:`.StatementLambdaElement`.""" + + role = None + def __init__(self, fn, parent_lambda, **kw): + self._default_kw = parent_lambda._default_kw + self.fn = fn self.parent_lambda = parent_lambda - role = None - apply_propagate_attrs = self + self.tracker_key = parent_lambda.tracker_key + (fn.__code__,) + self._retrieve_tracker_rec(fn, self, kw) + self._propagate_attrs = parent_lambda._propagate_attrs + + def _invoke_user_fn(self, fn, *arg): + return fn(self.parent_lambda._resolved) + + +class AnalyzedCode(object): + __slots__ = ( + "track_closure_variables", + "track_bound_values", + "bindparam_trackers", + "closure_trackers", + "build_py_wrappers", + ) + _fns = weakref.WeakKeyDictionary() + + @classmethod + def get(cls, fn, lambda_element, lambda_kw, **kw): + try: + # TODO: validate kw haven't changed? + return cls._fns[fn.__code__] + except KeyError: + pass + cls._fns[fn.__code__] = analyzed = AnalyzedCode( + fn, lambda_element, lambda_kw, **kw + ) + return analyzed + + def __init__( + self, + fn, + lambda_element, + lambda_kw, + track_bound_values=True, + enable_tracking=True, + track_on=None, + ): + closure = fn.__closure__ + + self.track_closure_variables = not track_on + + self.track_bound_values = track_bound_values + + # a list of callables generated from _bound_parameter_getter_* + # functions. Each of these uses a PyWrapper object to retrieve + # a parameter value + self.bindparam_trackers = [] + + # a list of callables generated from _cache_key_getter_* functions + # these callables work to generate a cache key for the lambda + # based on what's inside its closure variables. + self.closure_trackers = [] - if fn.__code__ not in _trackers: - rec = self._initialize_var_trackers( - role, apply_propagate_attrs, kw + self.build_py_wrappers = [] + + if enable_tracking: + if track_on: + self._init_track_on(track_on) + + self._init_globals(fn) + + if closure: + self._init_closure(fn) + + self._setup_additional_closure_trackers(fn, lambda_element, lambda_kw) + + def _init_track_on(self, track_on): + self.closure_trackers.extend( + self._cache_key_getter_track_on(idx, elem) + for idx, elem in enumerate(track_on) + ) + + def _init_globals(self, fn): + build_py_wrappers = self.build_py_wrappers + bindparam_trackers = self.bindparam_trackers + track_bound_values = self.track_bound_values + + for name in fn.__code__.co_names: + if name not in fn.__globals__: + continue + + _bound_value = self._roll_down_to_literal(fn.__globals__[name]) + + if coercions._deep_is_literal(_bound_value): + build_py_wrappers.append((name, None)) + if track_bound_values: + bindparam_trackers.append( + self._bound_parameter_getter_func_globals(name) + ) + + def _init_closure(self, fn): + build_py_wrappers = self.build_py_wrappers + closure = fn.__closure__ + + track_bound_values = self.track_bound_values + track_closure_variables = self.track_closure_variables + bindparam_trackers = self.bindparam_trackers + closure_trackers = self.closure_trackers + + for closure_index, (fv, cell) in enumerate( + zip(fn.__code__.co_freevars, closure) + ): + _bound_value = self._roll_down_to_literal(cell.cell_contents) + + if coercions._deep_is_literal(_bound_value): + build_py_wrappers.append((fv, closure_index)) + if track_bound_values: + bindparam_trackers.append( + self._bound_parameter_getter_func_closure( + fv, closure_index + ) + ) + else: + # for normal cell contents, add them to a list that + # we can compare later when we get new lambdas. if + # any identities have changed, then we will + # recalculate the whole lambda and run it again. + + if track_closure_variables: + closure_trackers.append( + self._cache_key_getter_closure_variable( + closure_index, cell.cell_contents + ) + ) + + def _setup_additional_closure_trackers( + self, fn, lambda_element, lambda_kw + ): + # an additional step is to actually run the function, then + # go through the PyWrapper objects that were set up to catch a bound + # parameter. then if they *didn't* make a param, oh they're another + # object in the closure we have to track for our cache key. so + # create trackers to catch those. + + analyzed_function = AnalyzedFunction( + self, lambda_element, None, lambda_kw, fn, + ) + + closure_trackers = self.closure_trackers + + for pywrapper in analyzed_function.closure_pywrappers: + if not pywrapper._sa__has_param: + closure_trackers.append( + self._cache_key_getter_tracked_literal(pywrapper) + ) + + @classmethod + def _roll_down_to_literal(cls, element): + is_clause_element = hasattr(element, "__clause_element__") + + if is_clause_element: + while not isinstance( + element, (elements.ClauseElement, schema.SchemaItem) + ): + try: + element = element.__clause_element__() + except AttributeError: + break + + if not is_clause_element: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + try: + return insp.__clause_element__() + except AttributeError: + return insp + + # TODO: should we coerce consts None/True/False here? + return element + else: + return element + + def _bound_parameter_getter_func_globals(self, name): + """Return a getter that will extend a list of bound parameters + with new entries from the ``__globals__`` collection of a particular + lambda. + + """ + + def extract_parameter_value( + current_fn, tracker_instrumented_fn, result + ): + wrapper = tracker_instrumented_fn.__globals__[name] + object.__getattribute__(wrapper, "_extract_bound_parameters")( + current_fn.__globals__[name], result + ) + + return extract_parameter_value + + def _bound_parameter_getter_func_closure(self, name, closure_index): + """Return a getter that will extend a list of bound parameters + with new entries from the ``__closure__`` collection of a particular + lambda. + + """ + + def extract_parameter_value( + current_fn, tracker_instrumented_fn, result + ): + wrapper = tracker_instrumented_fn.__closure__[ + closure_index + ].cell_contents + object.__getattribute__(wrapper, "_extract_bound_parameters")( + current_fn.__closure__[closure_index].cell_contents, result + ) + + return extract_parameter_value + + def _cache_key_getter_track_on(self, idx, elem): + """Return a getter that will extend a cache key with new entries + from the "track_on" parameter passed to a :class:`.LambdaElement`. + + """ + if isinstance(elem, traversals.HasCacheKey): + + def get(closure, kw, anon_map, bindparams): + return kw["track_on"][idx]._gen_cache_key(anon_map, bindparams) + + else: + + def get(closure, kw, anon_map, bindparams): + return kw["track_on"][idx] + + return get + + def _cache_key_getter_closure_variable(self, idx, cell_contents): + """Return a getter that will extend a cache key with new entries + from the ``__closure__`` collection of a particular lambda. + + """ + + if isinstance(cell_contents, traversals.HasCacheKey): + + def get(closure, kw, anon_map, bindparams): + return closure[idx].cell_contents._gen_cache_key( + anon_map, bindparams + ) + + elif isinstance(cell_contents, types.FunctionType): + + def get(closure, kw, anon_map, bindparams): + return closure[idx].cell_contents.__code__ + + elif cell_contents.__hash__ is None: + # this covers dict, etc. + def get(closure, kw, anon_map, bindparams): + return () + + else: + + def get(closure, kw, anon_map, bindparams): + return closure[idx].cell_contents + + return get + + def _cache_key_getter_tracked_literal(self, pytracker): + """Return a getter that will extend a cache key with new entries + from the ``__closure__`` collection of a particular lambda. + + this getter differs from _cache_key_getter_closure_variable + in that these are detected after the function is run, and PyWrapper + objects have recorded that a particular literal value is in fact + not being interpreted as a bound parameter. + + """ + + elem = pytracker._sa__to_evaluate + closure_index = pytracker._sa__closure_index + + if isinstance(elem, set): + raise exc.ArgumentError( + "Can't create a cache key for lambda closure variable " + '"%s" because it\'s a set. try using a list' + % pytracker._sa__name ) + + elif isinstance(elem, list): + + def get(closure, kw, anon_map, bindparams): + return tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in closure[closure_index].cell_contents + ) + + elif elem.__hash__ is None: + # this covers dict, etc. + def get(closure, kw, anon_map, bindparams): + return () + else: - rec = _trackers[self.fn.__code__] + def get(closure, kw, anon_map, bindparams): + return closure[closure_index].cell_contents + + return get + + +class AnalyzedFunction(object): + __slots__ = ( + "analyzed_code", + "fn", + "closure_pywrappers", + "tracker_instrumented_fn", + "expr", + "bindparam_trackers", + "expected_expr", + "is_sequence", + "propagate_attrs", + "closure_bindparams", + ) + + def __init__( + self, analyzed_code, lambda_element, apply_propagate_attrs, kw, fn, + ): + self.analyzed_code = analyzed_code + self.fn = fn + + self.bindparam_trackers = analyzed_code.bindparam_trackers + + self._instrument_and_run_function(lambda_element) + + self._coerce_expression(lambda_element, apply_propagate_attrs, kw) + + def _instrument_and_run_function(self, lambda_element): + analyzed_code = self.analyzed_code + + fn = self.fn + self.closure_pywrappers = closure_pywrappers = [] + + build_py_wrappers = analyzed_code.build_py_wrappers + + if not build_py_wrappers: + self.tracker_instrumented_fn = tracker_instrumented_fn = fn + self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) + else: + track_closure_variables = analyzed_code.track_closure_variables closure = fn.__closure__ - # check if objects referred to by the lambda have changed and - # re-scan the lambda if so. see comments for this same section in - # LambdaElement. - for idx, obj in rec[_STALE_CHECK]: - if closure[idx].cell_contents is not obj: - rec = self._initialize_var_trackers( - role, apply_propagate_attrs, kw + # will form the __closure__ of the function when we rebuild it + if closure: + new_closure = { + fv: cell.cell_contents + for fv, cell in zip(fn.__code__.co_freevars, closure) + } + else: + new_closure = {} + + # will form the __globals__ of the function when we rebuild it + new_globals = fn.__globals__.copy() + + for name, closure_index in build_py_wrappers: + if closure_index is not None: + value = closure[closure_index].cell_contents + new_closure[name] = bind = PyWrapper( + name, value, closure_index=closure_index ) - break + if track_closure_variables: + closure_pywrappers.append(bind) + else: + value = fn.__globals__[name] + new_globals[name] = bind = PyWrapper(name, value) + + # rewrite the original fn. things that look like they will + # become bound parameters are wrapped in a PyWrapper. + self.tracker_instrumented_fn = ( + tracker_instrumented_fn + ) = self._rewrite_code_obj( + fn, + [new_closure[name] for name in fn.__code__.co_freevars], + new_globals, + ) - self._rec = rec + # now invoke the function. This will give us a new SQL + # expression, but all the places that there would be a bound + # parameter, the PyWrapper in its place will give us a bind + # with a predictable name we can match up later. - self._propagate_attrs = parent_lambda._propagate_attrs + # additionally, each PyWrapper will log that it did in fact + # create a parameter, otherwise, it's some kind of Python + # object in the closure and we want to track that, to make + # sure it doesn't change to somehting else, or if it does, + # that we create a different tracked function with that + # variable. + self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) - self._resolved_bindparams = bindparams = [] - rec = self._rec - while True: - if rec[_TRACKERS]: - for tracker in rec[_TRACKERS]: - tracker(self.fn, bindparams) - if self.parent_lambda is not None: - self = self.parent_lambda - rec = self._rec + def _coerce_expression(self, lambda_element, apply_propagate_attrs, kw): + """Run the tracker-generated expression through coercion rules. + + After the user-defined lambda has been invoked to produce a statement + for re-use, run it through coercion rules to both check that it's the + correct type of object and also to coerce it to its useful form. + + """ + + parent_lambda = lambda_element.parent_lambda + expr = self.expr + + if parent_lambda is None: + if isinstance(expr, collections_abc.Sequence): + self.expected_expr = [ + coercions.expect( + lambda_element.role, + sub_expr, + apply_propagate_attrs=apply_propagate_attrs, + **kw + ) + for sub_expr in expr + ] + self.is_sequence = True else: - break + self.expected_expr = coercions.expect( + lambda_element.role, + expr, + apply_propagate_attrs=apply_propagate_attrs, + **kw + ) + self.is_sequence = False + else: + self.expected_expr = expr + self.is_sequence = False - def _invoke_user_fn(self, fn, *arg): - return fn(self.parent_lambda._rec[_EXPR]) + if apply_propagate_attrs is not None: + self.propagate_attrs = apply_propagate_attrs._propagate_attrs + else: + self.propagate_attrs = util.EMPTY_DICT - def _gen_cache_key(self, anon_map, bindparams): - if self._resolved_bindparams: - bindparams.extend(self._resolved_bindparams) + def _rewrite_code_obj(self, f, cell_values, globals_): + """Return a copy of f, with a new closure and new globals - cache_key = (self.fn.__code__, self.__class__) + yes it works in pypy :P - parent = self.parent_lambda - while parent is not None: - cache_key = (parent.fn.__code__,) + cache_key - parent = parent.parent_lambda + """ - return cache_key + argrange = range(len(cell_values)) + + code = "def make_cells():\n" + if cell_values: + code += " (%s) = (%s)\n" % ( + ", ".join("i%d" % i for i in argrange), + ", ".join("o%d" % i for i in argrange), + ) + code += " def closure():\n" + code += " return %s\n" % ", ".join("i%d" % i for i in argrange) + code += " return closure.__closure__" + vars_ = {"o%d" % i: cell_values[i] for i in argrange} + exec(code, vars_, vars_) + closure = vars_["make_cells"]() + + func = type(f)( + f.__code__, globals_, f.__name__, f.__defaults__, closure + ) + if sys.version_info >= (3,): + func.__annotations__ = f.__annotations__ + func.__kwdefaults__ = f.__kwdefaults__ + func.__doc__ = f.__doc__ + func.__module__ = f.__module__ + + return func class PyWrapper(ColumnOperators): - def __init__(self, name, to_evaluate, getter=None): + """A wrapper object that is injected into the ``__globals__`` and + ``__closure__`` of a Python function. + + When the function is instrumented with :class:`.PyWrapper` objects, it is + then invoked just once in order to set up the wrappers. We look through + all the :class:`.PyWrapper` objects we made to find the ones that generated + a :class:`.BindParameter` object, e.g. the expression system interpreted + something as a literal. Those positions in the globals/closure are then + ones that we will look at, each time a new lambda comes in that refers to + the same ``__code__`` object. In this way, we keep a single version of + the SQL expression that this lambda produced, without calling upon the + Python function that created it more than once, unless its other closure + variables have changed. The expression is then transformed to have the + new bound values embedded into it. + + """ + + def __init__(self, name, to_evaluate, closure_index=None, getter=None): self._name = name self._to_evaluate = to_evaluate self._param = None + self._has_param = False self._bind_paths = {} self._getter = getter + self._closure_index = closure_index def __call__(self, *arg, **kw): elem = object.__getattribute__(self, "_to_evaluate") value = elem(*arg, **kw) - if coercions._is_literal(value) and not isinstance( + if coercions._deep_is_literal(value) and not isinstance( # TODO: coverage where an ORM option or similar is here value, traversals.HasCacheKey, @@ -481,8 +1005,8 @@ class PyWrapper(ColumnOperators): if param is None: name = object.__getattribute__(self, "_name") self._param = param = elements.BindParameter(name, unique=True) + self._has_param = True param.type = type_api._resolve_value_to_type(to_evaluate) - return param._with_value(to_evaluate, maintain_key=True) def __getattribute__(self, key): @@ -497,7 +1021,15 @@ class PyWrapper(ColumnOperators): else: return self._sa__add_getter(key, operator.attrgetter) + def __iter__(self): + elem = object.__getattribute__(self, "_to_evaluate") + return iter(elem) + def __getitem__(self, key): + elem = object.__getattribute__(self, "_to_evaluate") + if not hasattr(elem, "__getitem__"): + raise AttributeError("__getitem__") + if isinstance(key, PyWrapper): # TODO: coverage raise exc.InvalidRequestError( @@ -518,90 +1050,14 @@ class PyWrapper(ColumnOperators): elem = object.__getattribute__(self, "_to_evaluate") value = getter(elem) - if coercions._is_literal(value): - wrapper = PyWrapper(key, value, getter) + if coercions._deep_is_literal(value): + wrapper = PyWrapper(key, value, getter=getter) bind_paths[bind_path_key] = wrapper return wrapper else: return value -def _roll_down_to_literal(element): - is_clause_element = hasattr(element, "__clause_element__") - - if is_clause_element: - while not isinstance( - element, (elements.ClauseElement, schema.SchemaItem) - ): - try: - element = element.__clause_element__() - except AttributeError: - break - - if not is_clause_element: - insp = inspection.inspect(element, raiseerr=False) - if insp is not None: - try: - return insp.__clause_element__() - except AttributeError: - return insp - - # TODO: should we coerce consts None/True/False here? - return element - else: - return element - - -def _globals_tracker(name, wrapper): - def extract_parameter_value(current_fn, result): - object.__getattribute__(wrapper, "_extract_bound_parameters")( - current_fn.__globals__[name], result - ) - - return extract_parameter_value - - -def _closure_tracker(name, wrapper, closure_index): - def extract_parameter_value(current_fn, result): - object.__getattribute__(wrapper, "_extract_bound_parameters")( - current_fn.__closure__[closure_index].cell_contents, result - ) - - return extract_parameter_value - - -def _rewrite_code_obj(f, cell_values, globals_): - """Return a copy of f, with a new closure and new globals - - yes it works in pypy :P - - """ - - argrange = range(len(cell_values)) - - code = "def make_cells():\n" - if cell_values: - code += " (%s) = (%s)\n" % ( - ", ".join("i%d" % i for i in argrange), - ", ".join("o%d" % i for i in argrange), - ) - code += " def closure():\n" - code += " return %s\n" % ", ".join("i%d" % i for i in argrange) - code += " return closure.__closure__" - vars_ = {"o%d" % i: cell_values[i] for i in argrange} - exec(code, vars_, vars_) - closure = vars_["make_cells"]() - - func = type(f)(f.__code__, globals_, f.__name__, f.__defaults__, closure) - if sys.version_info >= (3,): - func.__annotations__ = f.__annotations__ - func.__kwdefaults__ = f.__kwdefaults__ - func.__doc__ = f.__doc__ - func.__module__ = f.__module__ - - return func - - @inspection._inspects(LambdaElement) def insp(lmb): return inspection.inspect(lmb._resolved) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 1155c273b4..d67b617434 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2426,6 +2426,7 @@ class SelectBase( """ _is_select_statement = True + is_select = True def _generate_fromclause_column_proxies(self, fromclause): # type: (FromClause) -> None @@ -3867,19 +3868,19 @@ class Select( [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), ("_from_obj", InternalTraversal.dp_clauseelement_list), - ("_where_criteria", InternalTraversal.dp_clauseelement_list), - ("_having_criteria", InternalTraversal.dp_clauseelement_list), - ("_order_by_clauses", InternalTraversal.dp_clauseelement_list,), - ("_group_by_clauses", InternalTraversal.dp_clauseelement_list,), + ("_where_criteria", InternalTraversal.dp_clauseelement_tuple), + ("_having_criteria", InternalTraversal.dp_clauseelement_tuple), + ("_order_by_clauses", InternalTraversal.dp_clauseelement_tuple,), + ("_group_by_clauses", InternalTraversal.dp_clauseelement_tuple,), ("_setup_joins", InternalTraversal.dp_setup_join_tuple,), ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple,), - ("_correlate", InternalTraversal.dp_clauseelement_list), - ("_correlate_except", InternalTraversal.dp_clauseelement_list,), + ("_correlate", InternalTraversal.dp_clauseelement_tuple), + ("_correlate_except", InternalTraversal.dp_clauseelement_tuple,), ("_limit_clause", InternalTraversal.dp_clauseelement), ("_offset_clause", InternalTraversal.dp_clauseelement), ("_for_update_arg", InternalTraversal.dp_clauseelement), ("_distinct", InternalTraversal.dp_boolean), - ("_distinct_on", InternalTraversal.dp_clauseelement_list), + ("_distinct_on", InternalTraversal.dp_clauseelement_tuple), ("_label_style", InternalTraversal.dp_plain_obj), ("_is_future", InternalTraversal.dp_boolean), ] @@ -4345,7 +4346,7 @@ class Select( @_generative def join(self, target, onclause=None, isouter=False, full=False): - r"""Create a SQL JOIN against this :class:`_expresson.Select` + r"""Create a SQL JOIN against this :class:`_expression.Select` object's criterion and apply generatively, returning the newly resulting :class:`_expression.Select`. @@ -4474,7 +4475,7 @@ class Select( # they've become. This allows us to ensure the same cloned from # is used when other items such as columns are "cloned" - all_the_froms = list( + all_the_froms = set( itertools.chain( _from_objects(*self._raw_columns), _from_objects(*self._where_criteria), @@ -4490,10 +4491,15 @@ class Select( new_froms = {f: clone(f, **kw) for f in all_the_froms} # 2. copy FROM collections, adding in joins that we've created. - self._from_obj = tuple(clone(f, **kw) for f in self._from_obj) + tuple( - f for f in new_froms.values() if isinstance(f, Join) + existing_from_obj = [clone(f, **kw) for f in self._from_obj] + add_froms = ( + set(f for f in new_froms.values() if isinstance(f, Join)) + .difference(all_the_froms) + .difference(existing_from_obj) ) + self._from_obj = tuple(existing_from_obj) + tuple(add_froms) + # 3. clone everything else, making sure we use columns # corresponding to the froms we just made. def replace(obj, **kw): @@ -4687,6 +4693,7 @@ class Select( """ + assert isinstance(self._where_criteria, tuple) self._where_criteria += ( coercions.expect(roles.WhereHavingRole, whereclause), ) @@ -5371,6 +5378,9 @@ class TextualSelect(SelectBase): _is_textual = True + is_text = True + is_select = True + def __init__(self, text, columns, positional=False): self.element = text # convert for ORM attributes->columns, etc diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index f41480a947..cb38df6afa 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -190,7 +190,10 @@ class HasCacheKey(object): # statements, not so much, but they usually won't have # annotations. result += self._annotations_cache_key - elif meth is InternalTraversal.dp_clauseelement_list: + elif ( + meth is InternalTraversal.dp_clauseelement_list + or meth is InternalTraversal.dp_clauseelement_tuple + ): result += ( attrname, tuple( @@ -390,6 +393,7 @@ class _CacheKey(ExtendedInternalTraversal): visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY visit_clauseelement_list = InternalTraversal.dp_clauseelement_list visit_annotations_key = InternalTraversal.dp_annotations_key + visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple visit_string = ( visit_boolean @@ -451,6 +455,8 @@ class _CacheKey(ExtendedInternalTraversal): tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), ) + visit_executable_options = visit_has_cache_key_list + def visit_inspectable_list( self, attrname, obj, parent, anon_map, bindparams ): @@ -682,6 +688,41 @@ class _CacheKey(ExtendedInternalTraversal): _cache_key_traversal_visitor = _CacheKey() +class HasCopyInternals(object): + def _clone(self, **kw): + raise NotImplementedError() + + def _copy_internals(self, omit_attrs=(), **kw): + """Reassign internal elements to be clones of themselves. + + Called during a copy-and-traverse operation on newly + shallow-copied elements to create a deep copy. + + The given clone function should be used, which may be applying + additional transformations to the element (i.e. replacement + traversal, cloned traversal, annotations). + + """ + + try: + traverse_internals = self._traverse_internals + except AttributeError: + # user-defined classes may not have a _traverse_internals + return + + for attrname, obj, meth in _copy_internals.run_generated_dispatch( + self, traverse_internals, "_generated_copy_internals_traversal" + ): + if attrname in omit_attrs: + continue + + if obj is not None: + + result = meth(attrname, self, obj, **kw) + if result is not None: + setattr(self, attrname, result) + + class _CopyInternals(InternalTraversal): """Generate a _copy_internals internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -696,6 +737,16 @@ class _CopyInternals(InternalTraversal): ): return [clone(clause, **kw) for clause in element] + def visit_clauseelement_tuple( + self, attrname, parent, element, clone=_clone, **kw + ): + return tuple([clone(clause, **kw) for clause in element]) + + def visit_executable_options( + self, attrname, parent, element, clone=_clone, **kw + ): + return tuple([clone(clause, **kw) for clause in element]) + def visit_clauseelement_unordered_set( self, attrname, parent, element, clone=_clone, **kw ): @@ -817,6 +868,9 @@ class _GetChildren(InternalTraversal): def visit_clauseelement_list(self, element, **kw): return element + def visit_clauseelement_tuple(self, element, **kw): + return element + def visit_clauseelement_tuples(self, element, **kw): return itertools.chain.from_iterable(element) @@ -840,8 +894,8 @@ class _GetChildren(InternalTraversal): if not isinstance(target, str): yield _flatten_clauseelement(target) - # if onclause is not None and not isinstance(onclause, str): - # yield _flatten_clauseelement(onclause) + if onclause is not None and not isinstance(onclause, str): + yield _flatten_clauseelement(onclause) def visit_dml_ordered_values(self, element, **kw): for k, v in element: @@ -1015,6 +1069,8 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): ): return COMPARE_FAILED + visit_executable_options = visit_has_cache_key_list + def visit_clauseelement( self, attrname, left_parent, left, right_parent, right, **kw ): @@ -1057,6 +1113,12 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): for l, r in util.zip_longest(left, right, fillvalue=None): self.stack.append((l, r)) + def visit_clauseelement_tuple( + self, attrname, left_parent, left, right_parent, right, **kw + ): + for l, r in util.zip_longest(left, right, fillvalue=None): + self.stack.append((l, r)) + def _compare_unordered_sequences(self, seq1, seq2, **kw): if seq1 is None: return seq2 is None diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 56d3c93b3c..5cb3cba709 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -257,7 +257,7 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): """ - dp_clauseelement_tuples = symbol("CT") + dp_clauseelement_tuples = symbol("CTS") """Visit a list of tuples which contain :class:`_expression.ClauseElement` objects. @@ -268,6 +268,13 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): """ + dp_clauseelement_tuple = symbol("CT") + """Visit a tuple of :class:`_expression.ClauseElement` objects. + + """ + + dp_executable_options = symbol("EO") + dp_fromclause_ordered_set = symbol("CO") """Visit an ordered set of :class:`_expression.FromClause` objects. """ @@ -712,6 +719,9 @@ def cloned_traverse(obj, opts, visitors): cloned = {} stop_on = set(opts.get("stop_on", [])) + def deferred_copy_internals(obj): + return cloned_traverse(obj, opts, visitors) + def clone(elem, **kw): if elem in stop_on: return elem @@ -732,7 +742,7 @@ def cloned_traverse(obj, opts, visitors): return cloned[id(elem)] if obj is not None: - obj = clone(obj) + obj = clone(obj, deferred_copy_internals=deferred_copy_internals) clone = None # remove gc cycles return obj @@ -764,6 +774,9 @@ def replacement_traverse(obj, opts, replace): cloned = {} stop_on = {id(x) for x in opts.get("stop_on", [])} + def deferred_copy_internals(obj): + return replacement_traverse(obj, opts, replace) + def clone(elem, **kw): if ( id(elem) in stop_on @@ -776,19 +789,24 @@ def replacement_traverse(obj, opts, replace): stop_on.add(id(newelem)) return newelem else: - - if elem not in cloned: + # base "already seen" on id(), not hash, so that we don't + # replace an Annotated element with its non-annotated one, and + # vice versa + id_elem = id(elem) + if id_elem not in cloned: if "replace" in kw: newelem = kw["replace"](elem) if newelem is not None: - cloned[elem] = newelem + cloned[id_elem] = newelem return newelem - cloned[elem] = newelem = elem._clone() + cloned[id_elem] = newelem = elem._clone() newelem._copy_internals(clone=clone, **kw) - return cloned[elem] + return cloned[id_elem] if obj is not None: - obj = clone(obj, **opts) + obj = clone( + obj, deferred_copy_internals=deferred_copy_internals, **opts + ) clone = None # remove gc cycles return obj diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index cea9c4f667..42992af130 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -12,9 +12,11 @@ from functools import partial # noqa from functools import update_wrapper # noqa from ._collections import coerce_generator_arg # noqa +from ._collections import coerce_to_immutabledict # noqa from ._collections import collections_abc # noqa from ._collections import column_dict # noqa from ._collections import column_set # noqa +from ._collections import EMPTY_DICT # noqa from ._collections import EMPTY_SET # noqa from ._collections import FacadeDict # noqa from ._collections import flatten_iterator # noqa diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 6056864947..7c109b358e 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -49,13 +49,25 @@ def _immutabledict_py_fallback(): def __reduce__(self): return _immutabledict_reconstructor, (dict(self),) - def union(self, d): - if not d: + def union(self, __d=None): + if not __d: return self new = dict.__new__(self.__class__) dict.__init__(new, self) - dict.update(new, d) + dict.update(new, __d) + return new + + def _union_w_kw(self, __d=None, **kw): + # not sure if C version works correctly w/ this yet + if not __d and not kw: + return self + + new = dict.__new__(self.__class__) + dict.__init__(new, self) + if __d: + dict.update(new, __d) + dict.update(new, kw) return new def merge_with(self, *dicts): @@ -90,6 +102,18 @@ except ImportError: return immutabledict(*arg) +def coerce_to_immutabledict(d): + if not d: + return EMPTY_DICT + elif isinstance(d, immutabledict): + return d + else: + return immutabledict(d) + + +EMPTY_DICT = immutabledict() + + class FacadeDict(ImmutableContainer, dict): """A dictionary that is not publicly mutable.""" diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 7ac716dbeb..b573accbd9 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -131,6 +131,11 @@ class MyEntity(HasCacheKey): ] +class Foo: + x = 10 + y = 15 + + dml.Insert.argument_for("sqlite", "foo", None) dml.Update.argument_for("sqlite", "foo", None) dml.Delete.argument_for("sqlite", "foo", None) @@ -790,7 +795,7 @@ class CoreFixtures(object): def two(): r = random.randint(1, 10) - q = 20 + q = 408 return LambdaElement( lambda: table_a.c.a + q == r, roles.WhereHavingRole ) @@ -803,10 +808,6 @@ class CoreFixtures(object): roles.WhereHavingRole, ) - class Foo: - x = 10 - y = 15 - def four(): return LambdaElement( lambda: and_(table_a.c.a == Foo.x), roles.WhereHavingRole @@ -833,6 +834,16 @@ class CoreFixtures(object): lambda s: s.where(table_a.c.a == value) ) + from sqlalchemy.sql import lambdas + + def eight(): + q = 5 + return lambdas.DeferredLambdaElement( + lambda t: t.c.a > q, + roles.WhereHavingRole, + lambda_args=(table_a,), + ) + return [ one(), two(), @@ -841,6 +852,7 @@ class CoreFixtures(object): five(), six(), seven(), + eight(), ] dont_compare_values_fixtures.append(_lambda_fixtures) diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index aefcaf252c..4918afc9c6 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -791,6 +791,90 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): "JOIN table2 ON table1.col1 = table2.col2) AS anon_1", ) + def test_this_thing_using_setup_joins_three(self): + + j = t1.join(t2, t1.c.col1 == t2.c.col2) + + s1 = select(j) + + s2 = s1.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s2, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + vis = sql_util.ClauseAdapter(j) + + s3 = vis.traverse(s1) + + s4 = s3.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s4, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + s5 = vis.traverse(s3) + + s6 = s5.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s6, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + def test_this_thing_using_setup_joins_four(self): + + j = t1.join(t2, t1.c.col1 == t2.c.col2) + + s1 = select(j) + + assert not s1._from_obj + + s2 = s1.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s2, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + s3 = visitors.replacement_traverse(s1, {}, lambda elem: None) + + s4 = s3.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s4, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + + s5 = visitors.replacement_traverse(s3, {}, lambda elem: None) + + s6 = s5.join(t3, t1.c.col1 == t3.c.col1) + + self.assert_compile( + s6, + "SELECT table1.col1, table1.col2, table1.col3, " + "table2.col1, table2.col2, table2.col3 FROM table1 " + "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 " + "ON table3.col1 = table1.col1", + ) + def test_select_fromtwice_one(self): t1a = t1.alias() diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index 53f6a9544c..a91242de5e 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -5,6 +5,7 @@ from sqlalchemy.schema import Column from sqlalchemy.schema import ForeignKey from sqlalchemy.schema import Table from sqlalchemy.sql import and_ +from sqlalchemy.sql import bindparam from sqlalchemy.sql import coercions from sqlalchemy.sql import column from sqlalchemy.sql import join @@ -19,6 +20,7 @@ from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import ne_ from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.types import Integer from sqlalchemy.types import String @@ -46,10 +48,39 @@ class DeferredLambdaTest( go(), "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :x_1 AND t1.p = :y_1" ) + def test_global_tracking(self): + t1 = table("t1", column("q"), column("p")) + + global global_x, global_y + + global_x = 10 + global_y = 17 + + def go(): + return select([t1]).where( + lambda: and_(t1.c.q == global_x, t1.c.p == global_y) + ) + + self.assert_compile( + go(), + "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :global_x_1 " + "AND t1.p = :global_y_1", + checkparams={"global_x_1": 10, "global_y_1": 17}, + ) + + global_y = 9 + + self.assert_compile( + go(), + "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :global_x_1 " + "AND t1.p = :global_y_1", + checkparams={"global_x_1": 10, "global_y_1": 9}, + ) + def test_stale_checker_embedded(self): def go(x): - stmt = select([lambda: x]) + stmt = select(lambda: x) return stmt c1 = column("x") @@ -67,7 +98,7 @@ class DeferredLambdaTest( def test_stale_checker_statement(self): def go(x): - stmt = lambdas.lambda_stmt(lambda: select([x])) + stmt = lambdas.lambda_stmt(lambda: select(x)) return stmt c1 = column("x") @@ -85,13 +116,13 @@ class DeferredLambdaTest( def test_stale_checker_linked(self): def go(x, y): - stmt = lambdas.lambda_stmt(lambda: select([x])) + ( + stmt = lambdas.lambda_stmt(lambda: select(x)) + ( lambda s: s.where(y > 5) ) return stmt - c1 = column("x") - c2 = column("y") + c1 = oldc1 = column("x") + c2 = oldc2 = column("y") s1 = go(c1, c2) s2 = go(c1, c2) @@ -104,6 +135,426 @@ class DeferredLambdaTest( s3 = go(c1, c2) self.assert_compile(s3, "SELECT q WHERE p > :p_1") + s4 = go(c1, c2) + self.assert_compile(s4, "SELECT q WHERE p > :p_1") + + s5 = go(oldc1, oldc2) + self.assert_compile(s5, "SELECT x WHERE y > :y_1") + + def test_stmt_lambda_w_additional_hascachekey_variants(self): + def go(col_expr, q): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(col_expr == q) + + return stmt + + c1 = column("x") + c2 = column("y") + + s1 = go(c1, 5) + s2 = go(c2, 10) + s3 = go(c1, 8) + s4 = go(c2, 12) + + self.assert_compile( + s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5} + ) + self.assert_compile( + s2, "SELECT y WHERE y = :q_1", checkparams={"q_1": 10} + ) + self.assert_compile( + s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8} + ) + self.assert_compile( + s4, "SELECT y WHERE y = :q_1", checkparams={"q_1": 12} + ) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + s4key = s4._generate_cache_key() + + eq_(s1key[0], s3key[0]) + eq_(s2key[0], s4key[0]) + ne_(s1key[0], s2key[0]) + + def test_stmt_lambda_w_atonce_whereclause_values_notrack(self): + def go(col_expr, whereclause): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt = stmt.add_criteria( + lambda stmt: stmt.where(whereclause), enable_tracking=False + ) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5) + s2 = go(c1, c1 == 10) + + self.assert_compile( + s1, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5} + ) + + # and as we see, this is wrong. Because whereclause + # is fixed for the lambda and we do not re-evaluate the closure + # for this value changing. this can't be passed unless + # enable_tracking=False. + self.assert_compile( + s2, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5} + ) + + def test_stmt_lambda_w_atonce_whereclause_values(self): + c2 = column("y") + + def go(col_expr, whereclause, x): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt = stmt.add_criteria( + lambda stmt: stmt.where(whereclause).order_by(c2 > x), + ) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5, 9) + s2 = go(c1, c1 == 10, 15) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + + eq_([b.value for b in s1key.bindparams], [5, 9]) + eq_([b.value for b in s2key.bindparams], [10, 15]) + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 ORDER BY y > :x_2", + checkparams={"x_1": 5, "x_2": 9}, + ) + + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 ORDER BY y > :x_2", + checkparams={"x_1": 10, "x_2": 15}, + ) + + def test_stmt_lambda_plain_customtrack(self): + c2 = column("y") + + def go(col_expr, whereclause, p): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt = stmt.add_criteria(lambda stmt: stmt.where(whereclause)) + stmt = stmt.add_criteria( + lambda stmt: stmt.order_by(col_expr), track_on=(col_expr,) + ) + stmt = stmt.add_criteria(lambda stmt: stmt.where(col_expr == p)) + return stmt + + c1 = column("x") + c2 = column("y") + + s1 = go(c1, c1 == 5, 9) + s2 = go(c1, c1 == 10, 15) + s3 = go(c2, c2 == 18, 12) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + + eq_([b.value for b in s1key.bindparams], [5, 9]) + eq_([b.value for b in s2key.bindparams], [10, 15]) + eq_([b.value for b in s3key.bindparams], [18, 12]) + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 AND x = :p_1 ORDER BY x", + checkparams={"x_1": 5, "p_1": 9}, + ) + + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 AND x = :p_1 ORDER BY x", + checkparams={"x_1": 10, "p_1": 15}, + ) + + self.assert_compile( + s3, + "SELECT y WHERE y = :y_1 AND y = :p_1 ORDER BY y", + checkparams={"y_1": 18, "p_1": 12}, + ) + + def test_stmt_lambda_w_atonce_whereclause_customtrack_binds(self): + c2 = column("y") + + # this pattern is *completely unnecessary*, and I would prefer + # if we can detect this and just raise, because when it is not done + # correctly, it is *extremely* difficult to catch it failing. + # however I also can't come up with a reliable way to catch it. + # so we will keep the use of "track_on" to be internal. + + def go(col_expr, whereclause, p): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt = stmt.add_criteria( + lambda stmt: stmt.where(whereclause).order_by(col_expr > p), + track_on=(whereclause, whereclause.right.value), + ) + + return stmt + + c1 = column("x") + c2 = column("y") + + s1 = go(c1, c1 == 5, 9) + s2 = go(c1, c1 == 10, 15) + s3 = go(c2, c2 == 18, 12) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + + eq_([b.value for b in s1key.bindparams], [5, 9]) + eq_([b.value for b in s2key.bindparams], [10, 15]) + eq_([b.value for b in s3key.bindparams], [18, 12]) + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 ORDER BY x > :p_1", + checkparams={"x_1": 5, "p_1": 9}, + ) + + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 ORDER BY x > :p_1", + checkparams={"x_1": 10, "p_1": 15}, + ) + + self.assert_compile( + s3, + "SELECT y WHERE y = :y_1 ORDER BY y > :p_1", + checkparams={"y_1": 18, "p_1": 12}, + ) + + def test_stmt_lambda_track_closure_binds_one(self): + def go(col_expr, whereclause): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(whereclause) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5) + s2 = go(c1, c1 == 10) + + self.assert_compile( + s1, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5} + ) + self.assert_compile( + s2, "SELECT x WHERE x = :x_1", checkparams={"x_1": 10} + ) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + + eq_(s1key.key, s2key.key) + + eq_([b.value for b in s1key.bindparams], [5]) + eq_([b.value for b in s2key.bindparams], [10]) + + def test_stmt_lambda_track_closure_binds_two(self): + def go(col_expr, whereclause, x, y): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(whereclause).where( + and_(c1 == x, c1 < y) + ) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5, 8, 9) + s2 = go(c1, c1 == 10, 12, 14) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1", + checkparams={"x_1": 5, "x_2": 8, "y_1": 9}, + ) + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1", + checkparams={"x_1": 10, "x_2": 12, "y_1": 14}, + ) + + eq_([b.value for b in s1key.bindparams], [5, 8, 9]) + eq_([b.value for b in s2key.bindparams], [10, 12, 14]) + + s1_compiled_cached = s1.compile(cache_key=s1key) + + params = s1_compiled_cached.construct_params( + extracted_parameters=s2key[1] + ) + + eq_(params, {"x_1": 10, "x_2": 12, "y_1": 14}) + + def test_stmt_lambda_track_closure_binds_three(self): + def go(col_expr, whereclause, x, y): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(whereclause) + stmt += lambda stmt: stmt.where(and_(c1 == x, c1 < y)) + + return stmt + + c1 = column("x") + + s1 = go(c1, c1 == 5, 8, 9) + s2 = go(c1, c1 == 10, 12, 14) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + + self.assert_compile( + s1, + "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1", + checkparams={"x_1": 5, "x_2": 8, "y_1": 9}, + ) + self.assert_compile( + s2, + "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1", + checkparams={"x_1": 10, "x_2": 12, "y_1": 14}, + ) + + eq_([b.value for b in s1key.bindparams], [5, 8, 9]) + eq_([b.value for b in s2key.bindparams], [10, 12, 14]) + + s1_compiled_cached = s1.compile(cache_key=s1key) + + params = s1_compiled_cached.construct_params( + extracted_parameters=s2key[1] + ) + + eq_(params, {"x_1": 10, "x_2": 12, "y_1": 14}) + + def test_stmt_lambda_w_atonce_whereclause_novalue(self): + def go(col_expr, whereclause): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(whereclause) + + return stmt + + c1 = column("x") + + s1 = go(c1, bindparam("x")) + + self.assert_compile(s1, "SELECT x WHERE :x") + + def test_stmt_lambda_w_additional_hashable_variants(self): + # note a Python 2 old style class would fail here because it + # isn't hashable. right now we do a hard check for __hash__ which + # will raise if the attr isn't present + class Thing(object): + def __init__(self, col_expr): + self.col_expr = col_expr + + def go(thing, q): + stmt = lambdas.lambda_stmt(lambda: select(thing.col_expr)) + stmt += lambda stmt: stmt.where(thing.col_expr == q) + + return stmt + + c1 = Thing(column("x")) + c2 = Thing(column("y")) + + s1 = go(c1, 5) + s2 = go(c2, 10) + s3 = go(c1, 8) + s4 = go(c2, 12) + + self.assert_compile( + s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5} + ) + self.assert_compile( + s2, "SELECT y WHERE y = :q_1", checkparams={"q_1": 10} + ) + self.assert_compile( + s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8} + ) + self.assert_compile( + s4, "SELECT y WHERE y = :q_1", checkparams={"q_1": 12} + ) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + s4key = s4._generate_cache_key() + + eq_(s1key[0], s3key[0]) + eq_(s2key[0], s4key[0]) + ne_(s1key[0], s2key[0]) + + def test_stmt_lambda_w_set_of_opts(self): + + stmt = lambdas.lambda_stmt(lambda: select(column("x"))) + + opts = {column("x"), column("y")} + + assert_raises_message( + exc.ArgumentError, + 'Can\'t create a cache key for lambda closure variable "opts" ' + "because it's a set. try using a list", + stmt.__add__, + lambda stmt: stmt.options(*opts), + ) + + def test_stmt_lambda_w_list_of_opts(self): + def go(opts): + stmt = lambdas.lambda_stmt(lambda: select(column("x"))) + stmt += lambda stmt: stmt.options(*opts) + + return stmt + + s1 = go([column("a"), column("b")]) + + s2 = go([column("a"), column("b")]) + + s3 = go([column("q"), column("b")]) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + + eq_(s1key.key, s2key.key) + ne_(s1key.key, s3key.key) + + def test_stmt_lambda_hey_theres_multiple_paths(self): + def go(x, y): + stmt = lambdas.lambda_stmt(lambda: select(column("x"))) + + if x > 5: + stmt += lambda stmt: stmt.where(column("x") == x) + else: + stmt += lambda stmt: stmt.where(column("y") == y) + + stmt += lambda stmt: stmt.order_by(column("q")) + + # TODO: need more path variety here to exercise + # using a full path key + + return stmt + + s1 = go(2, 5) + s2 = go(8, 7) + s3 = go(4, 9) + s4 = go(10, 1) + + self.assert_compile(s1, "SELECT x WHERE y = :y_1 ORDER BY q") + self.assert_compile(s2, "SELECT x WHERE x = :x_1 ORDER BY q") + self.assert_compile(s3, "SELECT x WHERE y = :y_1 ORDER BY q") + self.assert_compile(s4, "SELECT x WHERE x = :x_1 ORDER BY q") + def test_coercion_cols_clause(self): assert_raises_message( exc.ArgumentError, diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 01c8d7ca65..58280bb677 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -29,6 +29,7 @@ from sqlalchemy import TypeDecorator from sqlalchemy import union from sqlalchemy import util from sqlalchemy.sql import Alias +from sqlalchemy.sql import annotation from sqlalchemy.sql import base from sqlalchemy.sql import column from sqlalchemy.sql import elements @@ -2352,6 +2353,33 @@ class AnnotationsTest(fixtures.TestBase): annot = obj._annotate({}) ne_(set([obj]), set([annot])) + def test_replacement_traverse_preserve(self): + """test that replacement traverse that hits an unannotated column + does not use it when replacing an annotated column. + + this requires that replacement traverse store elements in the + "seen" hash based on id(), not hash. + + """ + t = table("t", column("x")) + + stmt = select([t.c.x]) + + whereclause = annotation._deep_annotate(t.c.x == 5, {"foo": "bar"}) + + eq_(whereclause._annotations, {"foo": "bar"}) + eq_(whereclause.left._annotations, {"foo": "bar"}) + eq_(whereclause.right._annotations, {"foo": "bar"}) + + stmt = stmt.where(whereclause) + + s2 = visitors.replacement_traverse(stmt, {}, lambda elem: None) + + whereclause = s2._where_criteria[0] + eq_(whereclause._annotations, {"foo": "bar"}) + eq_(whereclause.left._annotations, {"foo": "bar"}) + eq_(whereclause.right._annotations, {"foo": "bar"}) + def test_proxy_set_iteration_includes_annotated(self): from sqlalchemy.schema import Column