:param future: Use the 2.0 style :class:`_future.Engine` and
:class:`_future.Connection` API.
- ..versionadded:: 1.4
+ .. versionadded:: 1.4
.. seealso::
if (
compiled
and compiled._result_columns
- and context.cache_hit
+ and context.cache_hit is context.dialect.CACHE_HIT
and not compiled._rewrites_selected_columns
and compiled.statement is not context.invoked_statement
):
SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE)
+CACHE_HIT = util.symbol("CACHE_HIT")
+CACHE_MISS = util.symbol("CACHE_MISS")
+CACHING_DISABLED = util.symbol("CACHING_DISABLED")
+NO_CACHE_KEY = util.symbol("NO_CACHE_KEY")
+
+
class DefaultDialect(interfaces.Dialect):
"""Default implementation of Dialect"""
"""
+ CACHE_HIT = CACHE_HIT
+ CACHE_MISS = CACHE_MISS
+ CACHING_DISABLED = CACHING_DISABLED
+ NO_CACHE_KEY = NO_CACHE_KEY
+
@util.deprecated_params(
convert_unicode=(
"1.3",
_expanded_parameters = util.immutabledict()
+ cache_hit = NO_CACHE_KEY
+
@classmethod
def _init_ddl(
cls,
parameters,
invoked_statement,
extracted_parameters,
- cache_hit=False,
+ cache_hit=CACHING_DISABLED,
):
"""Initialize execution context for a Compiled construct."""
return "raw sql"
now = util.perf_counter()
- if self.compiled.cache_key is None:
+
+ ch = self.cache_hit
+
+ if ch is NO_CACHE_KEY:
return "no key %.5fs" % (now - self.compiled._gen_time,)
- elif self.cache_hit:
+ elif ch is CACHE_HIT:
return "cached since %.4gs ago" % (now - self.compiled._gen_time,)
- else:
+ elif ch is CACHE_MISS:
return "generated in %.5fs" % (now - self.compiled._gen_time,)
+ elif ch is CACHING_DISABLED:
+ return "caching disabled %.5fs" % (now - self.compiled._gen_time,)
+ else:
+ return "unknown"
@util.memoized_property
def engine(self):
import numbers
import re
+import types
from . import operators
from . import roles
from . import visitors
+from .base import Options
+from .traversals import HasCacheKey
from .visitors import Visitable
from .. import exc
from .. import inspection
of a SQL expression construct.
"""
+
return not isinstance(
- element, (Visitable, schema.SchemaEventTarget)
+ element, (Visitable, schema.SchemaEventTarget),
) and not hasattr(element, "__clause_element__")
+def _deep_is_literal(element):
+ """Return whether or not the element is a "literal" in the context
+ of a SQL expression construct.
+
+ does a deeper more esoteric check than _is_literal. is used
+ for lambda elements that have to distinguish values that would
+ be bound vs. not without any context.
+
+ """
+
+ return (
+ not isinstance(
+ element,
+ (Visitable, schema.SchemaEventTarget, HasCacheKey, Options,),
+ )
+ and not hasattr(element, "__clause_element__")
+ and (
+ not isinstance(element, type)
+ or not issubclass(element, HasCacheKey)
+ )
+ and not isinstance(element, types.FunctionType)
+ )
+
+
def _document_text_coercion(paramname, meth_rst, param_rst):
return util.add_parameter_text(
paramname,
class CoerceTextStatementImpl(_CoerceLiterals, RoleImpl):
__slots__ = ()
- def _literal_coercion(self, element, **kw):
+ def _dont_literal_coercion(self, element, **kw):
if callable(element) and hasattr(element, "__code__"):
- return lambdas.StatementLambdaElement(element, self._role_class)
+ return lambdas.StatementLambdaElement(
+ element,
+ self._role_class,
+ additional_cache_criteria=kw.get(
+ "additional_cache_criteria", ()
+ ),
+ tracked=kw["tra"],
+ )
else:
return super(CoerceTextStatementImpl, self)._literal_coercion(
element, **kw
from .base import PARSE_AUTOCOMMIT
from .base import SingletonConstant
from .coercions import _document_text_coercion
-from .traversals import _copy_internals
from .traversals import _get_children
+from .traversals import HasCopyInternals
from .traversals import MemoizedHasCacheKey
from .traversals import NO_CACHE
from .visitors import cloned_traverse
roles.SQLRole,
SupportsWrappingAnnotations,
MemoizedHasCacheKey,
+ HasCopyInternals,
Traversible,
):
"""Base class for elements of a programmatically constructed SQL
"""
return traversals.compare(self, other, **kw)
- def _copy_internals(self, omit_attrs=(), **kw):
- """Reassign internal elements to be clones of themselves.
-
- Called during a copy-and-traverse operation on newly
- shallow-copied elements to create a deep copy.
-
- The given clone function should be used, which may be applying
- additional transformations to the element (i.e. replacement
- traversal, cloned traversal, annotations).
-
- """
-
- try:
- traverse_internals = self._traverse_internals
- except AttributeError:
- # user-defined classes may not have a _traverse_internals
- return
-
- for attrname, obj, meth in _copy_internals.run_generated_dispatch(
- self, traverse_internals, "_generated_copy_internals_traversal"
- ):
- if attrname in omit_attrs:
- continue
-
- if obj is not None:
- result = meth(self, attrname, obj, **kw)
- if result is not None:
- setattr(self, attrname, result)
-
def get_children(self, omit_attrs=(), **kw):
r"""Return immediate child :class:`.visitors.Traversible`
elements of this :class:`.visitors.Traversible`.
else:
elem_cache_key = None
- cache_hit = False
-
if elem_cache_key:
cache_key, extracted_params = elem_cache_key
key = (
compiled_sql = compiled_cache.get(key)
if compiled_sql is None:
+ cache_hit = dialect.CACHE_MISS
compiled_sql = self._compiler(
dialect,
cache_key=elem_cache_key,
)
compiled_cache[key] = compiled_sql
else:
- cache_hit = True
+ cache_hit = dialect.CACHE_HIT
else:
extracted_params = None
compiled_sql = self._compiler(
schema_translate_map=schema_translate_map,
**kw
)
+ cache_hit = (
+ dialect.CACHING_DISABLED
+ if compiled_cache is None
+ else dialect.NO_CACHE_KEY
+ )
return compiled_sql, extracted_params, cache_hit
if required is NO_ARG:
required = value is NO_ARG and callable_ is None
if value is NO_ARG:
- self._value_required_for_cache = False
value = None
- else:
- self._value_required_for_cache = True
if quote is not None:
key = quoted_name(key, quote)
"""Return a copy of this :class:`.BindParameter` with the given value
set.
"""
+
cloned = self._clone(maintain_key=maintain_key)
cloned.value = value
cloned.callable = None
anon_map[idself] = id_ = str(anon_map.index)
anon_map.index += 1
- bindparams.append(self)
+ if bindparams is not None:
+ bindparams.append(self)
return (
id_,
import itertools
import operator
import sys
+import types
import weakref
from . import coercions
from . import traversals
from . import type_api
from . import visitors
+from .base import _clone
from .operators import ColumnOperators
from .. import exc
from .. import inspection
from .. import util
from ..util import collections_abc
-_trackers = weakref.WeakKeyDictionary()
+_closure_per_cache_key = util.LRUCache(1000)
-_TRACKERS = 0
-_STALE_CHECK = 1
-_REAL_FN = 2
-_EXPR = 3
-_IS_SEQUENCE = 4
-_PROPAGATE_ATTRS = 5
-
-
-def lambda_stmt(lmb):
+def lambda_stmt(lmb, **opts):
"""Produce a SQL statement that is cached as a lambda.
- This SQL statement will only be constructed if element has not been
- compiled yet. The approach is used to save on Python function overhead
- when constructing statements that will be cached.
+ The Python code object within the lambda is scanned for both Python
+ literals that will become bound parameters as well as closure variables
+ that refer to Core or ORM constructs that may vary. The lambda itself
+ will be invoked only once per particular set of constructs detected.
E.g.::
"""
- return coercions.expect(roles.CoerceTextStatementRole, lmb)
+
+ return StatementLambdaElement(lmb, roles.CoerceTextStatementRole, **opts)
class LambdaElement(elements.ClauseElement):
_is_lambda_element = True
- _resolved_bindparams = ()
-
_traverse_internals = [
("_resolved", visitors.InternalTraversal.dp_clauseelement)
]
+ _transforms = ()
+
+ parent_lambda = None
+
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
def __init__(self, fn, role, apply_propagate_attrs=None, **kw):
self.fn = fn
self.role = role
- self.parent_lambda = None
+ self.tracker_key = (fn.__code__,)
if apply_propagate_attrs is None and (
role is roles.CoerceTextStatementRole
):
apply_propagate_attrs = self
- if fn.__code__ not in _trackers:
- rec = self._initialize_var_trackers(
- role, apply_propagate_attrs, kw
+ rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, kw)
+
+ if apply_propagate_attrs is not None:
+ propagate_attrs = rec.propagate_attrs
+ if propagate_attrs:
+ apply_propagate_attrs._propagate_attrs = propagate_attrs
+
+ def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, kw):
+ lambda_cache = kw.get("lambda_cache", _closure_per_cache_key)
+
+ tracker_key = self.tracker_key
+
+ fn = self.fn
+ closure = fn.__closure__
+
+ tracker = AnalyzedCode.get(
+ fn,
+ self,
+ kw,
+ track_bound_values=kw.get("track_bound_values", True),
+ enable_tracking=kw.get("enable_tracking", True),
+ track_on=kw.get("track_on", None),
+ )
+
+ self._resolved_bindparams = bindparams = []
+
+ anon_map = traversals.anon_map()
+ cache_key = tuple(
+ [
+ getter(closure, kw, anon_map, bindparams)
+ for getter in tracker.closure_trackers
+ ]
+ )
+ if self.parent_lambda is not None:
+ cache_key = self.parent_lambda.closure_cache_key + cache_key
+
+ self.closure_cache_key = cache_key
+
+ try:
+ rec = lambda_cache[tracker_key + cache_key]
+ except KeyError:
+ rec = None
+
+ if rec is None:
+ rec = AnalyzedFunction(
+ tracker, self, apply_propagate_attrs, kw, fn
)
+ rec.closure_bindparams = bindparams
+ lambda_cache[tracker_key + cache_key] = rec
else:
- rec = _trackers[self.fn.__code__]
- closure = fn.__closure__
+ bindparams[:] = [
+ orig_bind._with_value(new_bind.value, maintain_key=True)
+ for orig_bind, new_bind in zip(
+ rec.closure_bindparams, bindparams
+ )
+ ]
+
+ if self.parent_lambda is not None:
+ bindparams[:0] = self.parent_lambda._resolved_bindparams
- # check if the objects fixed inside the lambda that we've cached
- # have been changed. This can apply to things like mappers that
- # were recreated in test suites. if so, re-initialize.
- #
- # this is a small performance hit on every use for a not very
- # common situation, however it's very hard to debug if the
- # condition does occur.
- for idx, obj in rec[_STALE_CHECK]:
- if closure[idx].cell_contents is not obj:
- rec = self._initialize_var_trackers(
- role, apply_propagate_attrs, kw
- )
- break
self._rec = rec
- if apply_propagate_attrs is not None:
- propagate_attrs = rec[_PROPAGATE_ATTRS]
- if propagate_attrs:
- apply_propagate_attrs._propagate_attrs = propagate_attrs
+ lambda_element = self
+ while lambda_element is not None:
+ rec = lambda_element._rec
+ if rec.bindparam_trackers:
+ tracker_instrumented_fn = rec.tracker_instrumented_fn
+ for tracker in rec.bindparam_trackers:
+ tracker(
+ lambda_element.fn, tracker_instrumented_fn, bindparams
+ )
+ lambda_element = lambda_element.parent_lambda
- if rec[_TRACKERS]:
- self._resolved_bindparams = bindparams = []
- for tracker in rec[_TRACKERS]:
- tracker(self.fn, bindparams)
+ return rec
def __getattr__(self, key):
- return getattr(self._rec[_EXPR], key)
+ return getattr(self._rec.expected_expr, key)
@property
def _is_sequence(self):
- return self._rec[_IS_SEQUENCE]
+ return self._rec.is_sequence
@property
def _select_iterable(self):
def _param_dict(self):
return {b.key: b.value for b in self._resolved_bindparams}
- @util.memoized_property
- def _resolved(self):
+ def _setup_binds_for_tracked_expr(self, expr):
bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
def replace(thing):
and thing.key in bindparam_lookup
):
bind = bindparam_lookup[thing.key]
- # TODO: consider
- # if we should clone the bindparam here, re-cache the new
- # version, etc. also we make an assumption about "expanding"
- # in this case.
if thing.expanding:
bind.expanding = True
return bind
- expr = self._rec[_EXPR]
-
- if self._rec[_IS_SEQUENCE]:
+ if self._rec.is_sequence:
expr = [
visitors.replacement_traverse(sub_expr, {}, replace)
for sub_expr in expr
return expr
+ def _copy_internals(
+ self, clone=_clone, deferred_copy_internals=None, **kw
+ ):
+ # TODO: this needs A LOT of tests
+ self._resolved = clone(
+ self._resolved,
+ deferred_copy_internals=deferred_copy_internals,
+ **kw
+ )
+
+ @util.memoized_property
+ def _resolved(self):
+ expr = self._rec.expected_expr
+
+ if self._resolved_bindparams:
+ expr = self._setup_binds_for_tracked_expr(expr)
+
+ return expr
+
def _gen_cache_key(self, anon_map, bindparams):
- cache_key = (self.fn.__code__, self.__class__)
+ cache_key = (
+ self.fn.__code__,
+ self.__class__,
+ ) + self.closure_cache_key
+
+ parent = self.parent_lambda
+ while parent is not None:
+ cache_key = (
+ (parent.fn.__code__,) + parent.closure_cache_key + cache_key
+ )
+
+ parent = parent.parent_lambda
if self._resolved_bindparams:
bindparams.extend(self._resolved_bindparams)
def _invoke_user_fn(self, fn, *arg):
return fn()
- def _initialize_var_trackers(self, role, apply_propagate_attrs, coerce_kw):
- fn = self.fn
- # track objects referenced inside of lambdas, create bindparams
- # ahead of time for literal values. If bindparams are produced,
- # then rewrite the function globals and closure as necessary so that
- # it refers to the bindparams, then invoke the function
- new_closure = {}
- new_globals = fn.__globals__.copy()
- tracker_collection = []
- check_closure_for_stale = []
+class DeferredLambdaElement(LambdaElement):
+ """A LambdaElement where the lambda accepts arguments and is
+ invoked within the compile phase with special context.
- for name in fn.__code__.co_names:
- if name not in new_globals:
- continue
+ This lambda doesn't normally produce its real SQL expression outside of the
+ compile phase. It is passed a fixed set of initial arguments
+ so that it can generate a sample expression.
- bound_value = _roll_down_to_literal(new_globals[name])
+ """
- if coercions._is_literal(bound_value):
- new_globals[name] = bind = PyWrapper(name, bound_value)
- tracker_collection.append(_globals_tracker(name, bind))
+ def __init__(self, fn, role, lambda_args=(), **kw):
+ self.lambda_args = lambda_args
+ self.coerce_kw = kw
+ super(DeferredLambdaElement, self).__init__(fn, role, **kw)
- if fn.__closure__:
- for closure_index, (fv, cell) in enumerate(
- zip(fn.__code__.co_freevars, fn.__closure__)
- ):
+ def _invoke_user_fn(self, fn, *arg):
+ return fn(*self.lambda_args)
- bound_value = _roll_down_to_literal(cell.cell_contents)
+ def _resolve_with_args(self, *lambda_args):
+ tracker_fn = self._rec.tracker_instrumented_fn
+ expr = tracker_fn(*lambda_args)
- if coercions._is_literal(bound_value):
- new_closure[fv] = bind = PyWrapper(fv, bound_value)
- tracker_collection.append(
- _closure_tracker(fv, bind, closure_index)
- )
- else:
- new_closure[fv] = cell.cell_contents
- # for normal cell contents, add them to a list that
- # we can compare later when we get new lambdas. if
- # any identities have changed, then we will recalculate
- # the whole lambda and run it again.
- check_closure_for_stale.append(
- (closure_index, cell.cell_contents)
- )
+ expr = coercions.expect(self.role, expr, **self.coerce_kw)
- if tracker_collection:
- new_fn = _rewrite_code_obj(
- fn,
- [new_closure[name] for name in fn.__code__.co_freevars],
- new_globals,
- )
- expr = self._invoke_user_fn(new_fn)
+ if self._resolved_bindparams:
+ expr = self._setup_binds_for_tracked_expr(expr)
- else:
- new_fn = fn
- expr = self._invoke_user_fn(new_fn)
- tracker_collection = []
+ # TODO: TEST TEST TEST, this is very out there
+ for deferred_copy_internals in self._transforms:
+ expr = deferred_copy_internals(expr)
- if self.parent_lambda is None:
- if isinstance(expr, collections_abc.Sequence):
- expected_expr = [
- coercions.expect(
- role,
- sub_expr,
- apply_propagate_attrs=apply_propagate_attrs,
- **coerce_kw
- )
- for sub_expr in expr
- ]
- is_sequence = True
- else:
- expected_expr = coercions.expect(
- role,
- expr,
- apply_propagate_attrs=apply_propagate_attrs,
- **coerce_kw
- )
- is_sequence = False
- else:
- expected_expr = expr
- is_sequence = False
+ return expr
- if apply_propagate_attrs is not None:
- propagate_attrs = apply_propagate_attrs._propagate_attrs
- else:
- propagate_attrs = util.immutabledict()
-
- rec = _trackers[self.fn.__code__] = (
- tracker_collection,
- check_closure_for_stale,
- new_fn,
- expected_expr,
- is_sequence,
- propagate_attrs,
+ def _copy_internals(
+ self, clone=_clone, deferred_copy_internals=None, **kw
+ ):
+ super(DeferredLambdaElement, self)._copy_internals(
+ clone=clone, deferred_copy_internals=deferred_copy_internals, **kw
)
- return rec
+
+ # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know
+ # our expression yet. so hold onto the replacement
+ if deferred_copy_internals:
+ self._transforms += (deferred_copy_internals,)
class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
"""
+ def __init__(self, fn, parent_lambda, **kw):
+ self._default_kw = default_kw = {}
+ global_track_bound_values = kw.pop("global_track_bound_values", None)
+ if global_track_bound_values is not None:
+ default_kw["track_bound_values"] = global_track_bound_values
+ kw["track_bound_values"] = global_track_bound_values
+
+ if "lambda_cache" in kw:
+ default_kw["lambda_cache"] = kw["lambda_cache"]
+
+ super(StatementLambdaElement, self).__init__(fn, parent_lambda, **kw)
+
def __add__(self, other):
- return LinkedLambdaElement(other, parent_lambda=self)
+ return LinkedLambdaElement(
+ other, parent_lambda=self, **self._default_kw
+ )
+
+ def add_criteria(self, other, **kw):
+ if self._default_kw:
+ if kw:
+ default_kw = self._default_kw.copy()
+ default_kw.update(kw)
+ kw = default_kw
+ else:
+ kw = self._default_kw
+
+ return LinkedLambdaElement(other, parent_lambda=self, **kw)
def _execute_on_connection(
self, connection, multiparams, params, execution_options
):
- if self._rec[_EXPR].supports_execution:
+ if self._rec.expected_expr.supports_execution:
return connection._execute_clauseelement(
self, multiparams, params, execution_options
)
@property
def _with_options(self):
- return self._rec[_EXPR]._with_options
+ return self._rec.expected_expr._with_options
@property
def _effective_plugin_target(self):
- return self._rec[_EXPR]._effective_plugin_target
+ return self._rec.expected_expr._effective_plugin_target
@property
def _is_future(self):
- return self._rec[_EXPR]._is_future
+ return self._rec.expected_expr._is_future
@property
def _execution_options(self):
- return self._rec[_EXPR]._execution_options
+ return self._rec.expected_expr._execution_options
+
+ def spoil(self):
+ """Return a new :class:`.StatementLambdaElement` that will run
+ all lambdas unconditionally each time.
+
+ """
+ return NullLambdaStatement(self.fn())
+
+
+class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
+ """Provides the :class:`.StatementLambdaElement` API but does not
+ cache or analyze lambdas.
+
+ the lambdas are instead invoked immediately.
+
+ The intended use is to isolate issues that may arise when using
+ lambda statements.
+
+ """
+
+ __visit_name__ = "lambda_element"
+
+ _is_lambda_element = True
+
+ _traverse_internals = [
+ ("_resolved", visitors.InternalTraversal.dp_clauseelement)
+ ]
+
+ def __init__(self, statement):
+ self._resolved = statement
+ self._propagate_attrs = statement._propagate_attrs
+
+ def __getattr__(self, key):
+ return getattr(self._resolved, key)
+
+ def __add__(self, other):
+ statement = other(self._resolved)
+
+ return NullLambdaStatement(statement)
+
+ def add_criteria(self, other, **kw):
+ statement = other(self._resolved)
+
+ return NullLambdaStatement(statement)
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ if self._resolved.supports_execution:
+ return connection._execute_clauseelement(
+ self, multiparams, params, execution_options
+ )
+ else:
+ raise exc.ObjectNotExecutableError(self)
class LinkedLambdaElement(StatementLambdaElement):
+ """Represent subsequent links of a :class:`.StatementLambdaElement`."""
+
+ role = None
+
def __init__(self, fn, parent_lambda, **kw):
+ self._default_kw = parent_lambda._default_kw
+
self.fn = fn
self.parent_lambda = parent_lambda
- role = None
- apply_propagate_attrs = self
+ self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
+ self._retrieve_tracker_rec(fn, self, kw)
+ self._propagate_attrs = parent_lambda._propagate_attrs
+
+ def _invoke_user_fn(self, fn, *arg):
+ return fn(self.parent_lambda._resolved)
+
+
+class AnalyzedCode(object):
+ __slots__ = (
+ "track_closure_variables",
+ "track_bound_values",
+ "bindparam_trackers",
+ "closure_trackers",
+ "build_py_wrappers",
+ )
+ _fns = weakref.WeakKeyDictionary()
+
+ @classmethod
+ def get(cls, fn, lambda_element, lambda_kw, **kw):
+ try:
+ # TODO: validate kw haven't changed?
+ return cls._fns[fn.__code__]
+ except KeyError:
+ pass
+ cls._fns[fn.__code__] = analyzed = AnalyzedCode(
+ fn, lambda_element, lambda_kw, **kw
+ )
+ return analyzed
+
+ def __init__(
+ self,
+ fn,
+ lambda_element,
+ lambda_kw,
+ track_bound_values=True,
+ enable_tracking=True,
+ track_on=None,
+ ):
+ closure = fn.__closure__
+
+ self.track_closure_variables = not track_on
+
+ self.track_bound_values = track_bound_values
+
+ # a list of callables generated from _bound_parameter_getter_*
+ # functions. Each of these uses a PyWrapper object to retrieve
+ # a parameter value
+ self.bindparam_trackers = []
+
+ # a list of callables generated from _cache_key_getter_* functions
+ # these callables work to generate a cache key for the lambda
+ # based on what's inside its closure variables.
+ self.closure_trackers = []
- if fn.__code__ not in _trackers:
- rec = self._initialize_var_trackers(
- role, apply_propagate_attrs, kw
+ self.build_py_wrappers = []
+
+ if enable_tracking:
+ if track_on:
+ self._init_track_on(track_on)
+
+ self._init_globals(fn)
+
+ if closure:
+ self._init_closure(fn)
+
+ self._setup_additional_closure_trackers(fn, lambda_element, lambda_kw)
+
+ def _init_track_on(self, track_on):
+ self.closure_trackers.extend(
+ self._cache_key_getter_track_on(idx, elem)
+ for idx, elem in enumerate(track_on)
+ )
+
+ def _init_globals(self, fn):
+ build_py_wrappers = self.build_py_wrappers
+ bindparam_trackers = self.bindparam_trackers
+ track_bound_values = self.track_bound_values
+
+ for name in fn.__code__.co_names:
+ if name not in fn.__globals__:
+ continue
+
+ _bound_value = self._roll_down_to_literal(fn.__globals__[name])
+
+ if coercions._deep_is_literal(_bound_value):
+ build_py_wrappers.append((name, None))
+ if track_bound_values:
+ bindparam_trackers.append(
+ self._bound_parameter_getter_func_globals(name)
+ )
+
+ def _init_closure(self, fn):
+ build_py_wrappers = self.build_py_wrappers
+ closure = fn.__closure__
+
+ track_bound_values = self.track_bound_values
+ track_closure_variables = self.track_closure_variables
+ bindparam_trackers = self.bindparam_trackers
+ closure_trackers = self.closure_trackers
+
+ for closure_index, (fv, cell) in enumerate(
+ zip(fn.__code__.co_freevars, closure)
+ ):
+ _bound_value = self._roll_down_to_literal(cell.cell_contents)
+
+ if coercions._deep_is_literal(_bound_value):
+ build_py_wrappers.append((fv, closure_index))
+ if track_bound_values:
+ bindparam_trackers.append(
+ self._bound_parameter_getter_func_closure(
+ fv, closure_index
+ )
+ )
+ else:
+ # for normal cell contents, add them to a list that
+ # we can compare later when we get new lambdas. if
+ # any identities have changed, then we will
+ # recalculate the whole lambda and run it again.
+
+ if track_closure_variables:
+ closure_trackers.append(
+ self._cache_key_getter_closure_variable(
+ closure_index, cell.cell_contents
+ )
+ )
+
+ def _setup_additional_closure_trackers(
+ self, fn, lambda_element, lambda_kw
+ ):
+ # an additional step is to actually run the function, then
+ # go through the PyWrapper objects that were set up to catch a bound
+ # parameter. then if they *didn't* make a param, oh they're another
+ # object in the closure we have to track for our cache key. so
+ # create trackers to catch those.
+
+ analyzed_function = AnalyzedFunction(
+ self, lambda_element, None, lambda_kw, fn,
+ )
+
+ closure_trackers = self.closure_trackers
+
+ for pywrapper in analyzed_function.closure_pywrappers:
+ if not pywrapper._sa__has_param:
+ closure_trackers.append(
+ self._cache_key_getter_tracked_literal(pywrapper)
+ )
+
+ @classmethod
+ def _roll_down_to_literal(cls, element):
+ is_clause_element = hasattr(element, "__clause_element__")
+
+ if is_clause_element:
+ while not isinstance(
+ element, (elements.ClauseElement, schema.SchemaItem)
+ ):
+ try:
+ element = element.__clause_element__()
+ except AttributeError:
+ break
+
+ if not is_clause_element:
+ insp = inspection.inspect(element, raiseerr=False)
+ if insp is not None:
+ try:
+ return insp.__clause_element__()
+ except AttributeError:
+ return insp
+
+ # TODO: should we coerce consts None/True/False here?
+ return element
+ else:
+ return element
+
+ def _bound_parameter_getter_func_globals(self, name):
+ """Return a getter that will extend a list of bound parameters
+ with new entries from the ``__globals__`` collection of a particular
+ lambda.
+
+ """
+
+ def extract_parameter_value(
+ current_fn, tracker_instrumented_fn, result
+ ):
+ wrapper = tracker_instrumented_fn.__globals__[name]
+ object.__getattribute__(wrapper, "_extract_bound_parameters")(
+ current_fn.__globals__[name], result
+ )
+
+ return extract_parameter_value
+
+ def _bound_parameter_getter_func_closure(self, name, closure_index):
+ """Return a getter that will extend a list of bound parameters
+ with new entries from the ``__closure__`` collection of a particular
+ lambda.
+
+ """
+
+ def extract_parameter_value(
+ current_fn, tracker_instrumented_fn, result
+ ):
+ wrapper = tracker_instrumented_fn.__closure__[
+ closure_index
+ ].cell_contents
+ object.__getattribute__(wrapper, "_extract_bound_parameters")(
+ current_fn.__closure__[closure_index].cell_contents, result
+ )
+
+ return extract_parameter_value
+
+ def _cache_key_getter_track_on(self, idx, elem):
+ """Return a getter that will extend a cache key with new entries
+ from the "track_on" parameter passed to a :class:`.LambdaElement`.
+
+ """
+ if isinstance(elem, traversals.HasCacheKey):
+
+ def get(closure, kw, anon_map, bindparams):
+ return kw["track_on"][idx]._gen_cache_key(anon_map, bindparams)
+
+ else:
+
+ def get(closure, kw, anon_map, bindparams):
+ return kw["track_on"][idx]
+
+ return get
+
+ def _cache_key_getter_closure_variable(self, idx, cell_contents):
+ """Return a getter that will extend a cache key with new entries
+ from the ``__closure__`` collection of a particular lambda.
+
+ """
+
+ if isinstance(cell_contents, traversals.HasCacheKey):
+
+ def get(closure, kw, anon_map, bindparams):
+ return closure[idx].cell_contents._gen_cache_key(
+ anon_map, bindparams
+ )
+
+ elif isinstance(cell_contents, types.FunctionType):
+
+ def get(closure, kw, anon_map, bindparams):
+ return closure[idx].cell_contents.__code__
+
+ elif cell_contents.__hash__ is None:
+ # this covers dict, etc.
+ def get(closure, kw, anon_map, bindparams):
+ return ()
+
+ else:
+
+ def get(closure, kw, anon_map, bindparams):
+ return closure[idx].cell_contents
+
+ return get
+
+ def _cache_key_getter_tracked_literal(self, pytracker):
+ """Return a getter that will extend a cache key with new entries
+ from the ``__closure__`` collection of a particular lambda.
+
+ this getter differs from _cache_key_getter_closure_variable
+ in that these are detected after the function is run, and PyWrapper
+ objects have recorded that a particular literal value is in fact
+ not being interpreted as a bound parameter.
+
+ """
+
+ elem = pytracker._sa__to_evaluate
+ closure_index = pytracker._sa__closure_index
+
+ if isinstance(elem, set):
+ raise exc.ArgumentError(
+ "Can't create a cache key for lambda closure variable "
+ '"%s" because it\'s a set. try using a list'
+ % pytracker._sa__name
)
+
+ elif isinstance(elem, list):
+
+ def get(closure, kw, anon_map, bindparams):
+ return tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in closure[closure_index].cell_contents
+ )
+
+ elif elem.__hash__ is None:
+ # this covers dict, etc.
+ def get(closure, kw, anon_map, bindparams):
+ return ()
+
else:
- rec = _trackers[self.fn.__code__]
+ def get(closure, kw, anon_map, bindparams):
+ return closure[closure_index].cell_contents
+
+ return get
+
+
+class AnalyzedFunction(object):
+ __slots__ = (
+ "analyzed_code",
+ "fn",
+ "closure_pywrappers",
+ "tracker_instrumented_fn",
+ "expr",
+ "bindparam_trackers",
+ "expected_expr",
+ "is_sequence",
+ "propagate_attrs",
+ "closure_bindparams",
+ )
+
+ def __init__(
+ self, analyzed_code, lambda_element, apply_propagate_attrs, kw, fn,
+ ):
+ self.analyzed_code = analyzed_code
+ self.fn = fn
+
+ self.bindparam_trackers = analyzed_code.bindparam_trackers
+
+ self._instrument_and_run_function(lambda_element)
+
+ self._coerce_expression(lambda_element, apply_propagate_attrs, kw)
+
+ def _instrument_and_run_function(self, lambda_element):
+ analyzed_code = self.analyzed_code
+
+ fn = self.fn
+ self.closure_pywrappers = closure_pywrappers = []
+
+ build_py_wrappers = analyzed_code.build_py_wrappers
+
+ if not build_py_wrappers:
+ self.tracker_instrumented_fn = tracker_instrumented_fn = fn
+ self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
+ else:
+ track_closure_variables = analyzed_code.track_closure_variables
closure = fn.__closure__
- # check if objects referred to by the lambda have changed and
- # re-scan the lambda if so. see comments for this same section in
- # LambdaElement.
- for idx, obj in rec[_STALE_CHECK]:
- if closure[idx].cell_contents is not obj:
- rec = self._initialize_var_trackers(
- role, apply_propagate_attrs, kw
+ # will form the __closure__ of the function when we rebuild it
+ if closure:
+ new_closure = {
+ fv: cell.cell_contents
+ for fv, cell in zip(fn.__code__.co_freevars, closure)
+ }
+ else:
+ new_closure = {}
+
+ # will form the __globals__ of the function when we rebuild it
+ new_globals = fn.__globals__.copy()
+
+ for name, closure_index in build_py_wrappers:
+ if closure_index is not None:
+ value = closure[closure_index].cell_contents
+ new_closure[name] = bind = PyWrapper(
+ name, value, closure_index=closure_index
)
- break
+ if track_closure_variables:
+ closure_pywrappers.append(bind)
+ else:
+ value = fn.__globals__[name]
+ new_globals[name] = bind = PyWrapper(name, value)
+
+ # rewrite the original fn. things that look like they will
+ # become bound parameters are wrapped in a PyWrapper.
+ self.tracker_instrumented_fn = (
+ tracker_instrumented_fn
+ ) = self._rewrite_code_obj(
+ fn,
+ [new_closure[name] for name in fn.__code__.co_freevars],
+ new_globals,
+ )
- self._rec = rec
+ # now invoke the function. This will give us a new SQL
+ # expression, but all the places that there would be a bound
+ # parameter, the PyWrapper in its place will give us a bind
+ # with a predictable name we can match up later.
- self._propagate_attrs = parent_lambda._propagate_attrs
+ # additionally, each PyWrapper will log that it did in fact
+ # create a parameter, otherwise, it's some kind of Python
+ # object in the closure and we want to track that, to make
+ # sure it doesn't change to somehting else, or if it does,
+ # that we create a different tracked function with that
+ # variable.
+ self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
- self._resolved_bindparams = bindparams = []
- rec = self._rec
- while True:
- if rec[_TRACKERS]:
- for tracker in rec[_TRACKERS]:
- tracker(self.fn, bindparams)
- if self.parent_lambda is not None:
- self = self.parent_lambda
- rec = self._rec
+ def _coerce_expression(self, lambda_element, apply_propagate_attrs, kw):
+ """Run the tracker-generated expression through coercion rules.
+
+ After the user-defined lambda has been invoked to produce a statement
+ for re-use, run it through coercion rules to both check that it's the
+ correct type of object and also to coerce it to its useful form.
+
+ """
+
+ parent_lambda = lambda_element.parent_lambda
+ expr = self.expr
+
+ if parent_lambda is None:
+ if isinstance(expr, collections_abc.Sequence):
+ self.expected_expr = [
+ coercions.expect(
+ lambda_element.role,
+ sub_expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ **kw
+ )
+ for sub_expr in expr
+ ]
+ self.is_sequence = True
else:
- break
+ self.expected_expr = coercions.expect(
+ lambda_element.role,
+ expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ **kw
+ )
+ self.is_sequence = False
+ else:
+ self.expected_expr = expr
+ self.is_sequence = False
- def _invoke_user_fn(self, fn, *arg):
- return fn(self.parent_lambda._rec[_EXPR])
+ if apply_propagate_attrs is not None:
+ self.propagate_attrs = apply_propagate_attrs._propagate_attrs
+ else:
+ self.propagate_attrs = util.EMPTY_DICT
- def _gen_cache_key(self, anon_map, bindparams):
- if self._resolved_bindparams:
- bindparams.extend(self._resolved_bindparams)
+ def _rewrite_code_obj(self, f, cell_values, globals_):
+ """Return a copy of f, with a new closure and new globals
- cache_key = (self.fn.__code__, self.__class__)
+ yes it works in pypy :P
- parent = self.parent_lambda
- while parent is not None:
- cache_key = (parent.fn.__code__,) + cache_key
- parent = parent.parent_lambda
+ """
- return cache_key
+ argrange = range(len(cell_values))
+
+ code = "def make_cells():\n"
+ if cell_values:
+ code += " (%s) = (%s)\n" % (
+ ", ".join("i%d" % i for i in argrange),
+ ", ".join("o%d" % i for i in argrange),
+ )
+ code += " def closure():\n"
+ code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
+ code += " return closure.__closure__"
+ vars_ = {"o%d" % i: cell_values[i] for i in argrange}
+ exec(code, vars_, vars_)
+ closure = vars_["make_cells"]()
+
+ func = type(f)(
+ f.__code__, globals_, f.__name__, f.__defaults__, closure
+ )
+ if sys.version_info >= (3,):
+ func.__annotations__ = f.__annotations__
+ func.__kwdefaults__ = f.__kwdefaults__
+ func.__doc__ = f.__doc__
+ func.__module__ = f.__module__
+
+ return func
class PyWrapper(ColumnOperators):
- def __init__(self, name, to_evaluate, getter=None):
+ """A wrapper object that is injected into the ``__globals__`` and
+ ``__closure__`` of a Python function.
+
+ When the function is instrumented with :class:`.PyWrapper` objects, it is
+ then invoked just once in order to set up the wrappers. We look through
+ all the :class:`.PyWrapper` objects we made to find the ones that generated
+ a :class:`.BindParameter` object, e.g. the expression system interpreted
+ something as a literal. Those positions in the globals/closure are then
+ ones that we will look at, each time a new lambda comes in that refers to
+ the same ``__code__`` object. In this way, we keep a single version of
+ the SQL expression that this lambda produced, without calling upon the
+ Python function that created it more than once, unless its other closure
+ variables have changed. The expression is then transformed to have the
+ new bound values embedded into it.
+
+ """
+
+ def __init__(self, name, to_evaluate, closure_index=None, getter=None):
self._name = name
self._to_evaluate = to_evaluate
self._param = None
+ self._has_param = False
self._bind_paths = {}
self._getter = getter
+ self._closure_index = closure_index
def __call__(self, *arg, **kw):
elem = object.__getattribute__(self, "_to_evaluate")
value = elem(*arg, **kw)
- if coercions._is_literal(value) and not isinstance(
+ if coercions._deep_is_literal(value) and not isinstance(
# TODO: coverage where an ORM option or similar is here
value,
traversals.HasCacheKey,
if param is None:
name = object.__getattribute__(self, "_name")
self._param = param = elements.BindParameter(name, unique=True)
+ self._has_param = True
param.type = type_api._resolve_value_to_type(to_evaluate)
-
return param._with_value(to_evaluate, maintain_key=True)
def __getattribute__(self, key):
else:
return self._sa__add_getter(key, operator.attrgetter)
+ def __iter__(self):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ return iter(elem)
+
def __getitem__(self, key):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ if not hasattr(elem, "__getitem__"):
+ raise AttributeError("__getitem__")
+
if isinstance(key, PyWrapper):
# TODO: coverage
raise exc.InvalidRequestError(
elem = object.__getattribute__(self, "_to_evaluate")
value = getter(elem)
- if coercions._is_literal(value):
- wrapper = PyWrapper(key, value, getter)
+ if coercions._deep_is_literal(value):
+ wrapper = PyWrapper(key, value, getter=getter)
bind_paths[bind_path_key] = wrapper
return wrapper
else:
return value
-def _roll_down_to_literal(element):
- is_clause_element = hasattr(element, "__clause_element__")
-
- if is_clause_element:
- while not isinstance(
- element, (elements.ClauseElement, schema.SchemaItem)
- ):
- try:
- element = element.__clause_element__()
- except AttributeError:
- break
-
- if not is_clause_element:
- insp = inspection.inspect(element, raiseerr=False)
- if insp is not None:
- try:
- return insp.__clause_element__()
- except AttributeError:
- return insp
-
- # TODO: should we coerce consts None/True/False here?
- return element
- else:
- return element
-
-
-def _globals_tracker(name, wrapper):
- def extract_parameter_value(current_fn, result):
- object.__getattribute__(wrapper, "_extract_bound_parameters")(
- current_fn.__globals__[name], result
- )
-
- return extract_parameter_value
-
-
-def _closure_tracker(name, wrapper, closure_index):
- def extract_parameter_value(current_fn, result):
- object.__getattribute__(wrapper, "_extract_bound_parameters")(
- current_fn.__closure__[closure_index].cell_contents, result
- )
-
- return extract_parameter_value
-
-
-def _rewrite_code_obj(f, cell_values, globals_):
- """Return a copy of f, with a new closure and new globals
-
- yes it works in pypy :P
-
- """
-
- argrange = range(len(cell_values))
-
- code = "def make_cells():\n"
- if cell_values:
- code += " (%s) = (%s)\n" % (
- ", ".join("i%d" % i for i in argrange),
- ", ".join("o%d" % i for i in argrange),
- )
- code += " def closure():\n"
- code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
- code += " return closure.__closure__"
- vars_ = {"o%d" % i: cell_values[i] for i in argrange}
- exec(code, vars_, vars_)
- closure = vars_["make_cells"]()
-
- func = type(f)(f.__code__, globals_, f.__name__, f.__defaults__, closure)
- if sys.version_info >= (3,):
- func.__annotations__ = f.__annotations__
- func.__kwdefaults__ = f.__kwdefaults__
- func.__doc__ = f.__doc__
- func.__module__ = f.__module__
-
- return func
-
-
@inspection._inspects(LambdaElement)
def insp(lmb):
return inspection.inspect(lmb._resolved)
"""
_is_select_statement = True
+ is_select = True
def _generate_fromclause_column_proxies(self, fromclause):
# type: (FromClause) -> None
[
("_raw_columns", InternalTraversal.dp_clauseelement_list),
("_from_obj", InternalTraversal.dp_clauseelement_list),
- ("_where_criteria", InternalTraversal.dp_clauseelement_list),
- ("_having_criteria", InternalTraversal.dp_clauseelement_list),
- ("_order_by_clauses", InternalTraversal.dp_clauseelement_list,),
- ("_group_by_clauses", InternalTraversal.dp_clauseelement_list,),
+ ("_where_criteria", InternalTraversal.dp_clauseelement_tuple),
+ ("_having_criteria", InternalTraversal.dp_clauseelement_tuple),
+ ("_order_by_clauses", InternalTraversal.dp_clauseelement_tuple,),
+ ("_group_by_clauses", InternalTraversal.dp_clauseelement_tuple,),
("_setup_joins", InternalTraversal.dp_setup_join_tuple,),
("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple,),
- ("_correlate", InternalTraversal.dp_clauseelement_list),
- ("_correlate_except", InternalTraversal.dp_clauseelement_list,),
+ ("_correlate", InternalTraversal.dp_clauseelement_tuple),
+ ("_correlate_except", InternalTraversal.dp_clauseelement_tuple,),
("_limit_clause", InternalTraversal.dp_clauseelement),
("_offset_clause", InternalTraversal.dp_clauseelement),
("_for_update_arg", InternalTraversal.dp_clauseelement),
("_distinct", InternalTraversal.dp_boolean),
- ("_distinct_on", InternalTraversal.dp_clauseelement_list),
+ ("_distinct_on", InternalTraversal.dp_clauseelement_tuple),
("_label_style", InternalTraversal.dp_plain_obj),
("_is_future", InternalTraversal.dp_boolean),
]
@_generative
def join(self, target, onclause=None, isouter=False, full=False):
- r"""Create a SQL JOIN against this :class:`_expresson.Select`
+ r"""Create a SQL JOIN against this :class:`_expression.Select`
object's criterion
and apply generatively, returning the newly resulting
:class:`_expression.Select`.
# they've become. This allows us to ensure the same cloned from
# is used when other items such as columns are "cloned"
- all_the_froms = list(
+ all_the_froms = set(
itertools.chain(
_from_objects(*self._raw_columns),
_from_objects(*self._where_criteria),
new_froms = {f: clone(f, **kw) for f in all_the_froms}
# 2. copy FROM collections, adding in joins that we've created.
- self._from_obj = tuple(clone(f, **kw) for f in self._from_obj) + tuple(
- f for f in new_froms.values() if isinstance(f, Join)
+ existing_from_obj = [clone(f, **kw) for f in self._from_obj]
+ add_froms = (
+ set(f for f in new_froms.values() if isinstance(f, Join))
+ .difference(all_the_froms)
+ .difference(existing_from_obj)
)
+ self._from_obj = tuple(existing_from_obj) + tuple(add_froms)
+
# 3. clone everything else, making sure we use columns
# corresponding to the froms we just made.
def replace(obj, **kw):
"""
+ assert isinstance(self._where_criteria, tuple)
self._where_criteria += (
coercions.expect(roles.WhereHavingRole, whereclause),
)
_is_textual = True
+ is_text = True
+ is_select = True
+
def __init__(self, text, columns, positional=False):
self.element = text
# convert for ORM attributes->columns, etc
# statements, not so much, but they usually won't have
# annotations.
result += self._annotations_cache_key
- elif meth is InternalTraversal.dp_clauseelement_list:
+ elif (
+ meth is InternalTraversal.dp_clauseelement_list
+ or meth is InternalTraversal.dp_clauseelement_tuple
+ ):
result += (
attrname,
tuple(
visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY
visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
visit_annotations_key = InternalTraversal.dp_annotations_key
+ visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
visit_string = (
visit_boolean
tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
)
+ visit_executable_options = visit_has_cache_key_list
+
def visit_inspectable_list(
self, attrname, obj, parent, anon_map, bindparams
):
_cache_key_traversal_visitor = _CacheKey()
+class HasCopyInternals(object):
+ def _clone(self, **kw):
+ raise NotImplementedError()
+
+ def _copy_internals(self, omit_attrs=(), **kw):
+ """Reassign internal elements to be clones of themselves.
+
+ Called during a copy-and-traverse operation on newly
+ shallow-copied elements to create a deep copy.
+
+ The given clone function should be used, which may be applying
+ additional transformations to the element (i.e. replacement
+ traversal, cloned traversal, annotations).
+
+ """
+
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ # user-defined classes may not have a _traverse_internals
+ return
+
+ for attrname, obj, meth in _copy_internals.run_generated_dispatch(
+ self, traverse_internals, "_generated_copy_internals_traversal"
+ ):
+ if attrname in omit_attrs:
+ continue
+
+ if obj is not None:
+
+ result = meth(attrname, self, obj, **kw)
+ if result is not None:
+ setattr(self, attrname, result)
+
+
class _CopyInternals(InternalTraversal):
"""Generate a _copy_internals internal traversal dispatch for classes
with a _traverse_internals collection."""
):
return [clone(clause, **kw) for clause in element]
+ def visit_clauseelement_tuple(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return tuple([clone(clause, **kw) for clause in element])
+
+ def visit_executable_options(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return tuple([clone(clause, **kw) for clause in element])
+
def visit_clauseelement_unordered_set(
self, attrname, parent, element, clone=_clone, **kw
):
def visit_clauseelement_list(self, element, **kw):
return element
+ def visit_clauseelement_tuple(self, element, **kw):
+ return element
+
def visit_clauseelement_tuples(self, element, **kw):
return itertools.chain.from_iterable(element)
if not isinstance(target, str):
yield _flatten_clauseelement(target)
- # if onclause is not None and not isinstance(onclause, str):
- # yield _flatten_clauseelement(onclause)
+ if onclause is not None and not isinstance(onclause, str):
+ yield _flatten_clauseelement(onclause)
def visit_dml_ordered_values(self, element, **kw):
for k, v in element:
):
return COMPARE_FAILED
+ visit_executable_options = visit_has_cache_key_list
+
def visit_clauseelement(
self, attrname, left_parent, left, right_parent, right, **kw
):
for l, r in util.zip_longest(left, right, fillvalue=None):
self.stack.append((l, r))
+ def visit_clauseelement_tuple(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((l, r))
+
def _compare_unordered_sequences(self, seq1, seq2, **kw):
if seq1 is None:
return seq2 is None
"""
- dp_clauseelement_tuples = symbol("CT")
+ dp_clauseelement_tuples = symbol("CTS")
"""Visit a list of tuples which contain :class:`_expression.ClauseElement`
objects.
"""
+ dp_clauseelement_tuple = symbol("CT")
+ """Visit a tuple of :class:`_expression.ClauseElement` objects.
+
+ """
+
+ dp_executable_options = symbol("EO")
+
dp_fromclause_ordered_set = symbol("CO")
"""Visit an ordered set of :class:`_expression.FromClause` objects. """
cloned = {}
stop_on = set(opts.get("stop_on", []))
+ def deferred_copy_internals(obj):
+ return cloned_traverse(obj, opts, visitors)
+
def clone(elem, **kw):
if elem in stop_on:
return elem
return cloned[id(elem)]
if obj is not None:
- obj = clone(obj)
+ obj = clone(obj, deferred_copy_internals=deferred_copy_internals)
clone = None # remove gc cycles
return obj
cloned = {}
stop_on = {id(x) for x in opts.get("stop_on", [])}
+ def deferred_copy_internals(obj):
+ return replacement_traverse(obj, opts, replace)
+
def clone(elem, **kw):
if (
id(elem) in stop_on
stop_on.add(id(newelem))
return newelem
else:
-
- if elem not in cloned:
+ # base "already seen" on id(), not hash, so that we don't
+ # replace an Annotated element with its non-annotated one, and
+ # vice versa
+ id_elem = id(elem)
+ if id_elem not in cloned:
if "replace" in kw:
newelem = kw["replace"](elem)
if newelem is not None:
- cloned[elem] = newelem
+ cloned[id_elem] = newelem
return newelem
- cloned[elem] = newelem = elem._clone()
+ cloned[id_elem] = newelem = elem._clone()
newelem._copy_internals(clone=clone, **kw)
- return cloned[elem]
+ return cloned[id_elem]
if obj is not None:
- obj = clone(obj, **opts)
+ obj = clone(
+ obj, deferred_copy_internals=deferred_copy_internals, **opts
+ )
clone = None # remove gc cycles
return obj
from functools import update_wrapper # noqa
from ._collections import coerce_generator_arg # noqa
+from ._collections import coerce_to_immutabledict # noqa
from ._collections import collections_abc # noqa
from ._collections import column_dict # noqa
from ._collections import column_set # noqa
+from ._collections import EMPTY_DICT # noqa
from ._collections import EMPTY_SET # noqa
from ._collections import FacadeDict # noqa
from ._collections import flatten_iterator # noqa
def __reduce__(self):
return _immutabledict_reconstructor, (dict(self),)
- def union(self, d):
- if not d:
+ def union(self, __d=None):
+ if not __d:
return self
new = dict.__new__(self.__class__)
dict.__init__(new, self)
- dict.update(new, d)
+ dict.update(new, __d)
+ return new
+
+ def _union_w_kw(self, __d=None, **kw):
+ # not sure if C version works correctly w/ this yet
+ if not __d and not kw:
+ return self
+
+ new = dict.__new__(self.__class__)
+ dict.__init__(new, self)
+ if __d:
+ dict.update(new, __d)
+ dict.update(new, kw)
return new
def merge_with(self, *dicts):
return immutabledict(*arg)
+def coerce_to_immutabledict(d):
+ if not d:
+ return EMPTY_DICT
+ elif isinstance(d, immutabledict):
+ return d
+ else:
+ return immutabledict(d)
+
+
+EMPTY_DICT = immutabledict()
+
+
class FacadeDict(ImmutableContainer, dict):
"""A dictionary that is not publicly mutable."""
]
+class Foo:
+ x = 10
+ y = 15
+
+
dml.Insert.argument_for("sqlite", "foo", None)
dml.Update.argument_for("sqlite", "foo", None)
dml.Delete.argument_for("sqlite", "foo", None)
def two():
r = random.randint(1, 10)
- q = 20
+ q = 408
return LambdaElement(
lambda: table_a.c.a + q == r, roles.WhereHavingRole
)
roles.WhereHavingRole,
)
- class Foo:
- x = 10
- y = 15
-
def four():
return LambdaElement(
lambda: and_(table_a.c.a == Foo.x), roles.WhereHavingRole
lambda s: s.where(table_a.c.a == value)
)
+ from sqlalchemy.sql import lambdas
+
+ def eight():
+ q = 5
+ return lambdas.DeferredLambdaElement(
+ lambda t: t.c.a > q,
+ roles.WhereHavingRole,
+ lambda_args=(table_a,),
+ )
+
return [
one(),
two(),
five(),
six(),
seven(),
+ eight(),
]
dont_compare_values_fixtures.append(_lambda_fixtures)
"JOIN table2 ON table1.col1 = table2.col2) AS anon_1",
)
+ def test_this_thing_using_setup_joins_three(self):
+
+ j = t1.join(t2, t1.c.col1 == t2.c.col2)
+
+ s1 = select(j)
+
+ s2 = s1.join(t3, t1.c.col1 == t3.c.col1)
+
+ self.assert_compile(
+ s2,
+ "SELECT table1.col1, table1.col2, table1.col3, "
+ "table2.col1, table2.col2, table2.col3 FROM table1 "
+ "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+ "ON table3.col1 = table1.col1",
+ )
+
+ vis = sql_util.ClauseAdapter(j)
+
+ s3 = vis.traverse(s1)
+
+ s4 = s3.join(t3, t1.c.col1 == t3.c.col1)
+
+ self.assert_compile(
+ s4,
+ "SELECT table1.col1, table1.col2, table1.col3, "
+ "table2.col1, table2.col2, table2.col3 FROM table1 "
+ "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+ "ON table3.col1 = table1.col1",
+ )
+
+ s5 = vis.traverse(s3)
+
+ s6 = s5.join(t3, t1.c.col1 == t3.c.col1)
+
+ self.assert_compile(
+ s6,
+ "SELECT table1.col1, table1.col2, table1.col3, "
+ "table2.col1, table2.col2, table2.col3 FROM table1 "
+ "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+ "ON table3.col1 = table1.col1",
+ )
+
+ def test_this_thing_using_setup_joins_four(self):
+
+ j = t1.join(t2, t1.c.col1 == t2.c.col2)
+
+ s1 = select(j)
+
+ assert not s1._from_obj
+
+ s2 = s1.join(t3, t1.c.col1 == t3.c.col1)
+
+ self.assert_compile(
+ s2,
+ "SELECT table1.col1, table1.col2, table1.col3, "
+ "table2.col1, table2.col2, table2.col3 FROM table1 "
+ "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+ "ON table3.col1 = table1.col1",
+ )
+
+ s3 = visitors.replacement_traverse(s1, {}, lambda elem: None)
+
+ s4 = s3.join(t3, t1.c.col1 == t3.c.col1)
+
+ self.assert_compile(
+ s4,
+ "SELECT table1.col1, table1.col2, table1.col3, "
+ "table2.col1, table2.col2, table2.col3 FROM table1 "
+ "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+ "ON table3.col1 = table1.col1",
+ )
+
+ s5 = visitors.replacement_traverse(s3, {}, lambda elem: None)
+
+ s6 = s5.join(t3, t1.c.col1 == t3.c.col1)
+
+ self.assert_compile(
+ s6,
+ "SELECT table1.col1, table1.col2, table1.col3, "
+ "table2.col1, table2.col2, table2.col3 FROM table1 "
+ "JOIN table2 ON table1.col1 = table2.col2 JOIN table3 "
+ "ON table3.col1 = table1.col1",
+ )
+
def test_select_fromtwice_one(self):
t1a = t1.alias()
from sqlalchemy.schema import ForeignKey
from sqlalchemy.schema import Table
from sqlalchemy.sql import and_
+from sqlalchemy.sql import bindparam
from sqlalchemy.sql import coercions
from sqlalchemy.sql import column
from sqlalchemy.sql import join
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
+from sqlalchemy.testing import ne_
from sqlalchemy.testing.assertsql import CompiledSQL
from sqlalchemy.types import Integer
from sqlalchemy.types import String
go(), "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :x_1 AND t1.p = :y_1"
)
+ def test_global_tracking(self):
+ t1 = table("t1", column("q"), column("p"))
+
+ global global_x, global_y
+
+ global_x = 10
+ global_y = 17
+
+ def go():
+ return select([t1]).where(
+ lambda: and_(t1.c.q == global_x, t1.c.p == global_y)
+ )
+
+ self.assert_compile(
+ go(),
+ "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :global_x_1 "
+ "AND t1.p = :global_y_1",
+ checkparams={"global_x_1": 10, "global_y_1": 17},
+ )
+
+ global_y = 9
+
+ self.assert_compile(
+ go(),
+ "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :global_x_1 "
+ "AND t1.p = :global_y_1",
+ checkparams={"global_x_1": 10, "global_y_1": 9},
+ )
+
def test_stale_checker_embedded(self):
def go(x):
- stmt = select([lambda: x])
+ stmt = select(lambda: x)
return stmt
c1 = column("x")
def test_stale_checker_statement(self):
def go(x):
- stmt = lambdas.lambda_stmt(lambda: select([x]))
+ stmt = lambdas.lambda_stmt(lambda: select(x))
return stmt
c1 = column("x")
def test_stale_checker_linked(self):
def go(x, y):
- stmt = lambdas.lambda_stmt(lambda: select([x])) + (
+ stmt = lambdas.lambda_stmt(lambda: select(x)) + (
lambda s: s.where(y > 5)
)
return stmt
- c1 = column("x")
- c2 = column("y")
+ c1 = oldc1 = column("x")
+ c2 = oldc2 = column("y")
s1 = go(c1, c2)
s2 = go(c1, c2)
s3 = go(c1, c2)
self.assert_compile(s3, "SELECT q WHERE p > :p_1")
+ s4 = go(c1, c2)
+ self.assert_compile(s4, "SELECT q WHERE p > :p_1")
+
+ s5 = go(oldc1, oldc2)
+ self.assert_compile(s5, "SELECT x WHERE y > :y_1")
+
+ def test_stmt_lambda_w_additional_hascachekey_variants(self):
+ def go(col_expr, q):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt += lambda stmt: stmt.where(col_expr == q)
+
+ return stmt
+
+ c1 = column("x")
+ c2 = column("y")
+
+ s1 = go(c1, 5)
+ s2 = go(c2, 10)
+ s3 = go(c1, 8)
+ s4 = go(c2, 12)
+
+ self.assert_compile(
+ s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5}
+ )
+ self.assert_compile(
+ s2, "SELECT y WHERE y = :q_1", checkparams={"q_1": 10}
+ )
+ self.assert_compile(
+ s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8}
+ )
+ self.assert_compile(
+ s4, "SELECT y WHERE y = :q_1", checkparams={"q_1": 12}
+ )
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+ s3key = s3._generate_cache_key()
+ s4key = s4._generate_cache_key()
+
+ eq_(s1key[0], s3key[0])
+ eq_(s2key[0], s4key[0])
+ ne_(s1key[0], s2key[0])
+
+ def test_stmt_lambda_w_atonce_whereclause_values_notrack(self):
+ def go(col_expr, whereclause):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt = stmt.add_criteria(
+ lambda stmt: stmt.where(whereclause), enable_tracking=False
+ )
+
+ return stmt
+
+ c1 = column("x")
+
+ s1 = go(c1, c1 == 5)
+ s2 = go(c1, c1 == 10)
+
+ self.assert_compile(
+ s1, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5}
+ )
+
+ # and as we see, this is wrong. Because whereclause
+ # is fixed for the lambda and we do not re-evaluate the closure
+ # for this value changing. this can't be passed unless
+ # enable_tracking=False.
+ self.assert_compile(
+ s2, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5}
+ )
+
+ def test_stmt_lambda_w_atonce_whereclause_values(self):
+ c2 = column("y")
+
+ def go(col_expr, whereclause, x):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt = stmt.add_criteria(
+ lambda stmt: stmt.where(whereclause).order_by(c2 > x),
+ )
+
+ return stmt
+
+ c1 = column("x")
+
+ s1 = go(c1, c1 == 5, 9)
+ s2 = go(c1, c1 == 10, 15)
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+
+ eq_([b.value for b in s1key.bindparams], [5, 9])
+ eq_([b.value for b in s2key.bindparams], [10, 15])
+
+ self.assert_compile(
+ s1,
+ "SELECT x WHERE x = :x_1 ORDER BY y > :x_2",
+ checkparams={"x_1": 5, "x_2": 9},
+ )
+
+ self.assert_compile(
+ s2,
+ "SELECT x WHERE x = :x_1 ORDER BY y > :x_2",
+ checkparams={"x_1": 10, "x_2": 15},
+ )
+
+ def test_stmt_lambda_plain_customtrack(self):
+ c2 = column("y")
+
+ def go(col_expr, whereclause, p):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt = stmt.add_criteria(lambda stmt: stmt.where(whereclause))
+ stmt = stmt.add_criteria(
+ lambda stmt: stmt.order_by(col_expr), track_on=(col_expr,)
+ )
+ stmt = stmt.add_criteria(lambda stmt: stmt.where(col_expr == p))
+ return stmt
+
+ c1 = column("x")
+ c2 = column("y")
+
+ s1 = go(c1, c1 == 5, 9)
+ s2 = go(c1, c1 == 10, 15)
+ s3 = go(c2, c2 == 18, 12)
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+ s3key = s3._generate_cache_key()
+
+ eq_([b.value for b in s1key.bindparams], [5, 9])
+ eq_([b.value for b in s2key.bindparams], [10, 15])
+ eq_([b.value for b in s3key.bindparams], [18, 12])
+
+ self.assert_compile(
+ s1,
+ "SELECT x WHERE x = :x_1 AND x = :p_1 ORDER BY x",
+ checkparams={"x_1": 5, "p_1": 9},
+ )
+
+ self.assert_compile(
+ s2,
+ "SELECT x WHERE x = :x_1 AND x = :p_1 ORDER BY x",
+ checkparams={"x_1": 10, "p_1": 15},
+ )
+
+ self.assert_compile(
+ s3,
+ "SELECT y WHERE y = :y_1 AND y = :p_1 ORDER BY y",
+ checkparams={"y_1": 18, "p_1": 12},
+ )
+
+ def test_stmt_lambda_w_atonce_whereclause_customtrack_binds(self):
+ c2 = column("y")
+
+ # this pattern is *completely unnecessary*, and I would prefer
+ # if we can detect this and just raise, because when it is not done
+ # correctly, it is *extremely* difficult to catch it failing.
+ # however I also can't come up with a reliable way to catch it.
+ # so we will keep the use of "track_on" to be internal.
+
+ def go(col_expr, whereclause, p):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt = stmt.add_criteria(
+ lambda stmt: stmt.where(whereclause).order_by(col_expr > p),
+ track_on=(whereclause, whereclause.right.value),
+ )
+
+ return stmt
+
+ c1 = column("x")
+ c2 = column("y")
+
+ s1 = go(c1, c1 == 5, 9)
+ s2 = go(c1, c1 == 10, 15)
+ s3 = go(c2, c2 == 18, 12)
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+ s3key = s3._generate_cache_key()
+
+ eq_([b.value for b in s1key.bindparams], [5, 9])
+ eq_([b.value for b in s2key.bindparams], [10, 15])
+ eq_([b.value for b in s3key.bindparams], [18, 12])
+
+ self.assert_compile(
+ s1,
+ "SELECT x WHERE x = :x_1 ORDER BY x > :p_1",
+ checkparams={"x_1": 5, "p_1": 9},
+ )
+
+ self.assert_compile(
+ s2,
+ "SELECT x WHERE x = :x_1 ORDER BY x > :p_1",
+ checkparams={"x_1": 10, "p_1": 15},
+ )
+
+ self.assert_compile(
+ s3,
+ "SELECT y WHERE y = :y_1 ORDER BY y > :p_1",
+ checkparams={"y_1": 18, "p_1": 12},
+ )
+
+ def test_stmt_lambda_track_closure_binds_one(self):
+ def go(col_expr, whereclause):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt += lambda stmt: stmt.where(whereclause)
+
+ return stmt
+
+ c1 = column("x")
+
+ s1 = go(c1, c1 == 5)
+ s2 = go(c1, c1 == 10)
+
+ self.assert_compile(
+ s1, "SELECT x WHERE x = :x_1", checkparams={"x_1": 5}
+ )
+ self.assert_compile(
+ s2, "SELECT x WHERE x = :x_1", checkparams={"x_1": 10}
+ )
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+
+ eq_(s1key.key, s2key.key)
+
+ eq_([b.value for b in s1key.bindparams], [5])
+ eq_([b.value for b in s2key.bindparams], [10])
+
+ def test_stmt_lambda_track_closure_binds_two(self):
+ def go(col_expr, whereclause, x, y):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt += lambda stmt: stmt.where(whereclause).where(
+ and_(c1 == x, c1 < y)
+ )
+
+ return stmt
+
+ c1 = column("x")
+
+ s1 = go(c1, c1 == 5, 8, 9)
+ s2 = go(c1, c1 == 10, 12, 14)
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+
+ self.assert_compile(
+ s1,
+ "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1",
+ checkparams={"x_1": 5, "x_2": 8, "y_1": 9},
+ )
+ self.assert_compile(
+ s2,
+ "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1",
+ checkparams={"x_1": 10, "x_2": 12, "y_1": 14},
+ )
+
+ eq_([b.value for b in s1key.bindparams], [5, 8, 9])
+ eq_([b.value for b in s2key.bindparams], [10, 12, 14])
+
+ s1_compiled_cached = s1.compile(cache_key=s1key)
+
+ params = s1_compiled_cached.construct_params(
+ extracted_parameters=s2key[1]
+ )
+
+ eq_(params, {"x_1": 10, "x_2": 12, "y_1": 14})
+
+ def test_stmt_lambda_track_closure_binds_three(self):
+ def go(col_expr, whereclause, x, y):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt += lambda stmt: stmt.where(whereclause)
+ stmt += lambda stmt: stmt.where(and_(c1 == x, c1 < y))
+
+ return stmt
+
+ c1 = column("x")
+
+ s1 = go(c1, c1 == 5, 8, 9)
+ s2 = go(c1, c1 == 10, 12, 14)
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+
+ self.assert_compile(
+ s1,
+ "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1",
+ checkparams={"x_1": 5, "x_2": 8, "y_1": 9},
+ )
+ self.assert_compile(
+ s2,
+ "SELECT x WHERE x = :x_1 AND x = :x_2 AND x < :y_1",
+ checkparams={"x_1": 10, "x_2": 12, "y_1": 14},
+ )
+
+ eq_([b.value for b in s1key.bindparams], [5, 8, 9])
+ eq_([b.value for b in s2key.bindparams], [10, 12, 14])
+
+ s1_compiled_cached = s1.compile(cache_key=s1key)
+
+ params = s1_compiled_cached.construct_params(
+ extracted_parameters=s2key[1]
+ )
+
+ eq_(params, {"x_1": 10, "x_2": 12, "y_1": 14})
+
+ def test_stmt_lambda_w_atonce_whereclause_novalue(self):
+ def go(col_expr, whereclause):
+ stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+ stmt += lambda stmt: stmt.where(whereclause)
+
+ return stmt
+
+ c1 = column("x")
+
+ s1 = go(c1, bindparam("x"))
+
+ self.assert_compile(s1, "SELECT x WHERE :x")
+
+ def test_stmt_lambda_w_additional_hashable_variants(self):
+ # note a Python 2 old style class would fail here because it
+ # isn't hashable. right now we do a hard check for __hash__ which
+ # will raise if the attr isn't present
+ class Thing(object):
+ def __init__(self, col_expr):
+ self.col_expr = col_expr
+
+ def go(thing, q):
+ stmt = lambdas.lambda_stmt(lambda: select(thing.col_expr))
+ stmt += lambda stmt: stmt.where(thing.col_expr == q)
+
+ return stmt
+
+ c1 = Thing(column("x"))
+ c2 = Thing(column("y"))
+
+ s1 = go(c1, 5)
+ s2 = go(c2, 10)
+ s3 = go(c1, 8)
+ s4 = go(c2, 12)
+
+ self.assert_compile(
+ s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5}
+ )
+ self.assert_compile(
+ s2, "SELECT y WHERE y = :q_1", checkparams={"q_1": 10}
+ )
+ self.assert_compile(
+ s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8}
+ )
+ self.assert_compile(
+ s4, "SELECT y WHERE y = :q_1", checkparams={"q_1": 12}
+ )
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+ s3key = s3._generate_cache_key()
+ s4key = s4._generate_cache_key()
+
+ eq_(s1key[0], s3key[0])
+ eq_(s2key[0], s4key[0])
+ ne_(s1key[0], s2key[0])
+
+ def test_stmt_lambda_w_set_of_opts(self):
+
+ stmt = lambdas.lambda_stmt(lambda: select(column("x")))
+
+ opts = {column("x"), column("y")}
+
+ assert_raises_message(
+ exc.ArgumentError,
+ 'Can\'t create a cache key for lambda closure variable "opts" '
+ "because it's a set. try using a list",
+ stmt.__add__,
+ lambda stmt: stmt.options(*opts),
+ )
+
+ def test_stmt_lambda_w_list_of_opts(self):
+ def go(opts):
+ stmt = lambdas.lambda_stmt(lambda: select(column("x")))
+ stmt += lambda stmt: stmt.options(*opts)
+
+ return stmt
+
+ s1 = go([column("a"), column("b")])
+
+ s2 = go([column("a"), column("b")])
+
+ s3 = go([column("q"), column("b")])
+
+ s1key = s1._generate_cache_key()
+ s2key = s2._generate_cache_key()
+ s3key = s3._generate_cache_key()
+
+ eq_(s1key.key, s2key.key)
+ ne_(s1key.key, s3key.key)
+
+ def test_stmt_lambda_hey_theres_multiple_paths(self):
+ def go(x, y):
+ stmt = lambdas.lambda_stmt(lambda: select(column("x")))
+
+ if x > 5:
+ stmt += lambda stmt: stmt.where(column("x") == x)
+ else:
+ stmt += lambda stmt: stmt.where(column("y") == y)
+
+ stmt += lambda stmt: stmt.order_by(column("q"))
+
+ # TODO: need more path variety here to exercise
+ # using a full path key
+
+ return stmt
+
+ s1 = go(2, 5)
+ s2 = go(8, 7)
+ s3 = go(4, 9)
+ s4 = go(10, 1)
+
+ self.assert_compile(s1, "SELECT x WHERE y = :y_1 ORDER BY q")
+ self.assert_compile(s2, "SELECT x WHERE x = :x_1 ORDER BY q")
+ self.assert_compile(s3, "SELECT x WHERE y = :y_1 ORDER BY q")
+ self.assert_compile(s4, "SELECT x WHERE x = :x_1 ORDER BY q")
+
def test_coercion_cols_clause(self):
assert_raises_message(
exc.ArgumentError,
from sqlalchemy import union
from sqlalchemy import util
from sqlalchemy.sql import Alias
+from sqlalchemy.sql import annotation
from sqlalchemy.sql import base
from sqlalchemy.sql import column
from sqlalchemy.sql import elements
annot = obj._annotate({})
ne_(set([obj]), set([annot]))
+ def test_replacement_traverse_preserve(self):
+ """test that replacement traverse that hits an unannotated column
+ does not use it when replacing an annotated column.
+
+ this requires that replacement traverse store elements in the
+ "seen" hash based on id(), not hash.
+
+ """
+ t = table("t", column("x"))
+
+ stmt = select([t.c.x])
+
+ whereclause = annotation._deep_annotate(t.c.x == 5, {"foo": "bar"})
+
+ eq_(whereclause._annotations, {"foo": "bar"})
+ eq_(whereclause.left._annotations, {"foo": "bar"})
+ eq_(whereclause.right._annotations, {"foo": "bar"})
+
+ stmt = stmt.where(whereclause)
+
+ s2 = visitors.replacement_traverse(stmt, {}, lambda elem: None)
+
+ whereclause = s2._where_criteria[0]
+ eq_(whereclause._annotations, {"foo": "bar"})
+ eq_(whereclause.left._annotations, {"foo": "bar"})
+ eq_(whereclause.right._annotations, {"foo": "bar"})
+
def test_proxy_set_iteration_includes_annotated(self):
from sqlalchemy.schema import Column