.. autoclass:: BindParameter
:members:
+.. autoclass:: CacheKey
+ :members:
+
.. autoclass:: Case
:members:
.. autoclass:: ClauseElement
:members:
+ :inherited-members:
.. autoclass:: ClauseList
.. automodule:: sqlalchemy.sql.visitors
:members:
+ :private-members:
\ No newline at end of file
inserted_alias = getattr(self, "inserted_alias", None)
self._post_values_clause = OnDuplicateClause(inserted_alias, values)
- return self
insert = public_factory(Insert, ".dialects.mysql.insert")
return "ONLY " + sqltext
def get_select_precolumns(self, select, **kw):
- if select._distinct is not False:
- if select._distinct is True:
- return "DISTINCT "
- elif isinstance(select._distinct, (list, tuple)):
+ if select._distinct or select._distinct_on:
+ if select._distinct_on:
return (
"DISTINCT ON ("
+ ", ".join(
- [self.process(col, **kw) for col in select._distinct]
+ [
+ self.process(col, **kw)
+ for col in select._distinct_on
+ ]
)
+ ") "
)
else:
- return (
- "DISTINCT ON ("
- + self.process(select._distinct, **kw)
- + ") "
- )
+ return "DISTINCT "
else:
return ""
self._post_values_clause = OnConflictDoUpdate(
constraint, index_elements, index_where, set_, where
)
- return self
@_generative
def on_conflict_do_nothing(
self._post_values_clause = OnConflictDoNothing(
constraint, index_elements, index_where
)
- return self
insert = public_factory(Insert, ".dialects.postgresql.insert")
self.spoil()
else:
for opt in options:
- cache_key = opt._generate_cache_key(cache_path)
+ cache_key = opt._generate_path_cache_key(cache_path)
if cache_key is False:
self.spoil()
elif cache_key is not None:
if hasattr(class_, "_compiler_dispatcher"):
# regenerate default _compiler_dispatch
- visitors._generate_dispatch(class_)
+ visitors._generate_compiler_dispatch(class_)
# remove custom directive
del class_._compiler_dispatcher
from .. import event
from .. import inspection
from .. import util
+from ..sql import base as sql_base
+from ..sql import visitors
@inspection._self_inspects
interfaces._MappedAttribute,
interfaces.InspectionAttr,
interfaces.PropComparator,
+ sql_base.HasCacheKey,
):
"""Base class for :term:`descriptor` objects that intercept
attribute events on behalf of a :class:`.MapperProperty`
if base[key].dispatch._active_history:
self.dispatch._active_history = True
+ _cache_key_traversal = [
+ # ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ ("key", visitors.ExtendedInternalTraversal.dp_string),
+ ("_parententity", visitors.ExtendedInternalTraversal.dp_multi),
+ ("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+ ]
+
@util.memoized_property
def _supports_population(self):
return self.impl.supports_population
for assertion in assertions:
assertion(self, fn.__name__)
fn(self, *args[1:], **kw)
- return self
return generate
from .. import inspection
from .. import util
from ..sql import operators
+from ..sql import visitors
+from ..sql.traversals import HasCacheKey
__all__ = (
)
-class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots):
+class MapperProperty(
+ HasCacheKey, _MappedAttribute, InspectionAttr, util.MemoizedSlots
+):
"""Represent a particular class attribute mapped by :class:`.Mapper`.
The most common occurrences of :class:`.MapperProperty` are the
"info",
)
+ _cache_key_traversal = [
+ ("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+ ("key", visitors.ExtendedInternalTraversal.dp_string),
+ ]
+
cascade = frozenset()
"""The set of 'cascade' attribute names.
self.process_query(query)
- def _generate_cache_key(self, path):
+ def _generate_path_cache_key(self, path):
"""Used by the "baked lazy loader" to see if this option can be cached.
The "baked lazy loader" refers to the :class:`.Query` that is
@inspection._self_inspects
@log.class_logger
-class Mapper(InspectionAttr):
+class Mapper(sql_base.HasCacheKey, InspectionAttr):
"""Define the correlation of class attributes to database table
columns.
"""
return self
+ _cache_key_traversal = [
+ ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj)
+ ]
+
@property
def entity(self):
r"""Part of the inspection API.
from .. import exc
from .. import inspection
from .. import util
-
+from ..sql import visitors
+from ..sql.traversals import HasCacheKey
log = logging.getLogger(__name__)
_DEFAULT_TOKEN = "_sa_default"
-class PathRegistry(object):
+class PathRegistry(HasCacheKey):
"""Represent query load paths and registry functions.
Basically represents structures like:
is_token = False
is_root = False
+ _cache_key_traversal = [
+ ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list)
+ ]
+
def __eq__(self, other):
return other is not None and self.path == other.path
def __len__(self):
return len(self.path)
+ def __hash__(self):
+ return id(self)
+
@property
def length(self):
return len(self.path)
from .. import util
from ..sql import coercions
from ..sql import roles
+from ..sql import visitors
from ..sql.base import _generative
from ..sql.base import Generative
+from ..sql.traversals import HasCacheKey
-class Load(Generative, MapperOption):
+class Load(HasCacheKey, Generative, MapperOption):
"""Represents loader options which modify the state of a
:class:`.Query` in order to affect how various mapped attributes are
loaded.
"""
+ _cache_key_traversal = [
+ ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+ ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ ("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+ (
+ "_context_cache_key",
+ visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples,
+ ),
+ ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict),
+ ]
+
def __init__(self, entity):
insp = inspect(entity)
self.path = insp._path_registry
load._of_type = None
return load
- def _generate_cache_key(self, path):
+ @property
+ def _context_cache_key(self):
+ serialized = []
+ for (key, loader_path), obj in self.context.items():
+ if key != "loader":
+ continue
+ serialized.append(loader_path + (obj,))
+ return serialized
+
+ def _generate_path_cache_key(self, path):
if path.path[0].is_aliased_class:
return False
self._to_bind = []
self.local_opts = {}
+ _cache_key_traversal = [
+ ("path", visitors.ExtendedInternalTraversal.dp_multi_list),
+ ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ ("_to_bind", visitors.ExtendedInternalTraversal.dp_has_cache_key_list),
+ ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict),
+ ]
+
_is_chain_link = False
- def _generate_cache_key(self, path):
+ def _generate_path_cache_key(self, path):
serialized = ()
for val in self._to_bind:
for local_elem, val_elem in zip(self.path, val.path):
else:
opt = val._bind_loader([path.path[0]], None, None, False)
if opt:
- c_key = opt._generate_cache_key(path)
+ c_key = opt._generate_path_cache_key(path)
if c_key is False:
return False
elif c_key:
opt = meth(opt, all_tokens[-1], **kw)
opt._is_chain_link = False
-
return opt
def _chop_path(self, to_chop, path):
from .. import inspection
from .. import sql
from .. import util
+from ..sql import base as sql_base
from ..sql import coercions
from ..sql import expression
from ..sql import roles
from ..sql import util as sql_util
+from ..sql import visitors
all_cascades = frozenset(
return str(self._aliased_insp)
-class AliasedInsp(InspectionAttr):
+class AliasedInsp(sql_base.HasCacheKey, InspectionAttr):
"""Provide an inspection interface for an
:class:`.AliasedClass` object.
def __clause_element__(self):
return self.selectable
+ _cache_key_traversal = [
+ ("name", visitors.ExtendedInternalTraversal.dp_string),
+ ("_adapt_on_names", visitors.ExtendedInternalTraversal.dp_boolean),
+ ("selectable", visitors.ExtendedInternalTraversal.dp_clauseelement),
+ ]
+
@property
def class_(self):
"""Return the mapped class ultimately represented by this
"""
from . import operators
+from .base import HasCacheKey
+from .visitors import InternalTraversal
from .. import util
-class SupportsCloneAnnotations(object):
+class SupportsAnnotations(object):
+ @util.memoized_property
+ def _annotation_traversals(self):
+ return [
+ (
+ key,
+ InternalTraversal.dp_has_cache_key
+ if isinstance(value, HasCacheKey)
+ else InternalTraversal.dp_plain_obj,
+ )
+ for key, value in self._annotations.items()
+ ]
+
+
+class SupportsCloneAnnotations(SupportsAnnotations):
_annotations = util.immutabledict()
+ _traverse_internals = [
+ ("_annotations", InternalTraversal.dp_annotations_state)
+ ]
+
def _annotate(self, values):
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
"""
new = self._clone()
new._annotations = new._annotations.union(values)
+ new.__dict__.pop("_annotation_traversals", None)
return new
def _with_annotations(self, values):
"""
new = self._clone()
new._annotations = util.immutabledict(values)
+ new.__dict__.pop("_annotation_traversals", None)
return new
def _deannotate(self, values=None, clone=False):
# the expression for a deep deannotation
new = self._clone()
new._annotations = {}
+ new.__dict__.pop("_annotation_traversals", None)
return new
else:
return self
-class SupportsWrappingAnnotations(object):
+class SupportsWrappingAnnotations(SupportsAnnotations):
def _annotate(self, values):
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
def __init__(self, element, values):
self.__dict__ = element.__dict__.copy()
+ self.__dict__.pop("_annotation_traversals", None)
self.__element = element
self._annotations = values
self._hash = hash(element)
def _with_annotations(self, values):
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
+ clone.__dict__.pop("_annotation_traversals", None)
clone._annotations = values
return clone
"""
- def clone(elem):
+ # annotated objects hack the __hash__() method so if we want to
+ # uniquely process them we have to use id()
+
+ cloned_ids = {}
+
+ def clone(elem, **kw):
+ id_ = id(elem)
+
+ if id_ in cloned_ids:
+ return cloned_ids[id_]
+
if (
exclude
and hasattr(elem, "proxy_set")
else:
newelem = elem
newelem._copy_internals(clone=clone)
+ cloned_ids[id_] = newelem
return newelem
if element is not None:
def _deep_deannotate(element, values=None):
"""Deep copy the given element, removing annotations."""
- cloned = util.column_dict()
+ cloned = {}
- def clone(elem):
- # if a values dict is given,
- # the elem must be cloned each time it appears,
- # as there may be different annotations in source
- # elements that are remaining. if totally
- # removing all annotations, can assume the same
- # slate...
- if values or elem not in cloned:
+ def clone(elem, **kw):
+ if values:
+ key = id(elem)
+ else:
+ key = elem
+
+ if key not in cloned:
newelem = elem._deannotate(values=values, clone=True)
newelem._copy_internals(clone=clone)
- if not values:
- cloned[elem] = newelem
+ cloned[key] = newelem
return newelem
else:
- return cloned[elem]
+ return cloned[key]
if element is not None:
element = clone(element)
"Annotated%s" % cls.__name__, (base_cls, cls), {}
)
globals()["Annotated%s" % cls.__name__] = anno_cls
+
+ if "_traverse_internals" in cls.__dict__:
+ anno_cls._traverse_internals = list(cls._traverse_internals) + [
+ ("_annotations", InternalTraversal.dp_annotations_state)
+ ]
return anno_cls
import operator
import re
+from .traversals import HasCacheKey # noqa
from .visitors import ClauseVisitor
from .. import exc
from .. import util
def _clone(self):
return self
+ def _copy_internals(self, **kw):
+ pass
+
+
+class HasMemoized(object):
+ def _reset_memoizations(self):
+ self._memoized_property.expire_instance(self)
+
+ def _reset_exported(self):
+ self._memoized_property.expire_instance(self)
+
+ def _copy_internals(self, **kw):
+ super(HasMemoized, self)._copy_internals(**kw)
+ self._reset_memoizations()
+
def _from_objects(*elements):
return itertools.chain(*[element._from_objects for element in elements])
def _generative(fn):
+ """non-caching _generative() decorator.
+
+ This is basically the legacy decorator that copies the object and
+ runs a method on the new copy.
+
+ """
+
@util.decorator
- def _generative(fn, *args, **kw):
+ def _generative(fn, self, *args, **kw):
"""Mark a method as generative."""
- self = args[0]._generate()
- fn(self, *args[1:], **kw)
+ self = self._generate()
+ x = fn(self, *args, **kw)
+ assert x is None, "generative methods must have no return value"
return self
decorated = _generative(fn)
class Generative(object):
- """Allow a ClauseElement to generate itself via the
- @_generative decorator.
-
- """
+ """Provide a method-chaining pattern in conjunction with the
+ @_generative decorator."""
def _generate(self):
s = self.__class__.__new__(self.__class__)
+++ /dev/null
-from collections import deque
-
-from . import operators
-from .. import util
-
-
-SKIP_TRAVERSE = util.symbol("skip_traverse")
-
-
-def compare(obj1, obj2, **kw):
- if kw.get("use_proxies", False):
- strategy = ColIdentityComparatorStrategy()
- else:
- strategy = StructureComparatorStrategy()
-
- return strategy.compare(obj1, obj2, **kw)
-
-
-class StructureComparatorStrategy(object):
- __slots__ = "compare_stack", "cache"
-
- def __init__(self):
- self.compare_stack = deque()
- self.cache = set()
-
- def compare(self, obj1, obj2, **kw):
- stack = self.compare_stack
- cache = self.cache
-
- stack.append((obj1, obj2))
-
- while stack:
- left, right = stack.popleft()
-
- if left is right:
- continue
- elif left is None or right is None:
- # we know they are different so no match
- return False
- elif (left, right) in cache:
- continue
- cache.add((left, right))
-
- visit_name = left.__visit_name__
-
- # we're not exactly looking for identical types, because
- # there are things like Column and AnnotatedColumn. So the
- # visit_name has to at least match up
- if visit_name != right.__visit_name__:
- return False
-
- meth = getattr(self, "compare_%s" % visit_name, None)
-
- if meth:
- comparison = meth(left, right, **kw)
- if comparison is False:
- return False
- elif comparison is SKIP_TRAVERSE:
- continue
-
- for c1, c2 in util.zip_longest(
- left.get_children(column_collections=False),
- right.get_children(column_collections=False),
- fillvalue=None,
- ):
- if c1 is None or c2 is None:
- # collections are different sizes, comparison fails
- return False
- stack.append((c1, c2))
-
- return True
-
- def compare_inner(self, obj1, obj2, **kw):
- stack = self.compare_stack
- try:
- self.compare_stack = deque()
- return self.compare(obj1, obj2, **kw)
- finally:
- self.compare_stack = stack
-
- def _compare_unordered_sequences(self, seq1, seq2, **kw):
- if seq1 is None:
- return seq2 is None
-
- completed = set()
- for clause in seq1:
- for other_clause in set(seq2).difference(completed):
- if self.compare_inner(clause, other_clause, **kw):
- completed.add(other_clause)
- break
- return len(completed) == len(seq1) == len(seq2)
-
- def compare_bindparam(self, left, right, **kw):
- # note the ".key" is often generated from id(self) so can't
- # be compared, as far as determining structure.
- return (
- left.type._compare_type_affinity(right.type)
- and left.value == right.value
- and left.callable == right.callable
- and left._orig_key == right._orig_key
- )
-
- def compare_clauselist(self, left, right, **kw):
- if left.operator is right.operator:
- if operators.is_associative(left.operator):
- if self._compare_unordered_sequences(
- left.clauses, right.clauses
- ):
- return SKIP_TRAVERSE
- else:
- return False
- else:
- # normal ordered traversal
- return True
- else:
- return False
-
- def compare_unary(self, left, right, **kw):
- if left.operator:
- disp = self._get_operator_dispatch(
- left.operator, "unary", "operator"
- )
- if disp is not None:
- result = disp(left, right, left.operator, **kw)
- if result is not True:
- return result
- elif left.modifier:
- disp = self._get_operator_dispatch(
- left.modifier, "unary", "modifier"
- )
- if disp is not None:
- result = disp(left, right, left.operator, **kw)
- if result is not True:
- return result
- return (
- left.operator == right.operator and left.modifier == right.modifier
- )
-
- def compare_binary(self, left, right, **kw):
- disp = self._get_operator_dispatch(left.operator, "binary", None)
- if disp:
- result = disp(left, right, left.operator, **kw)
- if result is not True:
- return result
-
- if left.operator == right.operator:
- if operators.is_commutative(left.operator):
- if (
- compare(left.left, right.left, **kw)
- and compare(left.right, right.right, **kw)
- ) or (
- compare(left.left, right.right, **kw)
- and compare(left.right, right.left, **kw)
- ):
- return SKIP_TRAVERSE
- else:
- return False
- else:
- return True
- else:
- return False
-
- def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
- # used by compare_binary, compare_unary
- attrname = "visit_%s_%s%s" % (
- operator_.__name__,
- qualifier1,
- "_" + qualifier2 if qualifier2 else "",
- )
- return getattr(self, attrname, None)
-
- def visit_function_as_comparison_op_binary(
- self, left, right, operator, **kw
- ):
- return (
- left.left_index == right.left_index
- and left.right_index == right.right_index
- )
-
- def compare_function(self, left, right, **kw):
- return left.name == right.name
-
- def compare_column(self, left, right, **kw):
- if left.table is not None:
- self.compare_stack.appendleft((left.table, right.table))
- return (
- left.key == right.key
- and left.name == right.name
- and (
- left.type._compare_type_affinity(right.type)
- if left.type is not None
- else right.type is None
- )
- and left.is_literal == right.is_literal
- )
-
- def compare_collation(self, left, right, **kw):
- return left.collation == right.collation
-
- def compare_type_coerce(self, left, right, **kw):
- return left.type._compare_type_affinity(right.type)
-
- @util.dependencies("sqlalchemy.sql.elements")
- def compare_alias(self, elements, left, right, **kw):
- return (
- left.name == right.name
- if not isinstance(left.name, elements._anonymous_label)
- else isinstance(right.name, elements._anonymous_label)
- )
-
- def compare_cte(self, elements, left, right, **kw):
- raise NotImplementedError("TODO")
-
- def compare_extract(self, left, right, **kw):
- return left.field == right.field
-
- def compare_textual_label_reference(self, left, right, **kw):
- return left.element == right.element
-
- def compare_slice(self, left, right, **kw):
- return (
- left.start == right.start
- and left.stop == right.stop
- and left.step == right.step
- )
-
- def compare_over(self, left, right, **kw):
- return left.range_ == right.range_ and left.rows == right.rows
-
- @util.dependencies("sqlalchemy.sql.elements")
- def compare_label(self, elements, left, right, **kw):
- return left._type._compare_type_affinity(right._type) and (
- left.name == right.name
- if not isinstance(left.name, elements._anonymous_label)
- else isinstance(right.name, elements._anonymous_label)
- )
-
- def compare_typeclause(self, left, right, **kw):
- return left.type._compare_type_affinity(right.type)
-
- def compare_join(self, left, right, **kw):
- return left.isouter == right.isouter and left.full == right.full
-
- def compare_table(self, left, right, **kw):
- if left.name != right.name:
- return False
-
- self.compare_stack.extendleft(
- util.zip_longest(left.columns, right.columns)
- )
-
- def compare_compound_select(self, left, right, **kw):
-
- if not self._compare_unordered_sequences(
- left.selects, right.selects, **kw
- ):
- return False
-
- if left.keyword != right.keyword:
- return False
-
- if left._for_update_arg != right._for_update_arg:
- return False
-
- if not self.compare_inner(
- left._order_by_clause, right._order_by_clause, **kw
- ):
- return False
-
- if not self.compare_inner(
- left._group_by_clause, right._group_by_clause, **kw
- ):
- return False
-
- return SKIP_TRAVERSE
-
- def compare_select(self, left, right, **kw):
- if not self._compare_unordered_sequences(
- left._correlate, right._correlate
- ):
- return False
- if not self._compare_unordered_sequences(
- left._correlate_except, right._correlate_except
- ):
- return False
-
- if not self._compare_unordered_sequences(
- left._from_obj, right._from_obj
- ):
- return False
-
- if left._for_update_arg != right._for_update_arg:
- return False
-
- return True
-
- def compare_textual_select(self, left, right, **kw):
- self.compare_stack.extendleft(
- util.zip_longest(left.column_args, right.column_args)
- )
- return left.positional == right.positional
-
-
-class ColIdentityComparatorStrategy(StructureComparatorStrategy):
- def compare_column_element(
- self, left, right, use_proxies=True, equivalents=(), **kw
- ):
- """Compare ColumnElements using proxies and equivalent collections.
-
- This is a comparison strategy specific to the ORM.
- """
-
- to_compare = (right,)
- if equivalents and right in equivalents:
- to_compare = equivalents[right].union(to_compare)
-
- for oth in to_compare:
- if use_proxies and left.shares_lineage(oth):
- return True
- elif hash(left) == hash(right):
- return True
- else:
- return False
-
- def compare_column(self, left, right, **kw):
- return self.compare_column_element(left, right, **kw)
-
- def compare_label(self, left, right, **kw):
- return self.compare_column_element(left, right, **kw)
-
- def compare_table(self, left, right, **kw):
- # tables compare on identity, since it's not really feasible to
- # compare them column by column with the above rules
- return left is right
return self
+class prefix_anon_map(dict):
+ """A map that creates new keys for missing key access.
+
+ Considers keys of the form "<ident> <name>" to produce
+ new symbols "<name>_<index>", where "index" is an incrementing integer
+ corresponding to <name>.
+
+ Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
+ is otherwise usually used for this type of operation.
+
+ """
+
+ def __missing__(self, key):
+ (ident, derived) = key.split(" ", 1)
+ anonymous_counter = self.get(derived, 1)
+ self[derived] = anonymous_counter + 1
+ value = derived + "_" + str(anonymous_counter)
+ self[key] = value
+ return value
+
+
class SQLCompiler(Compiled):
"""Default implementation of :class:`.Compiled`.
# a map which tracks "anonymous" identifiers that are created on
# the fly here
- self.anon_map = util.PopulateDict(self._process_anon)
+ self.anon_map = prefix_anon_map()
# a map which tracks "truncated" names based on
# dialect.label_length or dialect.max_identifier_length
def _anonymize(self, name):
return name % self.anon_map
- def _process_anon(self, key):
- (ident, derived) = key.split(" ", 1)
- anonymous_counter = self.anon_map.get(derived, 1)
- self.anon_map[derived] = anonymous_counter + 1
- return derived + "_" + str(anonymous_counter)
-
def bindparam_string(
self,
name,
def _inv_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.__inv__`."""
+
+ # undocumented element currently used by the ORM for
+ # relationship.contains()
if hasattr(expr, "negation_clause"):
return expr.negation_clause
else:
import operator
import re
-from . import clause_compare
from . import coercions
from . import operators
from . import roles
+from . import traversals
from . import type_api
from .annotation import Annotated
from .annotation import SupportsWrappingAnnotations
from .base import _clone
from .base import _generative
from .base import Executable
+from .base import HasCacheKey
+from .base import HasMemoized
from .base import Immutable
from .base import NO_ARG
from .base import PARSE_AUTOCOMMIT
from .coercions import _document_text_coercion
+from .traversals import _copy_internals
+from .traversals import _get_children
+from .traversals import NO_CACHE
from .visitors import cloned_traverse
+from .visitors import InternalTraversal
from .visitors import traverse
-from .visitors import Visitable
+from .visitors import Traversible
from .. import exc
from .. import inspection
from .. import util
@inspection._self_inspects
-class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
+class ClauseElement(
+ roles.SQLRole, SupportsWrappingAnnotations, HasCacheKey, Traversible
+):
"""Base class for elements of a programmatically constructed SQL
expression.
_order_by_label_element = None
+ @property
+ def _cache_key_traversal(self):
+ try:
+ return self._traverse_internals
+ except AttributeError:
+ return NO_CACHE
+
def _clone(self):
"""Create a shallow copy of this ClauseElement.
"""
return self
- def _cache_key(self, **kw):
- """return an optional cache key.
-
- The cache key is a tuple which can contain any series of
- objects that are hashable and also identifies
- this object uniquely within the presence of a larger SQL expression
- or statement, for the purposes of caching the resulting query.
-
- The cache key should be based on the SQL compiled structure that would
- ultimately be produced. That is, two structures that are composed in
- exactly the same way should produce the same cache key; any difference
- in the strucures that would affect the SQL string or the type handlers
- should result in a different cache key.
-
- If a structure cannot produce a useful cache key, it should raise
- NotImplementedError, which will result in the entire structure
- for which it's part of not being useful as a cache key.
-
-
- """
- raise NotImplementedError()
-
@property
def _constructor(self):
"""return the 'constructor' for this ClauseElement.
(see :class:`.ColumnElement`)
"""
- return clause_compare.compare(self, other, **kw)
+ return traversals.compare(self, other, **kw)
- def _copy_internals(self, clone=_clone, **kw):
+ def _copy_internals(self, **kw):
"""Reassign internal elements to be clones of themselves.
Called during a copy-and-traverse operation on newly
traversal, cloned traversal, annotations).
"""
- pass
- def get_children(self, **kwargs):
- r"""Return immediate child elements of this :class:`.ClauseElement`.
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ return
+
+ for attrname, obj, meth in _copy_internals.run_generated_dispatch(
+ self, traverse_internals, "_generated_copy_internals_traversal"
+ ):
+ if obj is not None:
+ result = meth(self, obj, **kw)
+ if result is not None:
+ setattr(self, attrname, result)
+
+ def get_children(self, omit_attrs=None, **kw):
+ r"""Return immediate child :class:`.Traversible` elements of this
+ :class:`.Traversible`.
This is used for visit traversal.
- \**kwargs may contain flags that change the collection that is
+ \**kw may contain flags that change the collection that is
returned, for example to return a subset of items in order to
cut down on larger traversals, or to return child items from a
different context (such as schema-level collections instead of
clause-level).
"""
- return []
+ result = []
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ return result
+
+ for attrname, obj, meth in _get_children.run_generated_dispatch(
+ self, traverse_internals, "_generated_get_children_traversal"
+ ):
+ if obj is None or omit_attrs and attrname in omit_attrs:
+ continue
+ result.extend(meth(obj, **kw))
+ return result
def self_group(self, against=None):
# type: (Optional[Any]) -> ClauseElement
return or_(self, other)
def __invert__(self):
+ # undocumented element currently used by the ORM for
+ # relationship.contains()
if hasattr(self, "negation_clause"):
return self.negation_clause
else:
def _negate(self):
return UnaryExpression(
- self.self_group(against=operators.inv),
- operator=operators.inv,
- negate=None,
+ self.self_group(against=operators.inv), operator=operators.inv
)
def __bool__(self):
else:
return comparator_factory(self)
- def _cache_key(self, **kw):
- raise NotImplementedError(self.__class__)
-
def __getattr__(self, key):
try:
return getattr(self.comparator, key)
__visit_name__ = "bindparam"
+ _traverse_internals = [
+ ("key", InternalTraversal.dp_anon_name),
+ ("type", InternalTraversal.dp_type),
+ ("callable", InternalTraversal.dp_plain_dict),
+ ("value", InternalTraversal.dp_plain_obj),
+ ]
+
_is_crud = False
_expanding_in_types = ()
)
return c
- def _cache_key(self, bindparams=None, **kw):
- if bindparams is None:
- # even though _cache_key is a private method, we would like to
- # be super paranoid about this point. You can't include the
- # "value" or "callable" in the cache key, because the value is
- # not part of the structure of a statement and is likely to
- # change every time. However you cannot *throw it away* either,
- # because you can't invoke the statement without the parameter
- # values that were explicitly placed. So require that they
- # are collected here to make sure this happens.
- if self._value_required_for_cache:
- raise NotImplementedError(
- "bindparams collection argument required for _cache_key "
- "implementation. Bound parameter cache keys are not safe "
- "to use without accommodating for the value or callable "
- "within the parameter itself."
- )
- else:
- bindparams.append(self)
- return (BindParameter, self.type._cache_key, self._orig_key)
+ def _gen_cache_key(self, anon_map, bindparams):
+ if self in anon_map:
+ return (anon_map[self], self.__class__)
+
+ id_ = anon_map[self]
+ bindparams.append(self)
+
+ return (
+ id_,
+ self.__class__,
+ self.type._gen_cache_key,
+ traversals._resolve_name_for_compare(self, self.key, anon_map),
+ )
def _convert_to_unique(self):
if not self.unique:
__visit_name__ = "typeclause"
+ _traverse_internals = [("type", InternalTraversal.dp_type)]
+
def __init__(self, type_):
self.type = type_
- def _cache_key(self, **kw):
- return (TypeClause, self.type._cache_key)
-
class TextClause(
roles.DDLConstraintColumnRole,
__visit_name__ = "textclause"
+ _traverse_internals = [
+ ("_bindparams", InternalTraversal.dp_string_clauseelement_dict),
+ ("text", InternalTraversal.dp_string),
+ ]
+
_is_text_clause = True
_is_textual = True
else:
return self
- def _copy_internals(self, clone=_clone, **kw):
- self._bindparams = dict(
- (b.key, clone(b, **kw)) for b in self._bindparams.values()
- )
-
- def get_children(self, **kwargs):
- return list(self._bindparams.values())
-
- def _cache_key(self, **kw):
- return (self.text,) + tuple(
- bind._cache_key for bind in self._bindparams.values()
- )
-
class Null(roles.ConstExprRole, ColumnElement):
"""Represent the NULL keyword in a SQL statement.
__visit_name__ = "null"
+ _traverse_internals = []
+
@util.memoized_property
def type(self):
return type_api.NULLTYPE
return Null()
- def _cache_key(self, **kw):
- return (Null,)
-
class False_(roles.ConstExprRole, ColumnElement):
"""Represent the ``false`` keyword, or equivalent, in a SQL statement.
"""
__visit_name__ = "false"
+ _traverse_internals = []
@util.memoized_property
def type(self):
return False_()
- def _cache_key(self, **kw):
- return (False_,)
-
class True_(roles.ConstExprRole, ColumnElement):
"""Represent the ``true`` keyword, or equivalent, in a SQL statement.
__visit_name__ = "true"
+ _traverse_internals = []
+
@util.memoized_property
def type(self):
return type_api.BOOLEANTYPE
return True_()
- def _cache_key(self, **kw):
- return (True_,)
-
class ClauseList(
roles.InElementRole,
__visit_name__ = "clauselist"
+ _traverse_internals = [
+ ("clauses", InternalTraversal.dp_clauseelement_list),
+ ("operator", InternalTraversal.dp_operator),
+ ]
+
def __init__(self, *clauses, **kwargs):
self.operator = kwargs.pop("operator", operators.comma_op)
self.group = kwargs.pop("group", True)
coercions.expect(self._text_converter_role, clause)
)
- def _copy_internals(self, clone=_clone, **kw):
- self.clauses = [clone(clause, **kw) for clause in self.clauses]
-
- def get_children(self, **kwargs):
- return self.clauses
-
- def _cache_key(self, **kw):
- return (ClauseList, self.operator) + tuple(
- clause._cache_key(**kw) for clause in self.clauses
- )
-
@property
def _from_objects(self):
return list(itertools.chain(*[c._from_objects for c in self.clauses]))
"BooleanClauseList has a private constructor"
)
- def _cache_key(self, **kw):
- return (BooleanClauseList, self.operator) + tuple(
- clause._cache_key(**kw) for clause in self.clauses
- )
-
@classmethod
def _construct(cls, operator, continue_on, skip_on, *clauses, **kw):
convert_clauses = []
class Tuple(ClauseList, ColumnElement):
"""Represent a SQL tuple."""
+ _traverse_internals = ClauseList._traverse_internals + []
+
def __init__(self, *clauses, **kw):
"""Return a :class:`.Tuple`.
def _select_iterable(self):
return (self,)
- def _cache_key(self, **kw):
- return (Tuple,) + tuple(
- clause._cache_key(**kw) for clause in self.clauses
- )
-
def _bind_param(self, operator, obj, type_=None):
return Tuple(
*[
__visit_name__ = "case"
+ _traverse_internals = [
+ ("value", InternalTraversal.dp_clauseelement),
+ ("whens", InternalTraversal.dp_clauseelement_tuples),
+ ("else_", InternalTraversal.dp_clauseelement),
+ ]
+
def __init__(self, whens, value=None, else_=None):
r"""Produce a ``CASE`` expression.
else:
self.else_ = None
- def _copy_internals(self, clone=_clone, **kw):
- if self.value is not None:
- self.value = clone(self.value, **kw)
- self.whens = [(clone(x, **kw), clone(y, **kw)) for x, y in self.whens]
- if self.else_ is not None:
- self.else_ = clone(self.else_, **kw)
-
- def get_children(self, **kwargs):
- if self.value is not None:
- yield self.value
- for x, y in self.whens:
- yield x
- yield y
- if self.else_ is not None:
- yield self.else_
-
- def _cache_key(self, **kw):
- return (
- (
- Case,
- self.value._cache_key(**kw)
- if self.value is not None
- else None,
- )
- + tuple(
- (x._cache_key(**kw), y._cache_key(**kw)) for x, y in self.whens
- )
- + (
- self.else_._cache_key(**kw)
- if self.else_ is not None
- else None,
- )
- )
-
@property
def _from_objects(self):
return list(
__visit_name__ = "cast"
+ _traverse_internals = [
+ ("clause", InternalTraversal.dp_clauseelement),
+ ("typeclause", InternalTraversal.dp_clauseelement),
+ ]
+
def __init__(self, expression, type_):
r"""Produce a ``CAST`` expression.
)
self.typeclause = TypeClause(self.type)
- def _copy_internals(self, clone=_clone, **kw):
- self.clause = clone(self.clause, **kw)
- self.typeclause = clone(self.typeclause, **kw)
-
- def get_children(self, **kwargs):
- return self.clause, self.typeclause
-
- def _cache_key(self, **kw):
- return (
- Cast,
- self.clause._cache_key(**kw),
- self.typeclause._cache_key(**kw),
- )
-
@property
def _from_objects(self):
return self.clause._from_objects
return self.clause
-class TypeCoerce(WrapsColumnExpression, ColumnElement):
+class TypeCoerce(HasMemoized, WrapsColumnExpression, ColumnElement):
"""Represent a Python-side type-coercion wrapper.
:class:`.TypeCoerce` supplies the :func:`.expression.type_coerce`
__visit_name__ = "type_coerce"
+ _traverse_internals = [
+ ("clause", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ]
+
+ _memoized_property = util.group_expirable_memoized_property()
+
def __init__(self, expression, type_):
r"""Associate a SQL expression with a particular type, without rendering
``CAST``.
roles.ExpressionElementRole, expression, type_=self.type
)
- def _copy_internals(self, clone=_clone, **kw):
- self.clause = clone(self.clause, **kw)
- self.__dict__.pop("typed_expression", None)
-
- def get_children(self, **kwargs):
- return (self.clause,)
-
- def _cache_key(self, **kw):
- return (TypeCoerce, self.type._cache_key, self.clause._cache_key(**kw))
-
@property
def _from_objects(self):
return self.clause._from_objects
- @util.memoized_property
+ @_memoized_property
def typed_expression(self):
if isinstance(self.clause, BindParameter):
bp = self.clause._clone()
__visit_name__ = "extract"
+ _traverse_internals = [
+ ("expr", InternalTraversal.dp_clauseelement),
+ ("field", InternalTraversal.dp_string),
+ ]
+
def __init__(self, field, expr, **kwargs):
"""Return a :class:`.Extract` construct.
self.field = field
self.expr = coercions.expect(roles.ExpressionElementRole, expr)
- def _copy_internals(self, clone=_clone, **kw):
- self.expr = clone(self.expr, **kw)
-
- def get_children(self, **kwargs):
- return (self.expr,)
-
- def _cache_key(self, **kw):
- return (Extract, self.field, self.expr._cache_key(**kw))
-
@property
def _from_objects(self):
return self.expr._from_objects
__visit_name__ = "label_reference"
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+
def __init__(self, element):
self.element = element
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def _cache_key(self, **kw):
- return (_label_reference, self.element._cache_key(**kw))
-
- def get_children(self, **kwargs):
- return [self.element]
-
@property
def _from_objects(self):
return ()
class _textual_label_reference(ColumnElement):
__visit_name__ = "textual_label_reference"
+ _traverse_internals = [("element", InternalTraversal.dp_string)]
+
def __init__(self, element):
self.element = element
def _text_clause(self):
return TextClause._create_text(self.element)
- def _cache_key(self, **kw):
- return (_textual_label_reference, self.element)
-
class UnaryExpression(ColumnElement):
"""Define a 'unary' expression.
__visit_name__ = "unary"
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("operator", InternalTraversal.dp_operator),
+ ("modifier", InternalTraversal.dp_operator),
+ ]
+
def __init__(
self,
element,
operator=None,
modifier=None,
type_=None,
- negate=None,
wraps_column_expression=False,
):
self.operator = operator
against=self.operator or self.modifier
)
self.type = type_api.to_instance(type_)
- self.negate = negate
self.wraps_column_expression = wraps_column_expression
@classmethod
def _from_objects(self):
return self.element._from_objects
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def _cache_key(self, **kw):
- return (
- UnaryExpression,
- self.element._cache_key(**kw),
- self.operator,
- self.modifier,
- )
-
- def get_children(self, **kwargs):
- return (self.element,)
-
def _negate(self):
- if self.negate is not None:
- return UnaryExpression(
- self.element,
- operator=self.negate,
- negate=self.operator,
- modifier=self.modifier,
- type_=self.type,
- wraps_column_expression=self.wraps_column_expression,
- )
- elif self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
+ if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
return UnaryExpression(
self.self_group(against=operators.inv),
operator=operators.inv,
type_=type_api.BOOLEANTYPE,
wraps_column_expression=self.wraps_column_expression,
- negate=None,
)
else:
return ClauseElement._negate(self)
# type: (Optional[Any]) -> ClauseElement
return self
- def _cache_key(self, **kw):
- return (
- self.element._cache_key(**kw),
- self.type._cache_key,
- self.operator,
- self.negate,
- self.modifier,
- )
-
def _negate(self):
if isinstance(self.element, (True_, False_)):
return self.element._negate()
__visit_name__ = "binary"
+ _traverse_internals = [
+ ("left", InternalTraversal.dp_clauseelement),
+ ("right", InternalTraversal.dp_clauseelement),
+ ("operator", InternalTraversal.dp_operator),
+ ("negate", InternalTraversal.dp_operator),
+ ("modifiers", InternalTraversal.dp_plain_dict),
+ ]
+
_is_implicitly_boolean = True
"""Indicates that any database will know this is a boolean expression
even if the database does not have an explicit boolean datatype.
def _from_objects(self):
return self.left._from_objects + self.right._from_objects
- def _copy_internals(self, clone=_clone, **kw):
- self.left = clone(self.left, **kw)
- self.right = clone(self.right, **kw)
-
- def get_children(self, **kwargs):
- return self.left, self.right
-
- def _cache_key(self, **kw):
- return (
- BinaryExpression,
- self.left._cache_key(**kw),
- self.right._cache_key(**kw),
- )
-
def self_group(self, against=None):
# type: (Optional[Any]) -> ClauseElement
__visit_name__ = "slice"
+ _traverse_internals = [
+ ("start", InternalTraversal.dp_plain_obj),
+ ("stop", InternalTraversal.dp_plain_obj),
+ ("step", InternalTraversal.dp_plain_obj),
+ ]
+
def __init__(self, start, stop, step):
self.start = start
self.stop = stop
assert against is operator.getitem
return self
- def _cache_key(self, **kw):
- return (Slice, self.start, self.stop, self.step)
-
class IndexExpression(BinaryExpression):
"""Represent the class of expressions that are like an "index" operation.
class Grouping(GroupedElement, ColumnElement):
"""Represent a grouping within a column expression"""
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ]
+
def __init__(self, element):
self.element = element
self.type = getattr(element, "type", type_api.NULLTYPE)
def _label(self):
return getattr(self.element, "_label", None) or self.anon_label
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def get_children(self, **kwargs):
- return (self.element,)
-
- def _cache_key(self, **kw):
- return (Grouping, self.element._cache_key(**kw))
-
@property
def _from_objects(self):
return self.element._from_objects
__visit_name__ = "over"
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("order_by", InternalTraversal.dp_clauseelement),
+ ("partition_by", InternalTraversal.dp_clauseelement),
+ ("range_", InternalTraversal.dp_plain_obj),
+ ("rows", InternalTraversal.dp_plain_obj),
+ ]
+
order_by = None
partition_by = None
def type(self):
return self.element.type
- def get_children(self, **kwargs):
- return [
- c
- for c in (self.element, self.partition_by, self.order_by)
- if c is not None
- ]
-
- def _cache_key(self, **kw):
- return (
- (Over,)
- + tuple(
- e._cache_key(**kw) if e is not None else None
- for e in (self.element, self.partition_by, self.order_by)
- )
- + (self.range_, self.rows)
- )
-
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
- if self.partition_by is not None:
- self.partition_by = clone(self.partition_by, **kw)
- if self.order_by is not None:
- self.order_by = clone(self.order_by, **kw)
-
@property
def _from_objects(self):
return list(
__visit_name__ = "withingroup"
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("order_by", InternalTraversal.dp_clauseelement),
+ ]
+
order_by = None
def __init__(self, element, *order_by):
else:
return self.element.type
- def get_children(self, **kwargs):
- return [c for c in (self.element, self.order_by) if c is not None]
-
- def _cache_key(self, **kw):
- return (
- WithinGroup,
- self.element._cache_key(**kw)
- if self.element is not None
- else None,
- self.order_by._cache_key(**kw)
- if self.order_by is not None
- else None,
- )
-
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
- if self.order_by is not None:
- self.order_by = clone(self.order_by, **kw)
-
@property
def _from_objects(self):
return list(
__visit_name__ = "funcfilter"
+ _traverse_internals = [
+ ("func", InternalTraversal.dp_clauseelement),
+ ("criterion", InternalTraversal.dp_clauseelement),
+ ]
+
criterion = None
def __init__(self, func, *criterion):
def type(self):
return self.func.type
- def get_children(self, **kwargs):
- return [c for c in (self.func, self.criterion) if c is not None]
-
- def _copy_internals(self, clone=_clone, **kw):
- self.func = clone(self.func, **kw)
- if self.criterion is not None:
- self.criterion = clone(self.criterion, **kw)
-
- def _cache_key(self, **kw):
- return (
- FunctionFilter,
- self.func._cache_key(**kw),
- self.criterion._cache_key(**kw)
- if self.criterion is not None
- else None,
- )
-
@property
def _from_objects(self):
return list(
)
-class Label(roles.LabeledColumnExprRole, ColumnElement):
+class Label(HasMemoized, roles.LabeledColumnExprRole, ColumnElement):
"""Represents a column label (AS).
Represent a label, as typically applied to any column-level
__visit_name__ = "label"
+ _traverse_internals = [
+ ("name", InternalTraversal.dp_anon_name),
+ ("_type", InternalTraversal.dp_type),
+ ("_element", InternalTraversal.dp_clauseelement),
+ ]
+
+ _memoized_property = util.group_expirable_memoized_property()
+
def __init__(self, name, element, type_=None):
"""Return a :class:`Label` object for the
given :class:`.ColumnElement`.
def __reduce__(self):
return self.__class__, (self.name, self._element, self._type)
- def _cache_key(self, **kw):
- return (Label, self.element._cache_key(**kw), self._resolve_label)
-
@util.memoized_property
def _is_implicitly_boolean(self):
return self.element._is_implicitly_boolean
- @util.memoized_property
+ @_memoized_property
def _allow_label_resolve(self):
return self.element._allow_label_resolve
self._type or getattr(self._element, "type", None)
)
- @util.memoized_property
+ @_memoized_property
def element(self):
return self._element.self_group(against=operators.as_)
def foreign_keys(self):
return self.element.foreign_keys
- def get_children(self, **kwargs):
- return (self.element,)
-
def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
+ self._reset_memoizations()
self._element = clone(self._element, **kw)
- self.__dict__.pop("element", None)
- self.__dict__.pop("_allow_label_resolve", None)
if anonymize_labels:
self.name = self._resolve_label = _anonymous_label(
"%%(%d %s)s"
__visit_name__ = "column"
+ _traverse_internals = [
+ ("name", InternalTraversal.dp_string),
+ ("type", InternalTraversal.dp_type),
+ ("table", InternalTraversal.dp_clauseelement),
+ ("is_literal", InternalTraversal.dp_boolean),
+ ]
+
onupdate = default = server_default = server_onupdate = None
_is_multiparam_column = False
table = property(_get_table, _set_table)
- def _cache_key(self, **kw):
- return (
- self.name,
- self.table.name if self.table is not None else None,
- self.is_literal,
- self.type._cache_key,
- )
-
@_memoized_property
def _from_objects(self):
t = self.table
class CollationClause(ColumnElement):
__visit_name__ = "collation"
+ _traverse_internals = [("collation", InternalTraversal.dp_string)]
+
def __init__(self, collation):
self.collation = collation
- def _cache_key(self, **kw):
- return (CollationClause, self.collation)
-
class _IdentifiedClause(Executable, ClauseElement):
from .base import _from_objects # noqa
from .base import ColumnCollection # noqa
from .base import Executable # noqa
-from .base import Generative # noqa
from .base import PARSE_AUTOCOMMIT # noqa
from .dml import Delete # noqa
from .dml import Insert # noqa
_Case = Case
_Tuple = Tuple
_Over = Over
-_Generative = Generative
_TypeClause = TypeClause
_Extract = Extract
_Exists = Exists
from . import util as sqlutil
from .base import ColumnCollection
from .base import Executable
-from .elements import _clone
from .elements import _type_from_args
from .elements import BinaryExpression
from .elements import BindParameter
from .selectable import Alias
from .selectable import FromClause
from .selectable import Select
-from .visitors import VisitableType
+from .visitors import InternalTraversal
+from .visitors import TraversibleType
from .. import util
"""
+ _traverse_internals = [("clause_expr", InternalTraversal.dp_clauseelement)]
+
packagenames = ()
_has_args = False
+ _memoized_property = FromClause._memoized_property
+
def __init__(self, *clauses, **kwargs):
r"""Construct a :class:`.FunctionElement`.
col = self.label(None)
return ColumnCollection(columns=[(col.key, col)])
- @util.memoized_property
+ @_memoized_property
def clauses(self):
"""Return the underlying :class:`.ClauseList` which contains
the arguments for this :class:`.FunctionElement`.
def _from_objects(self):
return self.clauses._from_objects
- def get_children(self, **kwargs):
- return (self.clause_expr,)
-
- def _cache_key(self, **kw):
- return (FunctionElement, self.clause_expr._cache_key(**kw))
-
- def _copy_internals(self, clone=_clone, **kw):
- self.clause_expr = clone(self.clause_expr, **kw)
- self._reset_exported()
- FunctionElement.clauses._reset(self)
-
def within_group_type(self, within_group):
"""For types that define their return type as based on the criteria
within a WITHIN GROUP (ORDER BY) expression, called by the
class FunctionAsBinary(BinaryExpression):
+ _traverse_internals = [
+ ("sql_function", InternalTraversal.dp_clauseelement),
+ ("left_index", InternalTraversal.dp_plain_obj),
+ ("right_index", InternalTraversal.dp_plain_obj),
+ ("modifiers", InternalTraversal.dp_plain_dict),
+ ]
+
def __init__(self, fn, left_index, right_index):
self.sql_function = fn
self.left_index = left_index
def right(self, value):
self.sql_function.clauses.clauses[self.right_index - 1] = value
- def _copy_internals(self, clone=_clone, **kw):
- self.sql_function = clone(self.sql_function, **kw)
-
- def get_children(self, **kw):
- yield self.sql_function
-
- def _cache_key(self, **kw):
- return (
- FunctionAsBinary,
- self.sql_function._cache_key(**kw),
- self.left_index,
- self.right_index,
- )
-
class _FunctionGenerator(object):
"""Generate SQL function expressions.
__visit_name__ = "function"
+ _traverse_internals = FunctionElement._traverse_internals + [
+ ("packagenames", InternalTraversal.dp_plain_obj),
+ ("name", InternalTraversal.dp_string),
+ ("type", InternalTraversal.dp_type),
+ ]
+
def __init__(self, name, *clauses, **kw):
"""Construct a :class:`.Function`.
unique=True,
)
- def _cache_key(self, **kw):
- return (
- (Function,) + tuple(self.packagenames)
- if self.packagenames
- else () + (self.name, self.clause_expr._cache_key(**kw))
- )
-
-class _GenericMeta(VisitableType):
+class _GenericMeta(TraversibleType):
def __init__(cls, clsname, bases, clsdict):
if annotation.Annotated not in cls.__mro__:
cls.name = name = clsdict.get("name", clsname)
type = sqltypes.Integer()
name = "next_value"
+ _traverse_internals = [
+ ("sequence", InternalTraversal.dp_named_ddl_element)
+ ]
+
def __init__(self, seq, **kw):
assert isinstance(
seq, schema.Sequence
self._bind = kw.get("bind", None)
self.sequence = seq
- def _cache_key(self, **kw):
- return (next_value, self.sequence.name)
-
def compare(self, other, **kw):
return (
isinstance(other, next_value)
and self.sequence.name == other.sequence.name
)
- def get_children(self, **kwargs):
- return []
-
- def _copy_internals(self, **kw):
- pass
-
@property
def _from_objects(self):
return []
from .elements import quoted_name
from .elements import TextClause
from .selectable import TableClause
+from .visitors import InternalTraversal
from .. import event
from .. import exc
from .. import inspection
__visit_name__ = "table"
+ _traverse_internals = TableClause._traverse_internals + [
+ ("schema", InternalTraversal.dp_string)
+ ]
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ return (self,)
+
+ @util.deprecated_params(
+ useexisting=(
+ "0.7",
+ "The :paramref:`.Table.useexisting` parameter is deprecated and "
+ "will be removed in a future release. Please use "
+ ":paramref:`.Table.extend_existing`.",
+ )
+ )
def __new__(cls, *args, **kw):
if not args:
# python3k pickle seems to call this
def get_children(
self, column_collections=True, schema_visitor=False, **kw
):
+ # TODO: consider that we probably don't need column_collections=True
+ # at all, it does not seem to impact anything
if not schema_visitor:
return TableClause.get_children(
self, column_collections=column_collections, **kw
from .base import DedupeColumnCollection
from .base import Executable
from .base import Generative
+from .base import HasMemoized
from .base import Immutable
from .coercions import _document_text_coercion
from .elements import _anonymous_label
from .elements import BindParameter
from .elements import ClauseElement
from .elements import ClauseList
+from .elements import ColumnClause
from .elements import GroupedElement
from .elements import Grouping
from .elements import literal_column
from .elements import True_
from .elements import UnaryExpression
+from .visitors import InternalTraversal
from .. import exc
from .. import util
class HasPrefixes(object):
_prefixes = ()
+ _traverse_internals = [("_prefixes", InternalTraversal.dp_prefix_sequence)]
+
@_generative
@_document_text_coercion(
"expr",
class HasSuffixes(object):
_suffixes = ()
+ _traverse_internals = [("_suffixes", InternalTraversal.dp_prefix_sequence)]
+
@_generative
@_document_text_coercion(
"expr",
)
-class FromClause(roles.AnonymizedFromClauseRole, Selectable):
+class FromClause(HasMemoized, roles.AnonymizedFromClauseRole, Selectable):
"""Represent an element that can be used within the ``FROM``
clause of a ``SELECT`` statement.
"""
return getattr(self, "name", self.__class__.__name__ + " object")
- def _reset_exported(self):
- """delete memoized collections when a FromClause is cloned."""
-
- self._memoized_property.expire_instance(self)
-
def _generate_fromclause_column_proxies(self, fromclause):
fromclause._columns._populate_separate_keys(
col._make_proxy(fromclause) for col in self.c
__visit_name__ = "join"
+ _traverse_internals = [
+ ("left", InternalTraversal.dp_clauseelement),
+ ("right", InternalTraversal.dp_clauseelement),
+ ("onclause", InternalTraversal.dp_clauseelement),
+ ("isouter", InternalTraversal.dp_boolean),
+ ("full", InternalTraversal.dp_boolean),
+ ]
+
_is_join = True
def __init__(self, left, right, onclause=None, isouter=False, full=False):
self.left._refresh_for_new_column(column)
self.right._refresh_for_new_column(column)
- def _copy_internals(self, clone=_clone, **kw):
- self._reset_exported()
- self.left = clone(self.left, **kw)
- self.right = clone(self.right, **kw)
- self.onclause = clone(self.onclause, **kw)
-
- def get_children(self, **kwargs):
- return self.left, self.right, self.onclause
-
- def _cache_key(self, **kw):
- return (
- Join,
- self.isouter,
- self.full,
- self.left._cache_key(**kw),
- self.right._cache_key(**kw),
- self.onclause._cache_key(**kw),
- )
-
def _match_primaries(self, left, right):
if isinstance(left, Join):
left_right = left.right
_is_from_container = True
named_with_column = True
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("name", InternalTraversal.dp_anon_name),
+ ]
+
def __init__(self, *arg, **kw):
raise NotImplementedError(
"The %s class is not intended to be constructed "
def _copy_internals(self, clone=_clone, **kw):
element = clone(self.element, **kw)
+
+ # the element clone is usually against a Table that returns the
+ # same object. don't reset exported .c. collections and other
+ # memoized details if nothing changed
if element is not self.element:
self._reset_exported()
- self.element = element
-
- def get_children(self, column_collections=True, **kw):
- if column_collections:
- for c in self.c:
- yield c
- yield self.element
-
- def _cache_key(self, **kw):
- return (self.__class__, self.element._cache_key(**kw), self._orig_name)
+ self.element = element
@property
def _from_objects(self):
__visit_name__ = "tablesample"
+ _traverse_internals = AliasedReturnsRows._traverse_internals + [
+ ("sampling", InternalTraversal.dp_clauseelement),
+ ("seed", InternalTraversal.dp_clauseelement),
+ ]
+
@classmethod
def _factory(cls, selectable, sampling, name=None, seed=None):
"""Return a :class:`.TableSample` object.
__visit_name__ = "cte"
+ _traverse_internals = (
+ AliasedReturnsRows._traverse_internals
+ + [
+ ("_cte_alias", InternalTraversal.dp_clauseelement),
+ ("_restates", InternalTraversal.dp_clauseelement_unordered_set),
+ ("recursive", InternalTraversal.dp_boolean),
+ ]
+ + HasSuffixes._traverse_internals
+ )
+
@classmethod
def _factory(cls, selectable, name=None, recursive=False):
r"""Return a new :class:`.CTE`, or Common Table Expression instance.
def _copy_internals(self, clone=_clone, **kw):
super(CTE, self)._copy_internals(clone, **kw)
+ # TODO: I don't like that we can't use the traversal data here
if self._cte_alias is not None:
self._cte_alias = clone(self._cte_alias, **kw)
self._restates = frozenset(
[clone(elem, **kw) for elem in self._restates]
)
- def _cache_key(self, *arg, **kw):
- raise NotImplementedError("TODO")
-
def alias(self, name=None, flat=False):
"""Return an :class:`.Alias` of this :class:`.CTE`.
class FromGrouping(GroupedElement, FromClause):
"""Represent a grouping of a FROM clause"""
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+
def __init__(self, element):
self.element = coercions.expect(roles.FromClauseRole, element)
def _hide_froms(self):
return self.element._hide_froms
- def get_children(self, **kwargs):
- return (self.element,)
-
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def _cache_key(self, **kw):
- return (FromGrouping, self.element._cache_key(**kw))
-
@property
def _from_objects(self):
return self.element._from_objects
__visit_name__ = "table"
+ _traverse_internals = [
+ (
+ "columns",
+ InternalTraversal.dp_fromclause_canonical_column_collection,
+ ),
+ ("name", InternalTraversal.dp_string),
+ ]
+
named_with_column = True
implicit_returning = False
self._columns.add(c)
c.table = self
- def get_children(self, column_collections=True, **kwargs):
- if column_collections:
- return [c for c in self.c]
- else:
- return []
-
- def _cache_key(self, **kw):
- return (TableClause, self.name) + tuple(
- col._cache_key(**kw) for col in self._columns
- )
-
@util.dependencies("sqlalchemy.sql.dml")
def insert(self, dml, values=None, inline=False, **kwargs):
"""Generate an :func:`.insert` construct against this
class ForUpdateArg(ClauseElement):
+ _traverse_internals = [
+ ("of", InternalTraversal.dp_clauseelement_list),
+ ("nowait", InternalTraversal.dp_boolean),
+ ("read", InternalTraversal.dp_boolean),
+ ("skip_locked", InternalTraversal.dp_boolean),
+ ]
+
@classmethod
def parse_legacy_select(self, arg):
"""Parse the for_update argument of :func:`.select`.
def __hash__(self):
return id(self)
- def _copy_internals(self, clone=_clone, **kw):
- if self.of is not None:
- self.of = [clone(col, **kw) for col in self.of]
-
- def _cache_key(self, **kw):
- return (
- ForUpdateArg,
- self.nowait,
- self.read,
- self.skip_locked,
- self.of._cache_key(**kw) if self.of is not None else None,
- )
-
def __init__(
self,
nowait=False,
roles.DMLSelectRole,
roles.CompoundElementRole,
roles.InElementRole,
+ HasMemoized,
HasCTE,
Executable,
SupportsCloneAnnotations,
_memoized_property = util.group_expirable_memoized_property()
- def _reset_memoizations(self):
- self._memoized_property.expire_instance(self)
-
def _generate_fromclause_column_proxies(self, fromclause):
# type: (FromClause)
raise NotImplementedError()
"""
__visit_name__ = "grouping"
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
_is_select_container = True
def select_statement(self):
return self.element
- def get_children(self, **kwargs):
- return (self.element,)
-
def self_group(self, against=None):
# type: (Optional[Any]) -> FromClause
return self
"""
return self.element.selected_columns
- def _copy_internals(self, clone=_clone, **kw):
- self.element = clone(self.element, **kw)
-
- def _cache_key(self, **kw):
- return (SelectStatementGrouping, self.element._cache_key(**kw))
-
@property
def _from_objects(self):
return self.element._from_objects
def _label_resolve_dict(self):
raise NotImplementedError()
- def _copy_internals(self, clone=_clone, **kw):
- raise NotImplementedError()
-
class CompoundSelect(GenerativeSelect):
"""Forms the basis of ``UNION``, ``UNION ALL``, and other
__visit_name__ = "compound_select"
+ _traverse_internals = [
+ ("selects", InternalTraversal.dp_clauseelement_list),
+ ("_limit_clause", InternalTraversal.dp_clauseelement),
+ ("_offset_clause", InternalTraversal.dp_clauseelement),
+ ("_order_by_clause", InternalTraversal.dp_clauseelement),
+ ("_group_by_clause", InternalTraversal.dp_clauseelement),
+ ("_for_update_arg", InternalTraversal.dp_clauseelement),
+ ("keyword", InternalTraversal.dp_string),
+ ] + SupportsCloneAnnotations._traverse_internals
+
UNION = util.symbol("UNION")
UNION_ALL = util.symbol("UNION ALL")
EXCEPT = util.symbol("EXCEPT")
"""
return self.selects[0].selected_columns
- def _copy_internals(self, clone=_clone, **kw):
- self._reset_memoizations()
- self.selects = [clone(s, **kw) for s in self.selects]
- if hasattr(self, "_col_map"):
- del self._col_map
- for attr in (
- "_limit_clause",
- "_offset_clause",
- "_order_by_clause",
- "_group_by_clause",
- "_for_update_arg",
- ):
- if getattr(self, attr) is not None:
- setattr(self, attr, clone(getattr(self, attr), **kw))
-
- def get_children(self, **kwargs):
- return [self._order_by_clause, self._group_by_clause] + list(
- self.selects
- )
-
- def _cache_key(self, **kw):
- return (
- (CompoundSelect, self.keyword)
- + tuple(stmt._cache_key(**kw) for stmt in self.selects)
- + (
- self._order_by_clause._cache_key(**kw)
- if self._order_by_clause is not None
- else None,
- )
- + (
- self._group_by_clause._cache_key(**kw)
- if self._group_by_clause is not None
- else None,
- )
- + (
- self._for_update_arg._cache_key(**kw)
- if self._for_update_arg is not None
- else None,
- )
- )
-
def bind(self):
if self._bind:
return self._bind
_hints = util.immutabledict()
_statement_hints = ()
_distinct = False
- _from_cloned = None
+ _distinct_on = ()
_correlate = ()
_correlate_except = None
_memoized_property = SelectBase._memoized_property
+ _traverse_internals = (
+ [
+ ("_from_obj", InternalTraversal.dp_fromclause_ordered_set),
+ ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ ("_whereclause", InternalTraversal.dp_clauseelement),
+ ("_having", InternalTraversal.dp_clauseelement),
+ ("_order_by_clause", InternalTraversal.dp_clauseelement_list),
+ ("_group_by_clause", InternalTraversal.dp_clauseelement_list),
+ ("_correlate", InternalTraversal.dp_clauseelement_unordered_set),
+ (
+ "_correlate_except",
+ InternalTraversal.dp_clauseelement_unordered_set,
+ ),
+ ("_for_update_arg", InternalTraversal.dp_clauseelement),
+ ("_statement_hints", InternalTraversal.dp_statement_hint_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ("_distinct", InternalTraversal.dp_boolean),
+ ("_distinct_on", InternalTraversal.dp_clauseelement_list),
+ ]
+ + HasPrefixes._traverse_internals
+ + HasSuffixes._traverse_internals
+ + SupportsCloneAnnotations._traverse_internals
+ )
+
@util.deprecated_params(
autocommit=(
"0.6",
"""
self._auto_correlate = correlate
if distinct is not False:
- if distinct is True:
- self._distinct = True
- else:
- self._distinct = [
- coercions.expect(roles.WhereHavingRole, e)
- for e in util.to_list(distinct)
- ]
+ self._distinct = True
+ if not isinstance(distinct, bool):
+ self._distinct_on = tuple(
+ [
+ coercions.expect(roles.WhereHavingRole, e)
+ for e in util.to_list(distinct)
+ ]
+ )
if from_obj is not None:
self._from_obj = util.OrderedSet(
GenerativeSelect.__init__(self, **kwargs)
+ # @_memoized_property
@property
def _froms(self):
- # would love to cache this,
- # but there's just enough edge cases, particularly now that
- # declarative encourages construction of SQL expressions
- # without tables present, to just regen this each time.
+ # current roadblock to caching is two tests that test that the
+ # SELECT can be compiled to a string, then a Table is created against
+ # columns, then it can be compiled again and works. this is somewhat
+ # valid as people make select() against declarative class where
+ # columns don't have their Table yet and perhaps some operations
+ # call upon _froms and cache it too soon.
froms = []
seen = set()
- translate = self._from_cloned
for item in itertools.chain(
_from_objects(*self._raw_columns),
raise exc.InvalidRequestError(
"select() construct refers to itself as a FROM"
)
- if translate and item in translate:
- item = translate[item]
if not seen.intersection(item._cloned_set):
froms.append(item)
seen.update(item._cloned_set)
itertools.chain(*[_expand_cloned(f._hide_froms) for f in froms])
)
if toremove:
- # if we're maintaining clones of froms,
- # add the copies out to the toremove list. only include
- # clones that are lexical equivalents.
- if self._from_cloned:
- toremove.update(
- self._from_cloned[f]
- for f in toremove.intersection(self._from_cloned)
- if self._from_cloned[f]._is_lexical_equivalent(f)
- )
# filter out to FROM clauses not in the list,
# using a list to maintain ordering
froms = [f for f in froms if f not in toremove]
return False
def _copy_internals(self, clone=_clone, **kw):
-
# Select() object has been cloned and probably adapted by the
# given clone function. Apply the cloning function to internal
# objects
# as of 0.7.4 we also put the current version of _froms, which
# gets cleared on each generation. previously we were "baking"
# _froms into self._from_obj.
- self._from_cloned = from_cloned = dict(
- (f, clone(f, **kw)) for f in self._from_obj.union(self._froms)
- )
- # 3. update persistent _from_obj with the cloned versions.
- self._from_obj = util.OrderedSet(
- from_cloned[f] for f in self._from_obj
+ all_the_froms = list(
+ itertools.chain(
+ _from_objects(*self._raw_columns),
+ _from_objects(self._whereclause)
+ if self._whereclause is not None
+ else (),
+ )
)
+ new_froms = {f: clone(f, **kw) for f in all_the_froms}
+ # copy FROM collections
- # the _correlate collection is done separately, what can happen
- # here is the same item is _correlate as in _from_obj but the
- # _correlate version has an annotation on it - (specifically
- # RelationshipProperty.Comparator._criterion_exists() does
- # this). Also keep _correlate liberally open with its previous
- # contents, as this set is used for matching, not rendering.
- self._correlate = set(clone(f) for f in self._correlate).union(
- self._correlate
- )
+ self._from_obj = util.OrderedSet(
+ clone(f, **kw) for f in self._from_obj
+ ).union(f for f in new_froms.values() if isinstance(f, Join))
- # do something similar for _correlate_except - this is a more
- # unusual case but same idea applies
+ self._correlate = set(clone(f) for f in self._correlate)
if self._correlate_except:
self._correlate_except = set(
clone(f) for f in self._correlate_except
- ).union(self._correlate_except)
+ )
# 4. clone other things. The difficulty here is that Column
- # objects are not actually cloned, and refer to their original
- # .table, resulting in the wrong "from" parent after a clone
- # operation. Hence _from_cloned and _from_obj supersede what is
- # present here.
+ # objects are usually not altered by a straight clone because they
+ # are dependent on the FROM cloning we just did above in order to
+ # be targeted correctly, or a new FROM we have might be a JOIN
+ # object which doesn't have its own columns. so give the cloner a
+ # hint.
+ def replace(obj, **kw):
+ if isinstance(obj, ColumnClause) and obj.table in new_froms:
+ newelem = new_froms[obj.table].corresponding_column(obj)
+ return newelem
+
+ kw["replace"] = replace
+
+ # TODO: I'd still like to try to leverage the traversal data
self._raw_columns = [clone(c, **kw) for c in self._raw_columns]
for attr in (
"_limit_clause",
if getattr(self, attr) is not None:
setattr(self, attr, clone(getattr(self, attr), **kw))
- # erase _froms collection,
- # etc.
self._reset_memoizations()
def get_children(self, **kwargs):
- """return child elements as per the ClauseElement specification."""
-
- return (
- self._raw_columns
- + list(self._froms)
- + [
- x
- for x in (
- self._whereclause,
- self._having,
- self._order_by_clause,
- self._group_by_clause,
- )
- if x is not None
- ]
- )
-
- def _cache_key(self, **kw):
- return (
- (Select,)
- + ("raw_columns",)
- + tuple(elem._cache_key(**kw) for elem in self._raw_columns)
- + ("elements",)
- + tuple(
- elem._cache_key(**kw) if elem is not None else None
- for elem in (
- self._whereclause,
- self._having,
- self._order_by_clause,
- self._group_by_clause,
- )
- )
- + ("from_obj",)
- + tuple(elem._cache_key(**kw) for elem in self._from_obj)
- + ("correlate",)
- + tuple(
- elem._cache_key(**kw)
- for elem in (
- self._correlate if self._correlate is not None else ()
- )
- )
- + ("correlate_except",)
- + tuple(
- elem._cache_key(**kw)
- for elem in (
- self._correlate_except
- if self._correlate_except is not None
- else ()
- )
- )
- + ("for_update",),
- (
- self._for_update_arg._cache_key(**kw)
- if self._for_update_arg is not None
- else None,
- ),
+ # TODO: define "get_children" traversal items separately?
+ return self._froms + super(Select, self).get_children(
+ omit_attrs=["_from_obj", "_correlate", "_correlate_except"]
)
@_generative
"""
if expr:
expr = [coercions.expect(roles.ByOfRole, e) for e in expr]
- if isinstance(self._distinct, list):
- self._distinct = self._distinct + expr
- else:
- self._distinct = expr
+ self._distinct = True
+ self._distinct_on = self._distinct_on + tuple(expr)
else:
self._distinct = True
__visit_name__ = "textual_select"
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("column_args", InternalTraversal.dp_clauseelement_list),
+ ] + SupportsCloneAnnotations._traverse_internals
+
_is_textual = True
def __init__(self, text, columns, positional=False):
c._make_proxy(fromclause) for c in self.column_args
)
- def _copy_internals(self, clone=_clone, **kw):
- self._reset_memoizations()
- self.element = clone(self.element, **kw)
-
- def get_children(self, **kw):
- return [self.element]
-
- def _cache_key(self, **kw):
- return (TextualSelect, self.element._cache_key(**kw)) + tuple(
- col._cache_key(**kw) for col in self.column_args
- )
-
def _scalar_type(self):
return self.column_args[0].type
--- /dev/null
+from collections import deque
+from collections import namedtuple
+
+from . import operators
+from .visitors import ExtendedInternalTraversal
+from .visitors import InternalTraversal
+from .. import inspect
+from .. import util
+
+SKIP_TRAVERSE = util.symbol("skip_traverse")
+COMPARE_FAILED = False
+COMPARE_SUCCEEDED = True
+NO_CACHE = util.symbol("no_cache")
+
+
+def compare(obj1, obj2, **kw):
+ if kw.get("use_proxies", False):
+ strategy = ColIdentityComparatorStrategy()
+ else:
+ strategy = TraversalComparatorStrategy()
+
+ return strategy.compare(obj1, obj2, **kw)
+
+
+class HasCacheKey(object):
+ _cache_key_traversal = NO_CACHE
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ """return an optional cache key.
+
+ The cache key is a tuple which can contain any series of
+ objects that are hashable and also identifies
+ this object uniquely within the presence of a larger SQL expression
+ or statement, for the purposes of caching the resulting query.
+
+ The cache key should be based on the SQL compiled structure that would
+ ultimately be produced. That is, two structures that are composed in
+ exactly the same way should produce the same cache key; any difference
+ in the strucures that would affect the SQL string or the type handlers
+ should result in a different cache key.
+
+ If a structure cannot produce a useful cache key, it should raise
+ NotImplementedError, which will result in the entire structure
+ for which it's part of not being useful as a cache key.
+
+
+ """
+
+ if self in anon_map:
+ return (anon_map[self], self.__class__)
+
+ id_ = anon_map[self]
+
+ if self._cache_key_traversal is NO_CACHE:
+ anon_map[NO_CACHE] = True
+ return None
+
+ result = (id_, self.__class__)
+
+ for attrname, obj, meth in _cache_key_traversal.run_generated_dispatch(
+ self, self._cache_key_traversal, "_generated_cache_key_traversal"
+ ):
+ if obj is not None:
+ result += meth(attrname, obj, self, anon_map, bindparams)
+ return result
+
+ def _generate_cache_key(self):
+ """return a cache key.
+
+ The cache key is a tuple which can contain any series of
+ objects that are hashable and also identifies
+ this object uniquely within the presence of a larger SQL expression
+ or statement, for the purposes of caching the resulting query.
+
+ The cache key should be based on the SQL compiled structure that would
+ ultimately be produced. That is, two structures that are composed in
+ exactly the same way should produce the same cache key; any difference
+ in the strucures that would affect the SQL string or the type handlers
+ should result in a different cache key.
+
+ The cache key returned by this method is an instance of
+ :class:`.CacheKey`, which consists of a tuple representing the
+ cache key, as well as a list of :class:`.BindParameter` objects
+ which are extracted from the expression. While two expressions
+ that produce identical cache key tuples will themselves generate
+ identical SQL strings, the list of :class:`.BindParameter` objects
+ indicates the bound values which may have different values in
+ each one; these bound parameters must be consulted in order to
+ execute the statement with the correct parameters.
+
+ a :class:`.ClauseElement` structure that does not implement
+ a :meth:`._gen_cache_key` method and does not implement a
+ :attr:`.traverse_internals` attribute will not be cacheable; when
+ such an element is embedded into a larger structure, this method
+ will return None, indicating no cache key is available.
+
+ """
+ bindparams = []
+
+ _anon_map = anon_map()
+ key = self._gen_cache_key(_anon_map, bindparams)
+ if NO_CACHE in _anon_map:
+ return None
+ else:
+ return CacheKey(key, bindparams)
+
+
+class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
+ def __hash__(self):
+ return hash(self.key)
+
+ def __eq__(self, other):
+ return self.key == other.key
+
+
+def _clone(element, **kw):
+ return element._clone()
+
+
+class _CacheKey(ExtendedInternalTraversal):
+ def visit_has_cache_key(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj._gen_cache_key(anon_map, bindparams))
+
+ def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
+ return self.visit_has_cache_key(
+ attrname, inspect(obj), parent, anon_map, bindparams
+ )
+
+ def visit_clauseelement(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj._gen_cache_key(anon_map, bindparams))
+
+ def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
+ return (
+ attrname,
+ obj._gen_cache_key(anon_map, bindparams)
+ if isinstance(obj, HasCacheKey)
+ else obj,
+ )
+
+ def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
+ return (
+ attrname,
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ if isinstance(elem, HasCacheKey)
+ else elem
+ for elem in obj
+ ),
+ )
+
+ def visit_has_cache_key_tuples(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in tup_elem
+ )
+ for tup_elem in obj
+ ),
+ )
+
+ def visit_has_cache_key_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+ )
+
+ def visit_inspectable_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return self.visit_has_cache_key_list(
+ attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
+ )
+
+ def visit_clauseelement_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+ )
+
+ def visit_clauseelement_tuples(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return self.visit_has_cache_key_tuples(
+ attrname, obj, parent, anon_map, bindparams
+ )
+
+ def visit_anon_name(self, attrname, obj, parent, anon_map, bindparams):
+ from . import elements
+
+ name = obj
+ if isinstance(name, elements._anonymous_label):
+ name = name.apply_map(anon_map)
+
+ return (attrname, name)
+
+ def visit_fromclause_ordered_set(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+ )
+
+ def visit_clauseelement_unordered_set(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ cache_keys = [
+ elem._gen_cache_key(anon_map, bindparams) for elem in obj
+ ]
+ return (
+ attrname,
+ tuple(
+ sorted(cache_keys)
+ ), # cache keys all start with (id_, class)
+ )
+
+ def visit_named_ddl_element(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (attrname, obj.name)
+
+ def visit_prefix_sequence(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (clause._gen_cache_key(anon_map, bindparams), strval)
+ for clause, strval in obj
+ ),
+ )
+
+ def visit_statement_hint_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (attrname, obj)
+
+ def visit_table_hint_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ clause._gen_cache_key(anon_map, bindparams),
+ dialect_name,
+ text,
+ )
+ for (clause, dialect_name), text in obj.items()
+ ),
+ )
+
+ def visit_type(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj._gen_cache_key)
+
+ def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, tuple((key, obj[key]) for key in sorted(obj)))
+
+ def visit_string_clauseelement_dict(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (key, obj[key]._gen_cache_key(anon_map, bindparams))
+ for key in sorted(obj)
+ ),
+ )
+
+ def visit_string_multi_dict(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ key,
+ value._gen_cache_key(anon_map, bindparams)
+ if isinstance(value, HasCacheKey)
+ else value,
+ )
+ for key, value in [(key, obj[key]) for key in sorted(obj)]
+ ),
+ )
+
+ def visit_string(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj)
+
+ def visit_boolean(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj)
+
+ def visit_operator(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj)
+
+ def visit_plain_obj(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, obj)
+
+ def visit_fromclause_canonical_column_collection(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(col._gen_cache_key(anon_map, bindparams) for col in obj),
+ )
+
+ def visit_annotations_state(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ key,
+ self.dispatch(sym)(
+ key, obj[key], obj, anon_map, bindparams
+ ),
+ )
+ for key, sym in parent._annotation_traversals
+ ),
+ )
+
+ def visit_unknown_structure(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ anon_map[NO_CACHE] = True
+ return ()
+
+
+_cache_key_traversal = _CacheKey()
+
+
+class _CopyInternals(InternalTraversal):
+ """Generate a _copy_internals internal traversal dispatch for classes
+ with a _traverse_internals collection."""
+
+ def visit_clauseelement(self, parent, element, clone=_clone, **kw):
+ return clone(element, **kw)
+
+ def visit_clauseelement_list(self, parent, element, clone=_clone, **kw):
+ return [clone(clause, **kw) for clause in element]
+
+ def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw):
+ return [
+ tuple(clone(tup_elem, **kw) for tup_elem in elem)
+ for elem in element
+ ]
+
+ def visit_string_clauseelement_dict(
+ self, parent, element, clone=_clone, **kw
+ ):
+ return dict(
+ (key, clone(value, **kw)) for key, value in element.items()
+ )
+
+
+_copy_internals = _CopyInternals()
+
+
+class _GetChildren(InternalTraversal):
+ """Generate a _children_traversal internal traversal dispatch for classes
+ with a _traverse_internals collection."""
+
+ def visit_has_cache_key(self, element, **kw):
+ return (element,)
+
+ def visit_clauseelement(self, element, **kw):
+ return (element,)
+
+ def visit_clauseelement_list(self, element, **kw):
+ return tuple(element)
+
+ def visit_clauseelement_tuples(self, element, **kw):
+ tup = ()
+ for elem in element:
+ tup += elem
+ return tup
+
+ def visit_fromclause_canonical_column_collection(self, element, **kw):
+ if kw.get("column_collections", False):
+ return tuple(element)
+ else:
+ return ()
+
+ def visit_string_clauseelement_dict(self, element, **kw):
+ return tuple(element.values())
+
+ def visit_fromclause_ordered_set(self, element, **kw):
+ return tuple(element)
+
+ def visit_clauseelement_unordered_set(self, element, **kw):
+ return tuple(element)
+
+
+_get_children = _GetChildren()
+
+
+@util.dependencies("sqlalchemy.sql.elements")
+def _resolve_name_for_compare(elements, element, name, anon_map, **kw):
+ if isinstance(name, elements._anonymous_label):
+ name = name.apply_map(anon_map)
+
+ return name
+
+
+class anon_map(dict):
+ """A map that creates new keys for missing key access.
+
+ Produces an incrementing sequence given a series of unique keys.
+
+ This is similar to the compiler prefix_anon_map class although simpler.
+
+ Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
+ is otherwise usually used for this type of operation.
+
+ """
+
+ def __init__(self):
+ self.index = 0
+
+ def __missing__(self, key):
+ self[key] = val = str(self.index)
+ self.index += 1
+ return val
+
+
+class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
+ __slots__ = "stack", "cache", "anon_map"
+
+ def __init__(self):
+ self.stack = deque()
+ self.cache = set()
+
+ def _memoized_attr_anon_map(self):
+ return (anon_map(), anon_map())
+
+ def compare(self, obj1, obj2, **kw):
+ stack = self.stack
+ cache = self.cache
+
+ compare_annotations = kw.get("compare_annotations", False)
+
+ stack.append((obj1, obj2))
+
+ while stack:
+ left, right = stack.popleft()
+
+ if left is right:
+ continue
+ elif left is None or right is None:
+ # we know they are different so no match
+ return False
+ elif (left, right) in cache:
+ continue
+ cache.add((left, right))
+
+ visit_name = left.__visit_name__
+ if visit_name != right.__visit_name__:
+ return False
+
+ meth = getattr(self, "compare_%s" % visit_name, None)
+
+ if meth:
+ attributes_compared = meth(left, right, **kw)
+ if attributes_compared is COMPARE_FAILED:
+ return False
+ elif attributes_compared is SKIP_TRAVERSE:
+ continue
+
+ # attributes_compared is returned as a list of attribute
+ # names that were "handled" by the comparison method above.
+ # remaining attribute names in the _traverse_internals
+ # will be compared.
+ else:
+ attributes_compared = ()
+
+ for (
+ (left_attrname, left_visit_sym),
+ (right_attrname, right_visit_sym),
+ ) in util.zip_longest(
+ left._traverse_internals,
+ right._traverse_internals,
+ fillvalue=(None, None),
+ ):
+ if (
+ left_attrname != right_attrname
+ or left_visit_sym is not right_visit_sym
+ ):
+ if not compare_annotations and (
+ (
+ left_visit_sym
+ is InternalTraversal.dp_annotations_state,
+ )
+ or (
+ right_visit_sym
+ is InternalTraversal.dp_annotations_state,
+ )
+ ):
+ continue
+
+ return False
+ elif left_attrname in attributes_compared:
+ continue
+
+ dispatch = self.dispatch(left_visit_sym)
+ left_child = getattr(left, left_attrname)
+ right_child = getattr(right, right_attrname)
+ if left_child is None:
+ if right_child is not None:
+ return False
+ else:
+ continue
+
+ comparison = dispatch(
+ left, left_child, right, right_child, **kw
+ )
+ if comparison is COMPARE_FAILED:
+ return False
+
+ return True
+
+ def compare_inner(self, obj1, obj2, **kw):
+ comparator = self.__class__()
+ return comparator.compare(obj1, obj2, **kw)
+
+ def visit_has_cache_key(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
+ self.anon_map[1], []
+ ):
+ return COMPARE_FAILED
+
+ def visit_clauseelement(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ self.stack.append((left, right))
+
+ def visit_fromclause_canonical_column_collection(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for lcol, rcol in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((lcol, rcol))
+
+ def visit_fromclause_derived_column_collection(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ pass
+
+ def visit_string_clauseelement_dict(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for lstr, rstr in util.zip_longest(
+ sorted(left), sorted(right), fillvalue=None
+ ):
+ if lstr != rstr:
+ return COMPARE_FAILED
+ self.stack.append((left[lstr], right[rstr]))
+
+ def visit_annotations_state(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ if not kw.get("compare_annotations", False):
+ return
+
+ for (lstr, lmeth), (rstr, rmeth) in util.zip_longest(
+ left_parent._annotation_traversals,
+ right_parent._annotation_traversals,
+ fillvalue=(None, None),
+ ):
+ if lstr != rstr or (lmeth is not rmeth):
+ return COMPARE_FAILED
+
+ dispatch = self.dispatch(lmeth)
+ left_child = left[lstr]
+ right_child = right[rstr]
+ if left_child is None:
+ if right_child is not None:
+ return False
+ else:
+ continue
+
+ comparison = dispatch(None, left_child, None, right_child, **kw)
+ if comparison is COMPARE_FAILED:
+ return comparison
+
+ def visit_clauseelement_tuples(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for ltup, rtup in util.zip_longest(left, right, fillvalue=None):
+ if ltup is None or rtup is None:
+ return COMPARE_FAILED
+
+ for l, r in util.zip_longest(ltup, rtup, fillvalue=None):
+ self.stack.append((l, r))
+
+ def visit_clauseelement_list(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((l, r))
+
+ def _compare_unordered_sequences(self, seq1, seq2, **kw):
+ if seq1 is None:
+ return seq2 is None
+
+ completed = set()
+ for clause in seq1:
+ for other_clause in set(seq2).difference(completed):
+ if self.compare_inner(clause, other_clause, **kw):
+ completed.add(other_clause)
+ break
+ return len(completed) == len(seq1) == len(seq2)
+
+ def visit_clauseelement_unordered_set(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ return self._compare_unordered_sequences(left, right, **kw)
+
+ def visit_fromclause_ordered_set(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((l, r))
+
+ def visit_string(self, left_parent, left, right_parent, right, **kw):
+ return left == right
+
+ def visit_anon_name(self, left_parent, left, right_parent, right, **kw):
+ return _resolve_name_for_compare(
+ left_parent, left, self.anon_map[0], **kw
+ ) == _resolve_name_for_compare(
+ right_parent, right, self.anon_map[1], **kw
+ )
+
+ def visit_boolean(self, left_parent, left, right_parent, right, **kw):
+ return left == right
+
+ def visit_operator(self, left_parent, left, right_parent, right, **kw):
+ return left is right
+
+ def visit_type(self, left_parent, left, right_parent, right, **kw):
+ return left._compare_type_affinity(right)
+
+ def visit_plain_dict(self, left_parent, left, right_parent, right, **kw):
+ return left == right
+
+ def visit_plain_obj(self, left_parent, left, right_parent, right, **kw):
+ return left == right
+
+ def visit_named_ddl_element(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ if left is None:
+ if right is not None:
+ return COMPARE_FAILED
+
+ return left.name == right.name
+
+ def visit_prefix_sequence(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for (l_clause, l_str), (r_clause, r_str) in util.zip_longest(
+ left, right, fillvalue=(None, None)
+ ):
+ if l_str != r_str:
+ return COMPARE_FAILED
+ else:
+ self.stack.append((l_clause, r_clause))
+
+ def visit_table_hint_list(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
+ right_keys = sorted(
+ right, key=lambda elem: (elem[0].fullname, elem[1])
+ )
+ for (ltable, ldialect), (rtable, rdialect) in util.zip_longest(
+ left_keys, right_keys, fillvalue=(None, None)
+ ):
+ if ldialect != rdialect:
+ return COMPARE_FAILED
+ elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
+ return COMPARE_FAILED
+ else:
+ self.stack.append((ltable, rtable))
+
+ def visit_statement_hint_list(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_unknown_structure(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ raise NotImplementedError()
+
+ def compare_clauselist(self, left, right, **kw):
+ if left.operator is right.operator:
+ if operators.is_associative(left.operator):
+ if self._compare_unordered_sequences(
+ left.clauses, right.clauses, **kw
+ ):
+ return ["operator", "clauses"]
+ else:
+ return COMPARE_FAILED
+ else:
+ return ["operator"]
+ else:
+ return COMPARE_FAILED
+
+ def compare_binary(self, left, right, **kw):
+ if left.operator == right.operator:
+ if operators.is_commutative(left.operator):
+ if (
+ compare(left.left, right.left, **kw)
+ and compare(left.right, right.right, **kw)
+ ) or (
+ compare(left.left, right.right, **kw)
+ and compare(left.right, right.left, **kw)
+ ):
+ return ["operator", "negate", "left", "right"]
+ else:
+ return COMPARE_FAILED
+ else:
+ return ["operator", "negate"]
+ else:
+ return COMPARE_FAILED
+
+
+class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
+ def compare_column_element(
+ self, left, right, use_proxies=True, equivalents=(), **kw
+ ):
+ """Compare ColumnElements using proxies and equivalent collections.
+
+ This is a comparison strategy specific to the ORM.
+ """
+
+ to_compare = (right,)
+ if equivalents and right in equivalents:
+ to_compare = equivalents[right].union(to_compare)
+
+ for oth in to_compare:
+ if use_proxies and left.shares_lineage(oth):
+ return SKIP_TRAVERSE
+ elif hash(left) == hash(right):
+ return SKIP_TRAVERSE
+ else:
+ return COMPARE_FAILED
+
+ def compare_column(self, left, right, **kw):
+ return self.compare_column_element(left, right, **kw)
+
+ def compare_label(self, left, right, **kw):
+ return self.compare_column_element(left, right, **kw)
+
+ def compare_table(self, left, right, **kw):
+ # tables compare on identity, since it's not really feasible to
+ # compare them column by column with the above rules
+ return SKIP_TRAVERSE if left is right else COMPARE_FAILED
from . import operators
from .base import SchemaEventTarget
-from .visitors import Visitable
-from .visitors import VisitableType
+from .visitors import Traversible
+from .visitors import TraversibleType
from .. import exc
from .. import util
_resolve_value_to_type = None
-class TypeEngine(Visitable):
+class TypeEngine(Traversible):
"""The ultimate base class for all SQL datatypes.
Common subclasses of :class:`.TypeEngine` include
return dialect.type_descriptor(self)
@util.memoized_property
- def _cache_key(self):
- return util.constructor_key(self, self.__class__)
+ def _gen_cache_key(self):
+ names = util.get_cls_kwargs(self.__class__)
+ return (self.__class__,) + tuple(
+ (k, self.__dict__[k])
+ for k in names
+ if k in self.__dict__ and not k.startswith("_")
+ )
def adapt(self, cls, **kw):
"""Produce an "adapted" form of this type, given an "impl" class
return util.generic_repr(self)
-class VisitableCheckKWArg(util.EnsureKWArgType, VisitableType):
+class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType):
pass
return pairs
-class ClauseAdapter(visitors.ReplacingCloningVisitor):
+class ClauseAdapter(visitors.ReplacingExternalTraversal):
"""Clones and modifies clauses based on column correspondence.
E.g.::
from .. import exc
from .. import util
-
+from ..util import langhelpers
+from ..util import symbol
__all__ = [
- "VisitableType",
- "Visitable",
- "ClauseVisitor",
- "CloningVisitor",
- "ReplacingCloningVisitor",
"iterate",
"iterate_depthfirst",
"traverse_using",
"traverse_depthfirst",
"cloned_traverse",
"replacement_traverse",
+ "Traversible",
+ "TraversibleType",
+ "ExternalTraversal",
+ "InternalTraversal",
]
-class VisitableType(type):
- """Metaclass which assigns a ``_compiler_dispatch`` method to classes
- having a ``__visit_name__`` attribute.
+def _generate_compiler_dispatch(cls):
+ """Generate a _compiler_dispatch() external traversal on classes with a
+ __visit_name__ attribute.
+
+ """
+ visit_name = cls.__visit_name__
+
+ if isinstance(visit_name, util.compat.string_types):
+ # There is an optimization opportunity here because the
+ # the string name of the class's __visit_name__ is known at
+ # this early stage (import time) so it can be pre-constructed.
+ getter = operator.attrgetter("visit_%s" % visit_name)
+
+ def _compiler_dispatch(self, visitor, **kw):
+ try:
+ meth = getter(visitor)
+ except AttributeError:
+ raise exc.UnsupportedCompilationError(visitor, cls)
+ else:
+ return meth(self, **kw)
+
+ else:
+ # The optimization opportunity is lost for this case because the
+ # __visit_name__ is not yet a string. As a result, the visit
+ # string has to be recalculated with each compilation.
+ def _compiler_dispatch(self, visitor, **kw):
+ visit_attr = "visit_%s" % self.__visit_name__
+ try:
+ meth = getattr(visitor, visit_attr)
+ except AttributeError:
+ raise exc.UnsupportedCompilationError(visitor, cls)
+ else:
+ return meth(self, **kw)
+
+ _compiler_dispatch.__doc__ = """Look for an attribute named "visit_"
+ + self.__visit_name__ on the visitor, and call it with the same
+ kw params.
+ """
+ cls._compiler_dispatch = _compiler_dispatch
+
+
+class TraversibleType(type):
+ """Metaclass which assigns dispatch attributes to various kinds of
+ "visitable" classes.
- The ``_compiler_dispatch`` attribute becomes an instance method which
- looks approximately like the following::
+ Attributes include:
- def _compiler_dispatch (self, visitor, **kw):
- '''Look for an attribute named "visit_" + self.__visit_name__
- on the visitor, and call it with the same kw params.'''
- visit_attr = 'visit_%s' % self.__visit_name__
- return getattr(visitor, visit_attr)(self, **kw)
+ * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``.
+ This is called "external traversal" because the caller of each visit()
+ method is responsible for sub-traversing the inner elements of each
+ object. This is appropriate for string compilers and other traversals
+ that need to call upon the inner elements in a specific pattern.
- Classes having no ``__visit_name__`` attribute will remain unaffected.
+ * internal traversal collections ``_children_traversal``,
+ ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from
+ an optional ``_traverse_internals`` collection of symbols which comes
+ from the :class:`.InternalTraversal` list of symbols. This is called
+ "internal traversal" MARKMARK
"""
def __init__(cls, clsname, bases, clsdict):
- if clsname != "Visitable" and hasattr(cls, "__visit_name__"):
- _generate_dispatch(cls)
+ if clsname != "Traversible":
+ if "__visit_name__" in clsdict:
+ _generate_compiler_dispatch(cls)
+
+ super(TraversibleType, cls).__init__(clsname, bases, clsdict)
- super(VisitableType, cls).__init__(clsname, bases, clsdict)
+class Traversible(util.with_metaclass(TraversibleType)):
+ """Base class for visitable objects, applies the
+ :class:`.visitors.TraversibleType` metaclass.
-def _generate_dispatch(cls):
- """Return an optimized visit dispatch function for the cls
- for use by the compiler.
"""
- if "__visit_name__" in cls.__dict__:
- visit_name = cls.__visit_name__
- if isinstance(visit_name, util.compat.string_types):
- # There is an optimization opportunity here because the
- # the string name of the class's __visit_name__ is known at
- # this early stage (import time) so it can be pre-constructed.
- getter = operator.attrgetter("visit_%s" % visit_name)
- def _compiler_dispatch(self, visitor, **kw):
- try:
- meth = getter(visitor)
- except AttributeError:
- raise exc.UnsupportedCompilationError(visitor, cls)
- else:
- return meth(self, **kw)
+class _InternalTraversalType(type):
+ def __init__(cls, clsname, bases, clsdict):
+ if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"):
+ lookup = {}
+ for key, sym in clsdict.items():
+ if key.startswith("dp_"):
+ visit_key = key.replace("dp_", "visit_")
+ sym_name = sym.name
+ assert sym_name not in lookup, sym_name
+ lookup[sym] = lookup[sym_name] = visit_key
+ if hasattr(cls, "_dispatch_lookup"):
+ lookup.update(cls._dispatch_lookup)
+ cls._dispatch_lookup = lookup
+
+ super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict)
+
+
+def _generate_dispatcher(visitor, internal_dispatch, method_name):
+ names = []
+ for attrname, visit_sym in internal_dispatch:
+ meth = visitor.dispatch(visit_sym)
+ if meth:
+ visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym]
+ names.append((attrname, visit_name))
+
+ code = (
+ (" return [\n")
+ + (
+ ", \n".join(
+ " (%r, self.%s, visitor.%s)"
+ % (attrname, attrname, visit_name)
+ for attrname, visit_name in names
+ )
+ )
+ + ("\n ]\n")
+ )
+ meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
+ # print(meth_text)
+ return langhelpers._exec_code_in_env(meth_text, {}, method_name)
- else:
- # The optimization opportunity is lost for this case because the
- # __visit_name__ is not yet a string. As a result, the visit
- # string has to be recalculated with each compilation.
- def _compiler_dispatch(self, visitor, **kw):
- visit_attr = "visit_%s" % self.__visit_name__
- try:
- meth = getattr(visitor, visit_attr)
- except AttributeError:
- raise exc.UnsupportedCompilationError(visitor, cls)
- else:
- return meth(self, **kw)
-
- _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__
- on the visitor, and call it with the same kw params.
- """
- cls._compiler_dispatch = _compiler_dispatch
-
-
-class Visitable(util.with_metaclass(VisitableType, object)):
- """Base class for visitable objects, applies the
- :class:`.visitors.VisitableType` metaclass.
- The :class:`.Visitable` class is essentially at the base of the
- :class:`.ClauseElement` hierarchy.
+class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
+ r"""Defines visitor symbols used for internal traversal.
+
+ The :class:`.InternalTraversal` class is used in two ways. One is that
+ it can serve as the superclass for an object that implements the
+ various visit methods of the class. The other is that the symbols
+ themselves of :class:`.InternalTraversal` are used within
+ the ``_traverse_internals`` collection. Such as, the :class:`.Case`
+ object defines ``_travserse_internals`` as ::
+
+ _traverse_internals = [
+ ("value", InternalTraversal.dp_clauseelement),
+ ("whens", InternalTraversal.dp_clauseelement_tuples),
+ ("else_", InternalTraversal.dp_clauseelement),
+ ]
+
+ Above, the :class:`.Case` class indicates its internal state as the
+ attribtues named ``value``, ``whens``, and ``else\_``. They each
+ link to an :class:`.InternalTraversal` method which indicates the type
+ of datastructure referred towards.
+
+ Using the ``_traverse_internals`` structure, objects of type
+ :class:`.InternalTraversible` will have the following methods automatically
+ implemented:
+
+ * :meth:`.Traversible.get_children`
+
+ * :meth:`.Traversible._copy_internals`
+
+ * :meth:`.Traversible._gen_cache_key`
+
+ Subclasses can also implement these methods directly, particularly for the
+ :meth:`.Traversible._copy_internals` method, when special steps
+ are needed.
+
+ .. versionadded:: 1.4
"""
+ def dispatch(self, visit_symbol):
+ """Given a method from :class:`.InternalTraversal`, return the
+ corresponding method on a subclass.
-class ClauseVisitor(object):
- """Base class for visitor objects which can traverse using
+ """
+ name = self._dispatch_lookup[visit_symbol]
+ return getattr(self, name, None)
+
+ def run_generated_dispatch(
+ self, target, internal_dispatch, generate_dispatcher_name
+ ):
+ try:
+ dispatcher = target.__class__.__dict__[generate_dispatcher_name]
+ except KeyError:
+ dispatcher = _generate_dispatcher(
+ self, internal_dispatch, generate_dispatcher_name
+ )
+ setattr(target.__class__, generate_dispatcher_name, dispatcher)
+ return dispatcher(target, self)
+
+ dp_has_cache_key = symbol("HC")
+ """Visit a :class:`.HasCacheKey` object."""
+
+ dp_clauseelement = symbol("CE")
+ """Visit a :class:`.ClauseElement` object."""
+
+ dp_fromclause_canonical_column_collection = symbol("FC")
+ """Visit a :class:`.FromClause` object in the context of the
+ ``columns`` attribute.
+
+ The column collection is "canonical", meaning it is the originally
+ defined location of the :class:`.ColumnClause` objects. Right now
+ this means that the object being visited is a :class:`.TableClause`
+ or :class:`.Table` object only.
+
+ """
+
+ dp_clauseelement_tuples = symbol("CT")
+ """Visit a list of tuples which contain :class:`.ClauseElement`
+ objects.
+
+ """
+
+ dp_clauseelement_list = symbol("CL")
+ """Visit a list of :class:`.ClauseElement` objects.
+
+ """
+
+ dp_clauseelement_unordered_set = symbol("CU")
+ """Visit an unordered set of :class:`.ClauseElement` objects. """
+
+ dp_fromclause_ordered_set = symbol("CO")
+ """Visit an ordered set of :class:`.FromClause` objects. """
+
+ dp_string = symbol("S")
+ """Visit a plain string value.
+
+ Examples include table and column names, bound parameter keys, special
+ keywords such as "UNION", "UNION ALL".
+
+ The string value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_anon_name = symbol("AN")
+ """Visit a potentially "anonymized" string value.
+
+ The string value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_boolean = symbol("B")
+ """Visit a boolean value.
+
+ The boolean value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_operator = symbol("O")
+ """Visit an operator.
+
+ The operator is a function from the :mod:`sqlalchemy.sql.operators`
+ module.
+
+ The operator value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_type = symbol("T")
+ """Visit a :class:`.TypeEngine` object
+
+ The type object is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_plain_dict = symbol("PD")
+ """Visit a dictionary with string keys.
+
+ The keys of the dictionary should be strings, the values should
+ be immutable and hashable. The dictionary is considered to be
+ significant for cache key generation.
+
+ """
+
+ dp_string_clauseelement_dict = symbol("CD")
+ """Visit a dictionary of string keys to :class:`.ClauseElement`
+ objects.
+
+ """
+
+ dp_string_multi_dict = symbol("MD")
+ """Visit a dictionary of string keys to values which may either be
+ plain immutable/hashable or :class:`.HasCacheKey` objects.
+
+ """
+
+ dp_plain_obj = symbol("PO")
+ """Visit a plain python object.
+
+ The value should be immutable and hashable, such as an integer.
+ The value is considered to be significant for cache key generation.
+
+ """
+
+ dp_annotations_state = symbol("A")
+ """Visit the state of the :class:`.Annotatated` version of an object.
+
+ """
+
+ dp_named_ddl_element = symbol("DD")
+ """Visit a simple named DDL element.
+
+ The current object used by this method is the :class:`.Sequence`.
+
+ The object is only considered to be important for cache key generation
+ as far as its name, but not any other aspects of it.
+
+ """
+
+ dp_prefix_sequence = symbol("PS")
+ """Visit the sequence represented by :class:`.HasPrefixes`
+ or :class:`.HasSuffixes`.
+
+ """
+
+ dp_table_hint_list = symbol("TH")
+ """Visit the ``_hints`` collection of a :class:`.Select` object.
+
+ """
+
+ dp_statement_hint_list = symbol("SH")
+ """Visit the ``_statement_hints`` collection of a :class:`.Select`
+ object.
+
+ """
+
+ dp_unknown_structure = symbol("UK")
+ """Visit an unknown structure.
+
+ """
+
+
+class ExtendedInternalTraversal(InternalTraversal):
+ """defines additional symbols that are useful in caching applications.
+
+ Traversals for :class:`.ClauseElement` objects only need to use
+ those symbols present in :class:`.InternalTraversal`. However, for
+ additional caching use cases within the ORM, symbols dealing with the
+ :class:`.HasCacheKey` class are added here.
+
+ """
+
+ dp_ignore = symbol("IG")
+ """Specify an object that should be ignored entirely.
+
+ This currently applies function call argument caching where some
+ arguments should not be considered to be part of a cache key.
+
+ """
+
+ dp_inspectable = symbol("IS")
+ """Visit an inspectable object where the return value is a HasCacheKey`
+ object."""
+
+ dp_multi = symbol("M")
+ """Visit an object that may be a :class:`.HasCacheKey` or may be a
+ plain hashable object."""
+
+ dp_multi_list = symbol("MT")
+ """Visit a tuple containing elements that may be :class:`.HasCacheKey` or
+ may be a plain hashable object."""
+
+ dp_has_cache_key_tuples = symbol("HT")
+ """Visit a list of tuples which contain :class:`.HasCacheKey`
+ objects.
+
+ """
+
+ dp_has_cache_key_list = symbol("HL")
+ """Visit a list of :class:`.HasCacheKey` objects."""
+
+ dp_inspectable_list = symbol("IL")
+ """Visit a list of inspectable objects which upon inspection are
+ HasCacheKey objects."""
+
+
+class ExternalTraversal(object):
+ """Base class for visitor objects which can traverse externally using
the :func:`.visitors.traverse` function.
Direct usage of the :func:`.visitors.traverse` function is usually
return self
-class CloningVisitor(ClauseVisitor):
+class CloningExternalTraversal(ExternalTraversal):
"""Base class for visitor objects which can traverse using
the :func:`.visitors.cloned_traverse` function.
)
-class ReplacingCloningVisitor(CloningVisitor):
+class ReplacingExternalTraversal(CloningExternalTraversal):
"""Base class for visitor objects which can traverse using
the :func:`.visitors.replacement_traverse` function.
return replacement_traverse(obj, self.__traverse_options__, replace)
+# backwards compatibility
+Visitable = Traversible
+VisitableType = TraversibleType
+ClauseVisitor = ExternalTraversal
+CloningVisitor = CloningExternalTraversal
+ReplacingCloningVisitor = ReplacingExternalTraversal
+
+
def iterate(obj, opts):
r"""traverse the given expression structure, returning an iterator.
cloned = {}
stop_on = set(opts.get("stop_on", []))
- def clone(elem):
+ def clone(elem, **kw):
if elem in stop_on:
return elem
else:
if id(elem) not in cloned:
+
+ if "replace" in kw:
+ newelem = kw["replace"](elem)
+ if newelem is not None:
+ cloned[id(elem)] = newelem
+ return newelem
+
cloned[id(elem)] = newelem = elem._clone()
newelem._copy_internals(clone=clone)
meth = visitors.get(newelem.__visit_name__, None)
stop_on.add(id(newelem))
return newelem
else:
+
if elem not in cloned:
+ if "replace" in kw:
+ newelem = kw["replace"](elem)
+ if newelem is not None:
+ cloned[elem] = newelem
+ return newelem
+
cloned[elem] = newelem = elem._clone()
newelem._copy_internals(clone=clone, **kw)
return cloned[elem]
configure_mappers()
- def test_generate_cache_key_unbound_branching(self):
+ def test_generate_path_cache_key_unbound_branching(self):
A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G")
base = joinedload(A.bs)
@profiling.function_call_count()
def go():
for opt in opts:
- opt._generate_cache_key(cache_path)
+ opt._generate_path_cache_key(cache_path)
go()
- def test_generate_cache_key_bound_branching(self):
+ def test_generate_path_cache_key_bound_branching(self):
A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G")
base = Load(A).joinedload(A.bs)
@profiling.function_call_count()
def go():
for opt in opts:
- opt._generate_cache_key(cache_path)
+ opt._generate_path_cache_key(cache_path)
go()
if query._current_path:
query._cache_key = "user7_addresses"
- def _generate_cache_key(self, path):
+ def _generate_path_cache_key(self, path):
return None
return RelationshipCache()
--- /dev/null
+from sqlalchemy import inspect
+from sqlalchemy.orm import aliased
+from sqlalchemy.orm import defaultload
+from sqlalchemy.orm import defer
+from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import Load
+from sqlalchemy.orm import subqueryload
+from sqlalchemy.testing import eq_
+from test.orm import _fixtures
+from ..sql.test_compare import CacheKeyFixture
+
+
+class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
+ run_setup_mappers = "once"
+ run_inserts = None
+ run_deletes = None
+
+ @classmethod
+ def setup_mappers(cls):
+ cls._setup_stock_mapping()
+
+ def test_mapper_and_aliased(self):
+ User, Address, Keyword = self.classes("User", "Address", "Keyword")
+
+ self._run_cache_key_fixture(
+ lambda: (inspect(User), inspect(Address), inspect(aliased(User)))
+ )
+
+ def test_attributes(self):
+ User, Address, Keyword = self.classes("User", "Address", "Keyword")
+
+ self._run_cache_key_fixture(
+ lambda: (
+ User.id,
+ Address.id,
+ aliased(User).id,
+ aliased(User, name="foo").id,
+ aliased(User, name="bar").id,
+ User.name,
+ User.addresses,
+ Address.email_address,
+ aliased(User).addresses,
+ )
+ )
+
+ def test_unbound_options(self):
+ User, Address, Keyword, Order, Item = self.classes(
+ "User", "Address", "Keyword", "Order", "Item"
+ )
+
+ self._run_cache_key_fixture(
+ lambda: (
+ joinedload(User.addresses),
+ joinedload("addresses"),
+ joinedload(User.orders).selectinload("items"),
+ joinedload(User.orders).selectinload(Order.items),
+ defer(User.id),
+ defer("id"),
+ defer(Address.id),
+ joinedload(User.addresses).defer(Address.id),
+ joinedload(aliased(User).addresses).defer(Address.id),
+ joinedload(User.addresses).defer("id"),
+ joinedload(User.orders).joinedload(Order.items),
+ joinedload(User.orders).subqueryload(Order.items),
+ subqueryload(User.orders).subqueryload(Order.items),
+ subqueryload(User.orders)
+ .subqueryload(Order.items)
+ .defer(Item.description),
+ defaultload(User.orders).defaultload(Order.items),
+ defaultload(User.orders),
+ )
+ )
+
+ def test_bound_options(self):
+ User, Address, Keyword, Order, Item = self.classes(
+ "User", "Address", "Keyword", "Order", "Item"
+ )
+
+ self._run_cache_key_fixture(
+ lambda: (
+ Load(User).joinedload(User.addresses),
+ Load(User).joinedload(User.orders),
+ Load(User).defer(User.id),
+ Load(User).subqueryload("addresses"),
+ Load(Address).defer("id"),
+ Load(aliased(Address)).defer("id"),
+ Load(User).joinedload(User.addresses).defer(Address.id),
+ Load(User).joinedload(User.orders).joinedload(Order.items),
+ Load(User).joinedload(User.orders).subqueryload(Order.items),
+ Load(User).subqueryload(User.orders).subqueryload(Order.items),
+ Load(User)
+ .subqueryload(User.orders)
+ .subqueryload(Order.items)
+ .defer(Item.description),
+ Load(User).defaultload(User.orders).defaultload(Order.items),
+ Load(User).defaultload(User.orders),
+ )
+ )
+
+ def test_bound_options_equiv_on_strname(self):
+ """Bound loader options resolve on string name so test that the cache
+ key for the string version matches the resolved version.
+
+ """
+ User, Address, Keyword, Order, Item = self.classes(
+ "User", "Address", "Keyword", "Order", "Item"
+ )
+
+ for left, right in [
+ (Load(User).defer(User.id), Load(User).defer("id")),
+ (
+ Load(User).joinedload(User.addresses),
+ Load(User).joinedload("addresses"),
+ ),
+ (
+ Load(User).joinedload(User.orders).joinedload(Order.items),
+ Load(User).joinedload("orders").joinedload("items"),
+ ),
+ ]:
+ eq_(left._generate_cache_key(), right._generate_cache_key())
)
-class CacheKeyTest(PathTest, QueryTest):
+class PathedCacheKeyTest(PathTest, QueryTest):
run_create_tables = False
run_inserts = None
opt = joinedload(User.orders).joinedload(Order.items)
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(((Order, "items", Item, ("lazy", "joined")),)),
)
opt2 = base.joinedload(Order.address)
eq_(
- opt1._generate_cache_key(query_path),
+ opt1._generate_path_cache_key(query_path),
(((Order, "items", Item, ("lazy", "joined")),)),
)
eq_(
- opt2._generate_cache_key(query_path),
+ opt2._generate_path_cache_key(query_path),
(((Order, "address", Address, ("lazy", "joined")),)),
)
opt2 = base.joinedload(Order.address)
eq_(
- opt1._generate_cache_key(query_path),
+ opt1._generate_path_cache_key(query_path),
(((Order, "items", Item, ("lazy", "joined")),)),
)
eq_(
- opt2._generate_cache_key(query_path),
+ opt2._generate_path_cache_key(query_path),
(((Order, "address", Address, ("lazy", "joined")),)),
)
opt = Load(User).joinedload(User.orders).joinedload(Order.items)
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(((Order, "items", Item, ("lazy", "joined")),)),
)
query_path = self._make_path_registry([User, "addresses"])
opt = joinedload(User.orders).joinedload(Order.items)
- eq_(opt._generate_cache_key(query_path), None)
+ eq_(opt._generate_path_cache_key(query_path), None)
def test_bound_cache_key_excluded_on_other(self):
User, Address, Order, Item, SubItem = self.classes(
query_path = self._make_path_registry([User, "addresses"])
opt = Load(User).joinedload(User.orders).joinedload(Order.items)
- eq_(opt._generate_cache_key(query_path), None)
+ eq_(opt._generate_path_cache_key(query_path), None)
def test_unbound_cache_key_excluded_on_aliased(self):
User, Address, Order, Item, SubItem = self.classes(
query_path = self._make_path_registry([User, "orders"])
opt = joinedload(aliased(User).orders).joinedload(Order.items)
- eq_(opt._generate_cache_key(query_path), None)
+ eq_(opt._generate_path_cache_key(query_path), None)
def test_bound_cache_key_wildcard_one(self):
# do not change this test, it is testing
query_path = self._make_path_registry([User, "addresses"])
opt = Load(User).lazyload("*")
- eq_(opt._generate_cache_key(query_path), None)
+ eq_(opt._generate_path_cache_key(query_path), None)
def test_unbound_cache_key_wildcard_one(self):
User, Address = self.classes("User", "Address")
opt = lazyload("*")
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(("relationship:_sa_default", ("lazy", "select")),),
)
opt = Load(User).lazyload("orders").lazyload("*")
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(
("orders", Order, ("lazy", "select")),
("orders", Order, "relationship:*", ("lazy", "select")),
opt = lazyload("orders").lazyload("*")
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(
("orders", Order, ("lazy", "select")),
("orders", Order, "relationship:*", ("lazy", "select")),
)
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(
(SubItem, ("lazy", "subquery")),
("extra_keywords", Keyword, ("lazy", "subquery")),
)
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(
(SubItem, ("lazy", "subquery")),
("extra_keywords", Keyword, ("lazy", "subquery")),
)
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(
(SubItem, ("lazy", "subquery")),
("extra_keywords", Keyword, ("lazy", "subquery")),
)
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(
(SubItem, ("lazy", "subquery")),
("extra_keywords", Keyword, ("lazy", "subquery")),
opt = subqueryload(User.orders).subqueryload(
Order.items.of_type(SubItem)
)
- eq_(opt._generate_cache_key(query_path), None)
+ eq_(opt._generate_path_cache_key(query_path), None)
def test_unbound_cache_key_excluded_of_type_unsafe(self):
User, Address, Order, Item, SubItem = self.classes(
opt = subqueryload(User.orders).subqueryload(
Order.items.of_type(aliased(SubItem))
)
- eq_(opt._generate_cache_key(query_path), None)
+ eq_(opt._generate_path_cache_key(query_path), None)
def test_bound_cache_key_excluded_of_type_safe(self):
User, Address, Order, Item, SubItem = self.classes(
.subqueryload(User.orders)
.subqueryload(Order.items.of_type(SubItem))
)
- eq_(opt._generate_cache_key(query_path), None)
+ eq_(opt._generate_path_cache_key(query_path), None)
def test_bound_cache_key_excluded_of_type_unsafe(self):
User, Address, Order, Item, SubItem = self.classes(
.subqueryload(User.orders)
.subqueryload(Order.items.of_type(aliased(SubItem)))
)
- eq_(opt._generate_cache_key(query_path), None)
+ eq_(opt._generate_path_cache_key(query_path), None)
def test_unbound_cache_key_included_of_type_safe(self):
User, Address, Order, Item, SubItem = self.classes(
opt = joinedload(User.orders).joinedload(Order.items.of_type(SubItem))
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
((Order, "items", SubItem, ("lazy", "joined")),),
)
)
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
((Order, "items", SubItem, ("lazy", "joined")),),
)
opt = joinedload(User.orders).joinedload(
Order.items.of_type(aliased(SubItem))
)
- eq_(opt._generate_cache_key(query_path), False)
+ eq_(opt._generate_path_cache_key(query_path), False)
def test_unbound_cache_key_included_unsafe_option_two(self):
User, Address, Order, Item, SubItem = self.classes(
opt = joinedload(User.orders).joinedload(
Order.items.of_type(aliased(SubItem))
)
- eq_(opt._generate_cache_key(query_path), False)
+ eq_(opt._generate_path_cache_key(query_path), False)
def test_unbound_cache_key_included_unsafe_option_three(self):
User, Address, Order, Item, SubItem = self.classes(
opt = joinedload(User.orders).joinedload(
Order.items.of_type(aliased(SubItem))
)
- eq_(opt._generate_cache_key(query_path), False)
+ eq_(opt._generate_path_cache_key(query_path), False)
def test_unbound_cache_key_included_unsafe_query(self):
User, Address, Order, Item, SubItem = self.classes(
query_path = self._make_path_registry([inspect(au), "orders"])
opt = joinedload(au.orders).joinedload(Order.items)
- eq_(opt._generate_cache_key(query_path), False)
+ eq_(opt._generate_path_cache_key(query_path), False)
def test_unbound_cache_key_included_safe_w_deferred(self):
User, Address, Order, Item, SubItem = self.classes(
.defer(Address.user_id)
)
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(
(
Address,
)
eq_(
- opt1._generate_cache_key(query_path),
+ opt1._generate_path_cache_key(query_path),
((Order, "items", Item, ("lazy", "joined")),),
)
eq_(
- opt2._generate_cache_key(query_path),
+ opt2._generate_path_cache_key(query_path),
(
(Order, "address", Address, ("lazy", "joined")),
(
.defer(Address.user_id)
)
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(
(
Address,
)
eq_(
- opt1._generate_cache_key(query_path),
+ opt1._generate_path_cache_key(query_path),
((Order, "items", Item, ("lazy", "joined")),),
)
eq_(
- opt2._generate_cache_key(query_path),
+ opt2._generate_path_cache_key(query_path),
(
(Order, "address", Address, ("lazy", "joined")),
(
query_path = self._make_path_registry([User, "orders"])
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(
(
Order,
au = aliased(User)
opt = Load(au).joinedload(au.orders).joinedload(Order.items)
- eq_(opt._generate_cache_key(query_path), None)
+ eq_(opt._generate_path_cache_key(query_path), None)
def test_bound_cache_key_included_unsafe_option_one(self):
User, Address, Order, Item, SubItem = self.classes(
.joinedload(User.orders)
.joinedload(Order.items.of_type(aliased(SubItem)))
)
- eq_(opt._generate_cache_key(query_path), False)
+ eq_(opt._generate_path_cache_key(query_path), False)
def test_bound_cache_key_included_unsafe_option_two(self):
User, Address, Order, Item, SubItem = self.classes(
.joinedload(User.orders)
.joinedload(Order.items.of_type(aliased(SubItem)))
)
- eq_(opt._generate_cache_key(query_path), False)
+ eq_(opt._generate_path_cache_key(query_path), False)
def test_bound_cache_key_included_unsafe_option_three(self):
User, Address, Order, Item, SubItem = self.classes(
.joinedload(User.orders)
.joinedload(Order.items.of_type(aliased(SubItem)))
)
- eq_(opt._generate_cache_key(query_path), False)
+ eq_(opt._generate_path_cache_key(query_path), False)
def test_bound_cache_key_included_unsafe_query(self):
User, Address, Order, Item, SubItem = self.classes(
query_path = self._make_path_registry([inspect(au), "orders"])
opt = Load(au).joinedload(au.orders).joinedload(Order.items)
- eq_(opt._generate_cache_key(query_path), False)
+ eq_(opt._generate_path_cache_key(query_path), False)
def test_bound_cache_key_included_safe_w_option(self):
User, Address, Order, Item, SubItem = self.classes(
query_path = self._make_path_registry([User, "orders"])
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(
(
Order,
opt = defaultload(User.addresses).load_only("id", "email_address")
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(
(Address, "id", ("deferred", False), ("instrument", True)),
(
Address.id, Address.email_address
)
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(
(Address, "id", ("deferred", False), ("instrument", True)),
(
.load_only("id", "email_address")
)
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
(
(Address, "id", ("deferred", False), ("instrument", True)),
(
opt = defaultload(User.addresses).undefer_group("xyz")
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
((Address, "column:*", ("undefer_group_xyz", True)),),
)
opt = Load(User).defaultload(User.addresses).undefer_group("xyz")
eq_(
- opt._generate_cache_key(query_path),
+ opt._generate_path_cache_key(query_path),
((Address, "column:*", ("undefer_group_xyz", True)),),
)
from sqlalchemy.sql import True_
from sqlalchemy.sql import type_coerce
from sqlalchemy.sql import visitors
+from sqlalchemy.sql.base import HasCacheKey
from sqlalchemy.sql.elements import _label_reference
from sqlalchemy.sql.elements import _textual_label_reference
from sqlalchemy.sql.elements import Annotated
from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy.sql.functions import ReturnTypeFromArgs
from sqlalchemy.sql.selectable import _OffsetLimitParam
+from sqlalchemy.sql.selectable import AliasedReturnsRows
from sqlalchemy.sql.selectable import FromGrouping
from sqlalchemy.sql.selectable import Selectable
from sqlalchemy.sql.selectable import SelectStatementGrouping
-from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.sql.visitors import InternalTraversal
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
-from sqlalchemy.testing import is_
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_true
from sqlalchemy.testing import ne_
meta2 = MetaData()
table_a = Table("a", meta, Column("a", Integer), Column("b", String))
+table_b_like_a = Table("b2", meta, Column("a", Integer), Column("b", String))
+
table_a_2 = Table("a", meta2, Column("a", Integer), Column("b", String))
+table_a_2_fs = Table(
+ "a", meta2, Column("a", Integer), Column("b", String), schema="fs"
+)
+table_a_2_bs = Table(
+ "a", meta2, Column("a", Integer), Column("b", String), schema="bs"
+)
+
table_b = Table("b", meta, Column("a", Integer), Column("b", Integer))
table_c = Table("c", meta, Column("x", Integer), Column("y", Integer))
table_d = Table("d", meta, Column("y", Integer), Column("z", Integer))
-class CompareAndCopyTest(fixtures.TestBase):
+class MyEntity(HasCacheKey):
+ def __init__(self, name, element):
+ self.name = name
+ self.element = element
+
+ _cache_key_traversal = [
+ ("name", InternalTraversal.dp_string),
+ ("element", InternalTraversal.dp_clauseelement),
+ ]
+
+class CoreFixtures(object):
# lambdas which return a tuple of ColumnElement objects.
# must return at least two objects that should compare differently.
# to test more varieties of "difference" additional objects can be added.
text("select a, b, c from table").columns(
a=Integer, b=String, c=Integer
),
+ text("select a, b, c from table where foo=:bar").bindparams(
+ bindparam("bar", Integer)
+ ),
+ text("select a, b, c from table where foo=:foo").bindparams(
+ bindparam("foo", Integer)
+ ),
+ text("select a, b, c from table where foo=:bar").bindparams(
+ bindparam("bar", String)
+ ),
),
lambda: (
column("q") == column("x"),
column("q") == column("y"),
column("z") == column("x"),
+ column("z") + column("x"),
+ column("z") - column("x"),
+ column("x") - column("z"),
+ column("z") > column("x"),
+ # note these two are mathematically equivalent but for now they
+ # are considered to be different
+ column("z") >= column("x"),
+ column("x") <= column("z"),
+ column("q").between(5, 6),
+ column("q").between(5, 6, symmetric=True),
+ column("q").like("somstr"),
+ column("q").like("somstr", escape="\\"),
+ column("q").like("somstr", escape="X"),
+ ),
+ lambda: (
+ table_a.c.a,
+ table_a.c.a._annotate({"orm": True}),
+ table_a.c.a._annotate({"orm": True})._annotate({"bar": False}),
+ table_a.c.a._annotate(
+ {"orm": True, "parententity": MyEntity("a", table_a)}
+ ),
+ table_a.c.a._annotate(
+ {"orm": True, "parententity": MyEntity("b", table_a)}
+ ),
+ table_a.c.a._annotate(
+ {"orm": True, "parententity": MyEntity("b", select([table_a]))}
+ ),
),
lambda: (
cast(column("q"), Integer),
.where(table_a.c.b == 5)
.correlate_except(table_b),
),
+ lambda: (
+ select([table_a.c.a]).cte(),
+ select([table_a.c.a]).cte(recursive=True),
+ select([table_a.c.a]).cte(name="some_cte", recursive=True),
+ select([table_a.c.a]).cte(name="some_cte"),
+ select([table_a.c.a]).cte(name="some_cte").alias("other_cte"),
+ select([table_a.c.a])
+ .cte(name="some_cte")
+ .union_all(select([table_a.c.a])),
+ select([table_a.c.a])
+ .cte(name="some_cte")
+ .union_all(select([table_a.c.b])),
+ select([table_a.c.a]).lateral(),
+ select([table_a.c.a]).lateral(name="bar"),
+ table_a.tablesample(func.bernoulli(1)),
+ table_a.tablesample(func.bernoulli(1), seed=func.random()),
+ table_a.tablesample(func.bernoulli(1), seed=func.other_random()),
+ table_a.tablesample(func.hoho(1)),
+ table_a.tablesample(func.bernoulli(1), name="bar"),
+ table_a.tablesample(
+ func.bernoulli(1), name="bar", seed=func.random()
+ ),
+ ),
+ lambda: (
+ select([table_a.c.a]),
+ select([table_a.c.a]).prefix_with("foo"),
+ select([table_a.c.a]).prefix_with("foo", dialect="mysql"),
+ select([table_a.c.a]).prefix_with("foo", dialect="postgresql"),
+ select([table_a.c.a]).prefix_with("bar"),
+ select([table_a.c.a]).suffix_with("bar"),
+ ),
+ lambda: (
+ select([table_a_2.c.a]),
+ select([table_a_2_fs.c.a]),
+ select([table_a_2_bs.c.a]),
+ ),
+ lambda: (
+ select([table_a.c.a]),
+ select([table_a.c.a]).with_hint(None, "some hint"),
+ select([table_a.c.a]).with_hint(None, "some other hint"),
+ select([table_a.c.a]).with_hint(table_a, "some hint"),
+ select([table_a.c.a])
+ .with_hint(table_a, "some hint")
+ .with_hint(None, "some other hint"),
+ select([table_a.c.a]).with_hint(table_a, "some other hint"),
+ select([table_a.c.a]).with_hint(
+ table_a, "some hint", dialect_name="mysql"
+ ),
+ select([table_a.c.a]).with_hint(
+ table_a, "some hint", dialect_name="postgresql"
+ ),
+ ),
lambda: (
table_a.join(table_b, table_a.c.a == table_b.c.a),
table_a.join(
table("a", column("x"), column("y", Integer)),
table("a", column("q"), column("y", Integer)),
),
- lambda: (
- Table("a", MetaData(), Column("q", Integer), Column("b", String)),
- Table("b", MetaData(), Column("q", Integer), Column("b", String)),
- ),
+ lambda: (table_a, table_b),
]
+ def _complex_fixtures():
+ def one():
+ a1 = table_a.alias()
+ a2 = table_b_like_a.alias()
+
+ stmt = (
+ select([table_a.c.a, a1.c.b, a2.c.b])
+ .where(table_a.c.b == a1.c.b)
+ .where(a1.c.b == a2.c.b)
+ .where(a1.c.a == 5)
+ )
+
+ return stmt
+
+ def one_diff():
+ a1 = table_b_like_a.alias()
+ a2 = table_a.alias()
+
+ stmt = (
+ select([table_a.c.a, a1.c.b, a2.c.b])
+ .where(table_a.c.b == a1.c.b)
+ .where(a1.c.b == a2.c.b)
+ .where(a1.c.a == 5)
+ )
+
+ return stmt
+
+ def two():
+ inner = one().subquery()
+
+ stmt = select([table_b.c.a, inner.c.a, inner.c.b]).select_from(
+ table_b.join(inner, table_b.c.b == inner.c.b)
+ )
+
+ return stmt
+
+ def three():
+
+ a1 = table_a.alias()
+ a2 = table_a.alias()
+ ex = exists().where(table_b.c.b == a1.c.a)
+
+ stmt = (
+ select([a1.c.a, a2.c.a])
+ .select_from(a1.join(a2, a1.c.b == a2.c.b))
+ .where(ex)
+ )
+ return stmt
+
+ return [one(), one_diff(), two(), three()]
+
+ fixtures.append(_complex_fixtures)
+
+
+class CacheKeyFixture(object):
+ def _run_cache_key_fixture(self, fixture):
+ case_a = fixture()
+ case_b = fixture()
+
+ for a, b in itertools.combinations_with_replacement(
+ range(len(case_a)), 2
+ ):
+ if a == b:
+ a_key = case_a[a]._generate_cache_key()
+ b_key = case_b[b]._generate_cache_key()
+ eq_(a_key.key, b_key.key)
+
+ for a_param, b_param in zip(
+ a_key.bindparams, b_key.bindparams
+ ):
+ assert a_param.compare(b_param, compare_values=False)
+ else:
+ a_key = case_a[a]._generate_cache_key()
+ b_key = case_b[b]._generate_cache_key()
+
+ if a_key.key == b_key.key:
+ for a_param, b_param in zip(
+ a_key.bindparams, b_key.bindparams
+ ):
+ if not a_param.compare(b_param, compare_values=True):
+ break
+ else:
+ # this fails unconditionally since we could not
+ # find bound parameter values that differed.
+ # Usually we intended to get two distinct keys here
+ # so the failure will be more descriptive using the
+ # ne_() assertion.
+ ne_(a_key.key, b_key.key)
+ else:
+ ne_(a_key.key, b_key.key)
+
+ # ClauseElement-specific test to ensure the cache key
+ # collected all the bound parameters
+ if isinstance(case_a[a], ClauseElement) and isinstance(
+ case_b[b], ClauseElement
+ ):
+ assert_a_params = []
+ assert_b_params = []
+ visitors.traverse_depthfirst(
+ case_a[a], {}, {"bindparam": assert_a_params.append}
+ )
+ visitors.traverse_depthfirst(
+ case_b[b], {}, {"bindparam": assert_b_params.append}
+ )
+
+ # note we're asserting the order of the params as well as
+ # if there are dupes or not. ordering has to be deterministic
+ # and matches what a traversal would provide.
+ # regular traverse_depthfirst does produce dupes in cases like
+ # select([some_alias]).
+ # select_from(join(some_alias, other_table))
+ # where a bound parameter is inside of some_alias. the
+ # cache key case is more minimalistic
+ eq_(
+ sorted(a_key.bindparams, key=lambda b: b.key),
+ sorted(
+ util.unique_list(assert_a_params), key=lambda b: b.key
+ ),
+ )
+ eq_(
+ sorted(b_key.bindparams, key=lambda b: b.key),
+ sorted(
+ util.unique_list(assert_b_params), key=lambda b: b.key
+ ),
+ )
+
+
+class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
+ def test_cache_key(self):
+ for fixture in self.fixtures:
+ self._run_cache_key_fixture(fixture)
+
+ def test_cache_key_unknown_traverse(self):
+ class Foobar1(ClauseElement):
+ _traverse_internals = [
+ ("key", InternalTraversal.dp_anon_name),
+ ("type_", InternalTraversal.dp_unknown_structure),
+ ]
+
+ def __init__(self, key, type_):
+ self.key = key
+ self.type_ = type_
+
+ f1 = Foobar1("foo", String())
+ eq_(f1._generate_cache_key(), None)
+
+ def test_cache_key_no_method(self):
+ class Foobar1(ClauseElement):
+ pass
+
+ class Foobar2(ColumnElement):
+ pass
+
+ # the None for cache key will prevent objects
+ # which contain these elements from being cached.
+ f1 = Foobar1()
+ eq_(f1._generate_cache_key(), None)
+
+ f2 = Foobar2()
+ eq_(f2._generate_cache_key(), None)
+
+ s1 = select([column("q"), Foobar2()])
+
+ eq_(s1._generate_cache_key(), None)
+
+ def test_get_children_no_method(self):
+ class Foobar1(ClauseElement):
+ pass
+
+ class Foobar2(ColumnElement):
+ pass
+
+ f1 = Foobar1()
+ eq_(f1.get_children(), [])
+
+ f2 = Foobar2()
+ eq_(f2.get_children(), [])
+
+ def test_copy_internals_no_method(self):
+ class Foobar1(ClauseElement):
+ pass
+
+ class Foobar2(ColumnElement):
+ pass
+
+ f1 = Foobar1()
+ f2 = Foobar2()
+
+ f1._copy_internals()
+ f2._copy_internals()
+
+
+class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
@classmethod
def setup_class(cls):
# TODO: we need to get dialects here somehow, perhaps in test_suite?
cls
for cls in class_hierarchy(ClauseElement)
if issubclass(cls, (ColumnElement, Selectable))
- and "__init__" in cls.__dict__
+ and (
+ "__init__" in cls.__dict__
+ or issubclass(cls, AliasedReturnsRows)
+ )
and not issubclass(cls, (Annotated))
and "orm" not in cls.__module__
and "compiler" not in cls.__module__
):
if a == b:
is_true(
- case_a[a].compare(
- case_b[b], arbitrary_expression=True
- ),
+ case_a[a].compare(case_b[b], compare_annotations=True),
"%r != %r" % (case_a[a], case_b[b]),
)
else:
is_false(
- case_a[a].compare(
- case_b[b], arbitrary_expression=True
- ),
+ case_a[a].compare(case_b[b], compare_annotations=True),
"%r == %r" % (case_a[a], case_b[b]),
)
- def test_cache_key(self):
- def assert_params_append(assert_params):
- def append(param):
- if param._value_required_for_cache:
- assert_params.append(param)
- else:
- is_(param.value, None)
-
- return append
-
- for fixture in self.fixtures:
- case_a = fixture()
- case_b = fixture()
-
- for a, b in itertools.combinations_with_replacement(
- range(len(case_a)), 2
- ):
-
- assert_a_params = []
- assert_b_params = []
-
- visitors.traverse_depthfirst(
- case_a[a],
- {},
- {"bindparam": assert_params_append(assert_a_params)},
- )
- visitors.traverse_depthfirst(
- case_b[b],
- {},
- {"bindparam": assert_params_append(assert_b_params)},
- )
- if assert_a_params:
- assert_raises_message(
- NotImplementedError,
- "bindparams collection argument required ",
- case_a[a]._cache_key,
- )
- if assert_b_params:
- assert_raises_message(
- NotImplementedError,
- "bindparams collection argument required ",
- case_b[b]._cache_key,
- )
-
- if not assert_a_params and not assert_b_params:
- if a == b:
- eq_(case_a[a]._cache_key(), case_b[b]._cache_key())
- else:
- ne_(case_a[a]._cache_key(), case_b[b]._cache_key())
-
- def test_cache_key_gather_bindparams(self):
- for fixture in self.fixtures:
- case_a = fixture()
- case_b = fixture()
-
- # in the "bindparams" case, the cache keys for bound parameters
- # with only different values will be the same, but the params
- # themselves are gathered into a collection.
- for a, b in itertools.combinations_with_replacement(
- range(len(case_a)), 2
- ):
- a_params = {"bindparams": []}
- b_params = {"bindparams": []}
- if a == b:
- a_key = case_a[a]._cache_key(**a_params)
- b_key = case_b[b]._cache_key(**b_params)
- eq_(a_key, b_key)
-
- if a_params["bindparams"]:
- for a_param, b_param in zip(
- a_params["bindparams"], b_params["bindparams"]
- ):
- assert a_param.compare(b_param)
- else:
- a_key = case_a[a]._cache_key(**a_params)
- b_key = case_b[b]._cache_key(**b_params)
-
- if a_key == b_key:
- for a_param, b_param in zip(
- a_params["bindparams"], b_params["bindparams"]
- ):
- if not a_param.compare(b_param):
- break
- else:
- assert False, "Bound parameters are all the same"
- else:
- ne_(a_key, b_key)
-
- assert_a_params = []
- assert_b_params = []
- visitors.traverse_depthfirst(
- case_a[a], {}, {"bindparam": assert_a_params.append}
- )
- visitors.traverse_depthfirst(
- case_b[b], {}, {"bindparam": assert_b_params.append}
- )
-
- # note we're asserting the order of the params as well as
- # if there are dupes or not. ordering has to be deterministic
- # and matches what a traversal would provide.
- eq_(a_params["bindparams"], assert_a_params)
- eq_(b_params["bindparams"], assert_b_params)
-
def test_compare_col_identity(self):
stmt1 = (
select([table_a.c.a, table_b.c.b])
assert case_a[0].compare(case_b[0])
- clone = case_a[0]._clone()
- clone._copy_internals()
+ clone = visitors.replacement_traverse(
+ case_a[0], {}, lambda elem: None
+ )
assert clone.compare(case_b[0])
class CompareClausesTest(fixtures.TestBase):
+ def test_compare_metadata_tables(self):
+ # metadata Table objects cache on their own identity, not their
+ # structure. This is mainly to reduce the size of cache keys
+ # as well as reduce computational overhead, as Table objects have
+ # very large internal state and they are also generally global
+ # objects.
+
+ t1 = Table("a", MetaData(), Column("q", Integer), Column("p", Integer))
+ t2 = Table("a", MetaData(), Column("q", Integer), Column("p", Integer))
+
+ ne_(t1._generate_cache_key(), t2._generate_cache_key())
+
+ eq_(t1._generate_cache_key().key, (t1,))
+
+ def test_compare_adhoc_tables(self):
+ # non-metadata tables compare on their structure. these objects are
+ # not commonly used.
+
+ # note this test is a bit redundant as we have a similar test
+ # via the fixtures also
+ t1 = table("a", Column("q", Integer), Column("p", Integer))
+ t2 = table("a", Column("q", Integer), Column("p", Integer))
+ t3 = table("b", Column("q", Integer), Column("p", Integer))
+ t4 = table("a", Column("q", Integer), Column("x", Integer))
+
+ eq_(t1._generate_cache_key(), t2._generate_cache_key())
+
+ ne_(t1._generate_cache_key(), t3._generate_cache_key())
+ ne_(t1._generate_cache_key(), t4._generate_cache_key())
+ ne_(t3._generate_cache_key(), t4._generate_cache_key())
+
def test_compare_comparison_associative(self):
l1 = table_c.c.x == table_d.c.y
is_true(l1.compare(l2))
is_false(l1.compare(l3))
+ def test_compare_comparison_non_commutative_inverses(self):
+ l1 = table_c.c.x >= table_d.c.y
+ l2 = table_d.c.y < table_c.c.x
+ l3 = table_d.c.y <= table_c.c.x
+
+ # we're not doing this kind of commutativity right now.
+ is_false(l1.compare(l2))
+ is_false(l1.compare(l3))
+
def test_compare_clauselist_associative(self):
l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z)
use_proxies=True,
)
)
+
+ def test_compare_annotated_clears_mapping(self):
+ t = table("t", column("x"), column("y"))
+ x_a = t.c.x._annotate({"foo": True})
+ x_b = t.c.x._annotate({"foo": True})
+
+ is_true(x_a.compare(x_b, compare_annotations=True))
+ is_false(
+ x_a.compare(x_b._annotate({"bar": True}), compare_annotations=True)
+ )
+
+ s1 = select([t.c.x])._annotate({"foo": True})
+ s2 = select([t.c.x])._annotate({"foo": True})
+
+ is_true(s1.compare(s2, compare_annotations=True))
+
+ is_false(
+ s1.compare(s2._annotate({"bar": True}), compare_annotations=True)
+ )
+
+ def test_compare_annotated_wo_annotations(self):
+ t = table("t", column("x"), column("y"))
+ x_a = t.c.x._annotate({})
+ x_b = t.c.x._annotate({"foo": True})
+
+ is_true(t.c.x.compare(x_a))
+ is_true(x_b.compare(x_a))
+
+ is_true(x_a.compare(t.c.x))
+ is_false(x_a.compare(t.c.y))
+ is_false(t.c.y.compare(x_a))
+ is_true((t.c.x == 5).compare(x_a == 5))
+ is_false((t.c.y == 5).compare(x_a == 5))
+
+ s = select([t]).subquery()
+ x_p = s.c.x
+ is_false(x_a.compare(x_p))
+ is_false(t.c.x.compare(x_p))
+ x_p_a = x_p._annotate({})
+ is_true(x_p_a.compare(x_p))
+ is_true(x_p.compare(x_p_a))
+ is_false(x_p_a.compare(x_a))
# identity semantics.
class A(ClauseElement):
__visit_name__ = "a"
+ _traverse_internals = []
def __init__(self, expr):
self.expr = expr
)
)
+ modifiers = operator(left, right).modifiers
+
assert operator(left, right).compare(
BinaryExpression(
coercions.expect(roles.WhereHavingRole, left),
coercions.expect(roles.WhereHavingRole, right),
operator,
+ modifiers=modifiers,
)
)
s4 = s3.with_only_columns([table2.c.b])
self.assert_compile(s4, "SELECT t2.b FROM t2")
- def test_from_list_warning_against_existing(self):
+ def test_from_list_against_existing_one(self):
c1 = Column("c1", Integer)
s = select([c1])
self.assert_compile(s, "SELECT t.c1 FROM t")
- def test_from_list_recovers_after_warning(self):
+ def test_from_list_against_existing_two(self):
c1 = Column("c1", Integer)
c2 = Column("c2", Integer)
# force a compile.
eq_(str(s), "SELECT c1")
- @testing.emits_warning()
- def go():
- return Table("t", MetaData(), c1, c2)
-
- t = go()
+ t = Table("t", MetaData(), c1, c2)
eq_(c1._from_objects, [t])
eq_(c2._from_objects, [t])
- # 's' has been baked. Can't afford
- # not caching select._froms.
- # hopefully the warning will clue the user
self.assert_compile(s, "SELECT t.c1 FROM t")
self.assert_compile(select([c1]), "SELECT t.c1 FROM t")
self.assert_compile(select([c2]), "SELECT t.c2 FROM t")
"foo",
)
+ def test_whereclause_adapted(self):
+ table1 = table("t1", column("a"))
+
+ s1 = select([table1]).subquery()
+
+ s2 = select([s1]).where(s1.c.a == 5)
+
+ assert s2._whereclause.left.table is s1
+
+ ta = select([table1]).subquery()
+
+ s3 = sql_util.ClauseAdapter(ta).traverse(s2)
+
+ assert s1 not in s3._froms
+
+ # these are new assumptions with the newer approach that
+ # actively swaps out whereclause and others
+ assert s3._whereclause.left.table is not s1
+ assert s3._whereclause.left.table in s3._froms
+
class RefreshForNewColTest(fixtures.TestBase):
def test_join_uninit(self):
annot = obj._annotate({})
ne_(set([obj]), set([annot]))
- def test_compare(self):
- t = table("t", column("x"), column("y"))
- x_a = t.c.x._annotate({})
- assert t.c.x.compare(x_a)
- assert x_a.compare(t.c.x)
- assert not x_a.compare(t.c.y)
- assert not t.c.y.compare(x_a)
- assert (t.c.x == 5).compare(x_a == 5)
- assert not (t.c.y == 5).compare(x_a == 5)
-
- s = select([t]).subquery()
- x_p = s.c.x
- assert not x_a.compare(x_p)
- assert not t.c.x.compare(x_p)
- x_p_a = x_p._annotate({})
- assert x_p_a.compare(x_p)
- assert x_p.compare(x_p_a)
- assert not x_p_a.compare(x_a)
-
def test_proxy_set_iteration_includes_annotated(self):
from sqlalchemy.schema import Column
):
# the columns clause isn't changed at all
assert sel._raw_columns[0].table is a1
- assert sel._froms[0] is sel._froms[1].left
+ assert sel._froms[0].element is sel._froms[1].left.element
eq_(str(s), str(sel))
# when we are modifying annotations sets only
- # partially, each element is copied unconditionally
- # when encountered.
+ # partially, elements are copied uniquely based on id().
+ # this is new as of 1.4, previously they'd be copied every time
for sel in (
sql_util._deep_deannotate(s, {"foo": "bar"}),
sql_util._deep_annotate(s, {"foo": "bar"}),
class MiscTest(fixtures.TestBase):
def test_column_element_no_visit(self):
class MyElement(ColumnElement):
- pass
+ _traverse_internals = []
eq_(sql_util.find_tables(MyElement(), check_columns=True), [])