From: Mike Bayer Date: Thu, 29 Aug 2019 18:45:23 +0000 (-0400) Subject: Add anonymizing context to cache keys, comparison; convert traversal X-Git-Tag: rel_1_4_0b1~638^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=29330ec1596f12462c501a65404ff52005b16b6c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add anonymizing context to cache keys, comparison; convert traversal Created new visitor system called "internal traversal" that applies a data driven approach to the concept of a class that defines its own traversal steps, in contrast to the existing style of traversal now known as "external traversal" where the visitor class defines the traversal, i.e. the SQLCompiler. The internal traversal system now implements get_children(), _copy_internals(), compare() and _cache_key() for most Core elements. Core elements with special needs like Select still implement some of these methods directly however most of these methods are no longer explicitly implemented. The data-driven system is also applied to ORM elements that take part in SQL expressions so that these objects, like mappers, aliasedclass, query options, etc. can all participate in the cache key process. Still not considered is that this approach to defining traversibility will be used to create some kind of generic introspection system that works across Core / ORM. It's also not clear if real statement caching using the _cache_key() method is feasible, if it is shown that running _cache_key() is nearly as expensive as compiling in any case. Because it is data driven, it is more straightforward to optimize using inlined code, as is the case now, as well as potentially using C code to speed it up. In addition, the caching sytem now accommodates for anonymous name labels, which is essential so that constructs which have anonymous labels can be cacheable, that is, their position within a statement in relation to other anonymous names causes them to generate an integer counter relative to that construct which will be the same every time. Gathering of bound parameters from any cache key generation is also now required as there is no use case for a cache key that does not extract bound parameter values. Applies-to: #4639 Change-Id: I0660584def8627cad566719ee98d3be045db4b8d --- diff --git a/doc/build/core/sqlelement.rst b/doc/build/core/sqlelement.rst index 41a27af711..f7f9cab641 100644 --- a/doc/build/core/sqlelement.rst +++ b/doc/build/core/sqlelement.rst @@ -82,6 +82,9 @@ the FROM clause of a SELECT statement. .. autoclass:: BindParameter :members: +.. autoclass:: CacheKey + :members: + .. autoclass:: Case :members: @@ -90,6 +93,7 @@ the FROM clause of a SELECT statement. .. autoclass:: ClauseElement :members: + :inherited-members: .. autoclass:: ClauseList diff --git a/doc/build/core/visitors.rst b/doc/build/core/visitors.rst index 02f6e24fc4..539d664405 100644 --- a/doc/build/core/visitors.rst +++ b/doc/build/core/visitors.rst @@ -23,3 +23,4 @@ as well as when building out custom SQL expressions using the .. automodule:: sqlalchemy.sql.visitors :members: + :private-members: \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py index 293aa426da..b43b364fab 100644 --- a/lib/sqlalchemy/dialects/mysql/dml.py +++ b/lib/sqlalchemy/dialects/mysql/dml.py @@ -103,7 +103,6 @@ class Insert(StandardInsert): inserted_alias = getattr(self, "inserted_alias", None) self._post_values_clause = OnDuplicateClause(inserted_alias, values) - return self insert = public_factory(Insert, ".dialects.mysql.insert") diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 909d568a7f..e94f9913cd 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1658,23 +1658,20 @@ class PGCompiler(compiler.SQLCompiler): return "ONLY " + sqltext def get_select_precolumns(self, select, **kw): - if select._distinct is not False: - if select._distinct is True: - return "DISTINCT " - elif isinstance(select._distinct, (list, tuple)): + if select._distinct or select._distinct_on: + if select._distinct_on: return ( "DISTINCT ON (" + ", ".join( - [self.process(col, **kw) for col in select._distinct] + [ + self.process(col, **kw) + for col in select._distinct_on + ] ) + ") " ) else: - return ( - "DISTINCT ON (" - + self.process(select._distinct, **kw) - + ") " - ) + return "DISTINCT " else: return "" diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index 4e77f5a4c9..f4467976af 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -103,7 +103,6 @@ class Insert(StandardInsert): self._post_values_clause = OnConflictDoUpdate( constraint, index_elements, index_where, set_, where ) - return self @_generative def on_conflict_do_nothing( @@ -138,7 +137,6 @@ class Insert(StandardInsert): self._post_values_clause = OnConflictDoNothing( constraint, index_elements, index_where ) - return self insert = public_factory(Insert, ".dialects.postgresql.insert") diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index d18a35a407..8e137f141a 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -198,7 +198,7 @@ class BakedQuery(object): self.spoil() else: for opt in options: - cache_key = opt._generate_cache_key(cache_path) + cache_key = opt._generate_path_cache_key(cache_path) if cache_key is False: self.spoil() elif cache_key is not None: diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 4a5a8ba9cd..c2b2347587 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -455,7 +455,7 @@ def deregister(class_): if hasattr(class_, "_compiler_dispatcher"): # regenerate default _compiler_dispatch - visitors._generate_dispatch(class_) + visitors._generate_compiler_dispatch(class_) # remove custom directive del class_._compiler_dispatcher diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 83069f113c..aa2986205b 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -47,6 +47,8 @@ from .base import state_str from .. import event from .. import inspection from .. import util +from ..sql import base as sql_base +from ..sql import visitors @inspection._self_inspects @@ -54,6 +56,7 @@ class QueryableAttribute( interfaces._MappedAttribute, interfaces.InspectionAttr, interfaces.PropComparator, + sql_base.HasCacheKey, ): """Base class for :term:`descriptor` objects that intercept attribute events on behalf of a :class:`.MapperProperty` @@ -102,6 +105,13 @@ class QueryableAttribute( if base[key].dispatch._active_history: self.dispatch._active_history = True + _cache_key_traversal = [ + # ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("key", visitors.ExtendedInternalTraversal.dp_string), + ("_parententity", visitors.ExtendedInternalTraversal.dp_multi), + ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), + ] + @util.memoized_property def _supports_population(self): return self.impl.supports_population diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 6f8d192934..a3dea6b0e3 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -216,7 +216,6 @@ def _assertions(*assertions): for assertion in assertions: assertion(self, fn.__name__) fn(self, *args[1:], **kw) - return self return generate diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index e94a81fedb..704ce9df79 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -36,6 +36,8 @@ from .. import inspect from .. import inspection from .. import util from ..sql import operators +from ..sql import visitors +from ..sql.traversals import HasCacheKey __all__ = ( @@ -54,7 +56,9 @@ __all__ = ( ) -class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots): +class MapperProperty( + HasCacheKey, _MappedAttribute, InspectionAttr, util.MemoizedSlots +): """Represent a particular class attribute mapped by :class:`.Mapper`. The most common occurrences of :class:`.MapperProperty` are the @@ -74,6 +78,11 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots): "info", ) + _cache_key_traversal = [ + ("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ("key", visitors.ExtendedInternalTraversal.dp_string), + ] + cascade = frozenset() """The set of 'cascade' attribute names. @@ -647,7 +656,7 @@ class MapperOption(object): self.process_query(query) - def _generate_cache_key(self, path): + def _generate_path_cache_key(self, path): """Used by the "baked lazy loader" to see if this option can be cached. The "baked lazy loader" refers to the :class:`.Query` that is diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 376ad19233..548eca58db 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -71,7 +71,7 @@ _CONFIGURE_MUTEX = util.threading.RLock() @inspection._self_inspects @log.class_logger -class Mapper(InspectionAttr): +class Mapper(sql_base.HasCacheKey, InspectionAttr): """Define the correlation of class attributes to database table columns. @@ -729,6 +729,10 @@ class Mapper(InspectionAttr): """ return self + _cache_key_traversal = [ + ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj) + ] + @property def entity(self): r"""Part of the inspection API. diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 2f680a3a16..585cb80bc3 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -15,7 +15,8 @@ from .base import class_mapper from .. import exc from .. import inspection from .. import util - +from ..sql import visitors +from ..sql.traversals import HasCacheKey log = logging.getLogger(__name__) @@ -28,7 +29,7 @@ _WILDCARD_TOKEN = "*" _DEFAULT_TOKEN = "_sa_default" -class PathRegistry(object): +class PathRegistry(HasCacheKey): """Represent query load paths and registry functions. Basically represents structures like: @@ -57,6 +58,10 @@ class PathRegistry(object): is_token = False is_root = False + _cache_key_traversal = [ + ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list) + ] + def __eq__(self, other): return other is not None and self.path == other.path @@ -78,6 +83,9 @@ class PathRegistry(object): def __len__(self): return len(self.path) + def __hash__(self): + return id(self) + @property def length(self): return len(self.path) diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 26f47f6169..99bbbe37c7 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -26,11 +26,13 @@ from .. import inspect from .. import util from ..sql import coercions from ..sql import roles +from ..sql import visitors from ..sql.base import _generative from ..sql.base import Generative +from ..sql.traversals import HasCacheKey -class Load(Generative, MapperOption): +class Load(HasCacheKey, Generative, MapperOption): """Represents loader options which modify the state of a :class:`.Query` in order to affect how various mapped attributes are loaded. @@ -70,6 +72,17 @@ class Load(Generative, MapperOption): """ + _cache_key_traversal = [ + ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), + ( + "_context_cache_key", + visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples, + ), + ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict), + ] + def __init__(self, entity): insp = inspect(entity) self.path = insp._path_registry @@ -89,7 +102,16 @@ class Load(Generative, MapperOption): load._of_type = None return load - def _generate_cache_key(self, path): + @property + def _context_cache_key(self): + serialized = [] + for (key, loader_path), obj in self.context.items(): + if key != "loader": + continue + serialized.append(loader_path + (obj,)) + return serialized + + def _generate_path_cache_key(self, path): if path.path[0].is_aliased_class: return False @@ -522,9 +544,16 @@ class _UnboundLoad(Load): self._to_bind = [] self.local_opts = {} + _cache_key_traversal = [ + ("path", visitors.ExtendedInternalTraversal.dp_multi_list), + ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("_to_bind", visitors.ExtendedInternalTraversal.dp_has_cache_key_list), + ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict), + ] + _is_chain_link = False - def _generate_cache_key(self, path): + def _generate_path_cache_key(self, path): serialized = () for val in self._to_bind: for local_elem, val_elem in zip(self.path, val.path): @@ -533,7 +562,7 @@ class _UnboundLoad(Load): else: opt = val._bind_loader([path.path[0]], None, None, False) if opt: - c_key = opt._generate_cache_key(path) + c_key = opt._generate_path_cache_key(path) if c_key is False: return False elif c_key: @@ -660,7 +689,6 @@ class _UnboundLoad(Load): opt = meth(opt, all_tokens[-1], **kw) opt._is_chain_link = False - return opt def _chop_path(self, to_chop, path): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 5f0f41e8d9..c869936783 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -30,10 +30,12 @@ from .. import exc as sa_exc from .. import inspection from .. import sql from .. import util +from ..sql import base as sql_base from ..sql import coercions from ..sql import expression from ..sql import roles from ..sql import util as sql_util +from ..sql import visitors all_cascades = frozenset( @@ -530,7 +532,7 @@ class AliasedClass(object): return str(self._aliased_insp) -class AliasedInsp(InspectionAttr): +class AliasedInsp(sql_base.HasCacheKey, InspectionAttr): """Provide an inspection interface for an :class:`.AliasedClass` object. @@ -627,6 +629,12 @@ class AliasedInsp(InspectionAttr): def __clause_element__(self): return self.selectable + _cache_key_traversal = [ + ("name", visitors.ExtendedInternalTraversal.dp_string), + ("_adapt_on_names", visitors.ExtendedInternalTraversal.dp_boolean), + ("selectable", visitors.ExtendedInternalTraversal.dp_clauseelement), + ] + @property def class_(self): """Return the mapped class ultimately represented by this diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index a0264845e3..0d995ec8a2 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -12,12 +12,32 @@ associations. """ from . import operators +from .base import HasCacheKey +from .visitors import InternalTraversal from .. import util -class SupportsCloneAnnotations(object): +class SupportsAnnotations(object): + @util.memoized_property + def _annotation_traversals(self): + return [ + ( + key, + InternalTraversal.dp_has_cache_key + if isinstance(value, HasCacheKey) + else InternalTraversal.dp_plain_obj, + ) + for key, value in self._annotations.items() + ] + + +class SupportsCloneAnnotations(SupportsAnnotations): _annotations = util.immutabledict() + _traverse_internals = [ + ("_annotations", InternalTraversal.dp_annotations_state) + ] + def _annotate(self, values): """return a copy of this ClauseElement with annotations updated by the given dictionary. @@ -25,6 +45,7 @@ class SupportsCloneAnnotations(object): """ new = self._clone() new._annotations = new._annotations.union(values) + new.__dict__.pop("_annotation_traversals", None) return new def _with_annotations(self, values): @@ -34,6 +55,7 @@ class SupportsCloneAnnotations(object): """ new = self._clone() new._annotations = util.immutabledict(values) + new.__dict__.pop("_annotation_traversals", None) return new def _deannotate(self, values=None, clone=False): @@ -49,12 +71,13 @@ class SupportsCloneAnnotations(object): # the expression for a deep deannotation new = self._clone() new._annotations = {} + new.__dict__.pop("_annotation_traversals", None) return new else: return self -class SupportsWrappingAnnotations(object): +class SupportsWrappingAnnotations(SupportsAnnotations): def _annotate(self, values): """return a copy of this ClauseElement with annotations updated by the given dictionary. @@ -123,6 +146,7 @@ class Annotated(object): def __init__(self, element, values): self.__dict__ = element.__dict__.copy() + self.__dict__.pop("_annotation_traversals", None) self.__element = element self._annotations = values self._hash = hash(element) @@ -135,6 +159,7 @@ class Annotated(object): def _with_annotations(self, values): clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() + clone.__dict__.pop("_annotation_traversals", None) clone._annotations = values return clone @@ -192,7 +217,17 @@ def _deep_annotate(element, annotations, exclude=None): """ - def clone(elem): + # annotated objects hack the __hash__() method so if we want to + # uniquely process them we have to use id() + + cloned_ids = {} + + def clone(elem, **kw): + id_ = id(elem) + + if id_ in cloned_ids: + return cloned_ids[id_] + if ( exclude and hasattr(elem, "proxy_set") @@ -204,6 +239,7 @@ def _deep_annotate(element, annotations, exclude=None): else: newelem = elem newelem._copy_internals(clone=clone) + cloned_ids[id_] = newelem return newelem if element is not None: @@ -214,23 +250,21 @@ def _deep_annotate(element, annotations, exclude=None): def _deep_deannotate(element, values=None): """Deep copy the given element, removing annotations.""" - cloned = util.column_dict() + cloned = {} - def clone(elem): - # if a values dict is given, - # the elem must be cloned each time it appears, - # as there may be different annotations in source - # elements that are remaining. if totally - # removing all annotations, can assume the same - # slate... - if values or elem not in cloned: + def clone(elem, **kw): + if values: + key = id(elem) + else: + key = elem + + if key not in cloned: newelem = elem._deannotate(values=values, clone=True) newelem._copy_internals(clone=clone) - if not values: - cloned[elem] = newelem + cloned[key] = newelem return newelem else: - return cloned[elem] + return cloned[key] if element is not None: element = clone(element) @@ -268,6 +302,11 @@ def _new_annotation_type(cls, base_cls): "Annotated%s" % cls.__name__, (base_cls, cls), {} ) globals()["Annotated%s" % cls.__name__] = anno_cls + + if "_traverse_internals" in cls.__dict__: + anno_cls._traverse_internals = list(cls._traverse_internals) + [ + ("_annotations", InternalTraversal.dp_annotations_state) + ] return anno_cls diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 7e9199bfa8..d11a3a3139 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -14,6 +14,7 @@ import itertools import operator import re +from .traversals import HasCacheKey # noqa from .visitors import ClauseVisitor from .. import exc from .. import util @@ -38,18 +39,41 @@ class Immutable(object): def _clone(self): return self + def _copy_internals(self, **kw): + pass + + +class HasMemoized(object): + def _reset_memoizations(self): + self._memoized_property.expire_instance(self) + + def _reset_exported(self): + self._memoized_property.expire_instance(self) + + def _copy_internals(self, **kw): + super(HasMemoized, self)._copy_internals(**kw) + self._reset_memoizations() + def _from_objects(*elements): return itertools.chain(*[element._from_objects for element in elements]) def _generative(fn): + """non-caching _generative() decorator. + + This is basically the legacy decorator that copies the object and + runs a method on the new copy. + + """ + @util.decorator - def _generative(fn, *args, **kw): + def _generative(fn, self, *args, **kw): """Mark a method as generative.""" - self = args[0]._generate() - fn(self, *args[1:], **kw) + self = self._generate() + x = fn(self, *args, **kw) + assert x is None, "generative methods must have no return value" return self decorated = _generative(fn) @@ -357,10 +381,8 @@ class DialectKWArgs(object): class Generative(object): - """Allow a ClauseElement to generate itself via the - @_generative decorator. - - """ + """Provide a method-chaining pattern in conjunction with the + @_generative decorator.""" def _generate(self): s = self.__class__.__new__(self.__class__) diff --git a/lib/sqlalchemy/sql/clause_compare.py b/lib/sqlalchemy/sql/clause_compare.py deleted file mode 100644 index 30a90348c9..0000000000 --- a/lib/sqlalchemy/sql/clause_compare.py +++ /dev/null @@ -1,334 +0,0 @@ -from collections import deque - -from . import operators -from .. import util - - -SKIP_TRAVERSE = util.symbol("skip_traverse") - - -def compare(obj1, obj2, **kw): - if kw.get("use_proxies", False): - strategy = ColIdentityComparatorStrategy() - else: - strategy = StructureComparatorStrategy() - - return strategy.compare(obj1, obj2, **kw) - - -class StructureComparatorStrategy(object): - __slots__ = "compare_stack", "cache" - - def __init__(self): - self.compare_stack = deque() - self.cache = set() - - def compare(self, obj1, obj2, **kw): - stack = self.compare_stack - cache = self.cache - - stack.append((obj1, obj2)) - - while stack: - left, right = stack.popleft() - - if left is right: - continue - elif left is None or right is None: - # we know they are different so no match - return False - elif (left, right) in cache: - continue - cache.add((left, right)) - - visit_name = left.__visit_name__ - - # we're not exactly looking for identical types, because - # there are things like Column and AnnotatedColumn. So the - # visit_name has to at least match up - if visit_name != right.__visit_name__: - return False - - meth = getattr(self, "compare_%s" % visit_name, None) - - if meth: - comparison = meth(left, right, **kw) - if comparison is False: - return False - elif comparison is SKIP_TRAVERSE: - continue - - for c1, c2 in util.zip_longest( - left.get_children(column_collections=False), - right.get_children(column_collections=False), - fillvalue=None, - ): - if c1 is None or c2 is None: - # collections are different sizes, comparison fails - return False - stack.append((c1, c2)) - - return True - - def compare_inner(self, obj1, obj2, **kw): - stack = self.compare_stack - try: - self.compare_stack = deque() - return self.compare(obj1, obj2, **kw) - finally: - self.compare_stack = stack - - def _compare_unordered_sequences(self, seq1, seq2, **kw): - if seq1 is None: - return seq2 is None - - completed = set() - for clause in seq1: - for other_clause in set(seq2).difference(completed): - if self.compare_inner(clause, other_clause, **kw): - completed.add(other_clause) - break - return len(completed) == len(seq1) == len(seq2) - - def compare_bindparam(self, left, right, **kw): - # note the ".key" is often generated from id(self) so can't - # be compared, as far as determining structure. - return ( - left.type._compare_type_affinity(right.type) - and left.value == right.value - and left.callable == right.callable - and left._orig_key == right._orig_key - ) - - def compare_clauselist(self, left, right, **kw): - if left.operator is right.operator: - if operators.is_associative(left.operator): - if self._compare_unordered_sequences( - left.clauses, right.clauses - ): - return SKIP_TRAVERSE - else: - return False - else: - # normal ordered traversal - return True - else: - return False - - def compare_unary(self, left, right, **kw): - if left.operator: - disp = self._get_operator_dispatch( - left.operator, "unary", "operator" - ) - if disp is not None: - result = disp(left, right, left.operator, **kw) - if result is not True: - return result - elif left.modifier: - disp = self._get_operator_dispatch( - left.modifier, "unary", "modifier" - ) - if disp is not None: - result = disp(left, right, left.operator, **kw) - if result is not True: - return result - return ( - left.operator == right.operator and left.modifier == right.modifier - ) - - def compare_binary(self, left, right, **kw): - disp = self._get_operator_dispatch(left.operator, "binary", None) - if disp: - result = disp(left, right, left.operator, **kw) - if result is not True: - return result - - if left.operator == right.operator: - if operators.is_commutative(left.operator): - if ( - compare(left.left, right.left, **kw) - and compare(left.right, right.right, **kw) - ) or ( - compare(left.left, right.right, **kw) - and compare(left.right, right.left, **kw) - ): - return SKIP_TRAVERSE - else: - return False - else: - return True - else: - return False - - def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): - # used by compare_binary, compare_unary - attrname = "visit_%s_%s%s" % ( - operator_.__name__, - qualifier1, - "_" + qualifier2 if qualifier2 else "", - ) - return getattr(self, attrname, None) - - def visit_function_as_comparison_op_binary( - self, left, right, operator, **kw - ): - return ( - left.left_index == right.left_index - and left.right_index == right.right_index - ) - - def compare_function(self, left, right, **kw): - return left.name == right.name - - def compare_column(self, left, right, **kw): - if left.table is not None: - self.compare_stack.appendleft((left.table, right.table)) - return ( - left.key == right.key - and left.name == right.name - and ( - left.type._compare_type_affinity(right.type) - if left.type is not None - else right.type is None - ) - and left.is_literal == right.is_literal - ) - - def compare_collation(self, left, right, **kw): - return left.collation == right.collation - - def compare_type_coerce(self, left, right, **kw): - return left.type._compare_type_affinity(right.type) - - @util.dependencies("sqlalchemy.sql.elements") - def compare_alias(self, elements, left, right, **kw): - return ( - left.name == right.name - if not isinstance(left.name, elements._anonymous_label) - else isinstance(right.name, elements._anonymous_label) - ) - - def compare_cte(self, elements, left, right, **kw): - raise NotImplementedError("TODO") - - def compare_extract(self, left, right, **kw): - return left.field == right.field - - def compare_textual_label_reference(self, left, right, **kw): - return left.element == right.element - - def compare_slice(self, left, right, **kw): - return ( - left.start == right.start - and left.stop == right.stop - and left.step == right.step - ) - - def compare_over(self, left, right, **kw): - return left.range_ == right.range_ and left.rows == right.rows - - @util.dependencies("sqlalchemy.sql.elements") - def compare_label(self, elements, left, right, **kw): - return left._type._compare_type_affinity(right._type) and ( - left.name == right.name - if not isinstance(left.name, elements._anonymous_label) - else isinstance(right.name, elements._anonymous_label) - ) - - def compare_typeclause(self, left, right, **kw): - return left.type._compare_type_affinity(right.type) - - def compare_join(self, left, right, **kw): - return left.isouter == right.isouter and left.full == right.full - - def compare_table(self, left, right, **kw): - if left.name != right.name: - return False - - self.compare_stack.extendleft( - util.zip_longest(left.columns, right.columns) - ) - - def compare_compound_select(self, left, right, **kw): - - if not self._compare_unordered_sequences( - left.selects, right.selects, **kw - ): - return False - - if left.keyword != right.keyword: - return False - - if left._for_update_arg != right._for_update_arg: - return False - - if not self.compare_inner( - left._order_by_clause, right._order_by_clause, **kw - ): - return False - - if not self.compare_inner( - left._group_by_clause, right._group_by_clause, **kw - ): - return False - - return SKIP_TRAVERSE - - def compare_select(self, left, right, **kw): - if not self._compare_unordered_sequences( - left._correlate, right._correlate - ): - return False - if not self._compare_unordered_sequences( - left._correlate_except, right._correlate_except - ): - return False - - if not self._compare_unordered_sequences( - left._from_obj, right._from_obj - ): - return False - - if left._for_update_arg != right._for_update_arg: - return False - - return True - - def compare_textual_select(self, left, right, **kw): - self.compare_stack.extendleft( - util.zip_longest(left.column_args, right.column_args) - ) - return left.positional == right.positional - - -class ColIdentityComparatorStrategy(StructureComparatorStrategy): - def compare_column_element( - self, left, right, use_proxies=True, equivalents=(), **kw - ): - """Compare ColumnElements using proxies and equivalent collections. - - This is a comparison strategy specific to the ORM. - """ - - to_compare = (right,) - if equivalents and right in equivalents: - to_compare = equivalents[right].union(to_compare) - - for oth in to_compare: - if use_proxies and left.shares_lineage(oth): - return True - elif hash(left) == hash(right): - return True - else: - return False - - def compare_column(self, left, right, **kw): - return self.compare_column_element(left, right, **kw) - - def compare_label(self, left, right, **kw): - return self.compare_column_element(left, right, **kw) - - def compare_table(self, left, right, **kw): - # tables compare on identity, since it's not really feasible to - # compare them column by column with the above rules - return left is right diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5ecec7d6c2..546fffc6c4 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -434,6 +434,27 @@ class _CompileLabel(elements.ColumnElement): return self +class prefix_anon_map(dict): + """A map that creates new keys for missing key access. + + Considers keys of the form " " to produce + new symbols "_", where "index" is an incrementing integer + corresponding to . + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + def __missing__(self, key): + (ident, derived) = key.split(" ", 1) + anonymous_counter = self.get(derived, 1) + self[derived] = anonymous_counter + 1 + value = derived + "_" + str(anonymous_counter) + self[key] = value + return value + + class SQLCompiler(Compiled): """Default implementation of :class:`.Compiled`. @@ -574,7 +595,7 @@ class SQLCompiler(Compiled): # a map which tracks "anonymous" identifiers that are created on # the fly here - self.anon_map = util.PopulateDict(self._process_anon) + self.anon_map = prefix_anon_map() # a map which tracks "truncated" names based on # dialect.label_length or dialect.max_identifier_length @@ -1712,12 +1733,6 @@ class SQLCompiler(Compiled): def _anonymize(self, name): return name % self.anon_map - def _process_anon(self, key): - (ident, derived) = key.split(" ", 1) - anonymous_counter = self.anon_map.get(derived, 1) - self.anon_map[derived] = anonymous_counter + 1 - return derived + "_" + str(anonymous_counter) - def bindparam_string( self, name, diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 918f7524e5..c0baa85556 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -178,6 +178,9 @@ def _unsupported_impl(expr, op, *arg, **kw): def _inv_impl(expr, op, **kw): """See :meth:`.ColumnOperators.__inv__`.""" + + # undocumented element currently used by the ORM for + # relationship.contains() if hasattr(expr, "negation_clause"): return expr.negation_clause else: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index e6f57b8d11..ba615bc3fa 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -16,23 +16,29 @@ import itertools import operator import re -from . import clause_compare from . import coercions from . import operators from . import roles +from . import traversals from . import type_api from .annotation import Annotated from .annotation import SupportsWrappingAnnotations from .base import _clone from .base import _generative from .base import Executable +from .base import HasCacheKey +from .base import HasMemoized from .base import Immutable from .base import NO_ARG from .base import PARSE_AUTOCOMMIT from .coercions import _document_text_coercion +from .traversals import _copy_internals +from .traversals import _get_children +from .traversals import NO_CACHE from .visitors import cloned_traverse +from .visitors import InternalTraversal from .visitors import traverse -from .visitors import Visitable +from .visitors import Traversible from .. import exc from .. import inspection from .. import util @@ -162,7 +168,9 @@ def not_(clause): @inspection._self_inspects -class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): +class ClauseElement( + roles.SQLRole, SupportsWrappingAnnotations, HasCacheKey, Traversible +): """Base class for elements of a programmatically constructed SQL expression. @@ -190,6 +198,13 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): _order_by_label_element = None + @property + def _cache_key_traversal(self): + try: + return self._traverse_internals + except AttributeError: + return NO_CACHE + def _clone(self): """Create a shallow copy of this ClauseElement. @@ -221,28 +236,6 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): """ return self - def _cache_key(self, **kw): - """return an optional cache key. - - The cache key is a tuple which can contain any series of - objects that are hashable and also identifies - this object uniquely within the presence of a larger SQL expression - or statement, for the purposes of caching the resulting query. - - The cache key should be based on the SQL compiled structure that would - ultimately be produced. That is, two structures that are composed in - exactly the same way should produce the same cache key; any difference - in the strucures that would affect the SQL string or the type handlers - should result in a different cache key. - - If a structure cannot produce a useful cache key, it should raise - NotImplementedError, which will result in the entire structure - for which it's part of not being useful as a cache key. - - - """ - raise NotImplementedError() - @property def _constructor(self): """return the 'constructor' for this ClauseElement. @@ -336,9 +329,9 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): (see :class:`.ColumnElement`) """ - return clause_compare.compare(self, other, **kw) + return traversals.compare(self, other, **kw) - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals(self, **kw): """Reassign internal elements to be clones of themselves. Called during a copy-and-traverse operation on newly @@ -349,21 +342,46 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): traversal, cloned traversal, annotations). """ - pass - def get_children(self, **kwargs): - r"""Return immediate child elements of this :class:`.ClauseElement`. + try: + traverse_internals = self._traverse_internals + except AttributeError: + return + + for attrname, obj, meth in _copy_internals.run_generated_dispatch( + self, traverse_internals, "_generated_copy_internals_traversal" + ): + if obj is not None: + result = meth(self, obj, **kw) + if result is not None: + setattr(self, attrname, result) + + def get_children(self, omit_attrs=None, **kw): + r"""Return immediate child :class:`.Traversible` elements of this + :class:`.Traversible`. This is used for visit traversal. - \**kwargs may contain flags that change the collection that is + \**kw may contain flags that change the collection that is returned, for example to return a subset of items in order to cut down on larger traversals, or to return child items from a different context (such as schema-level collections instead of clause-level). """ - return [] + result = [] + try: + traverse_internals = self._traverse_internals + except AttributeError: + return result + + for attrname, obj, meth in _get_children.run_generated_dispatch( + self, traverse_internals, "_generated_get_children_traversal" + ): + if obj is None or omit_attrs and attrname in omit_attrs: + continue + result.extend(meth(obj, **kw)) + return result def self_group(self, against=None): # type: (Optional[Any]) -> ClauseElement @@ -501,6 +519,8 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): return or_(self, other) def __invert__(self): + # undocumented element currently used by the ORM for + # relationship.contains() if hasattr(self, "negation_clause"): return self.negation_clause else: @@ -508,9 +528,7 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): def _negate(self): return UnaryExpression( - self.self_group(against=operators.inv), - operator=operators.inv, - negate=None, + self.self_group(against=operators.inv), operator=operators.inv ) def __bool__(self): @@ -731,9 +749,6 @@ class ColumnElement( else: return comparator_factory(self) - def _cache_key(self, **kw): - raise NotImplementedError(self.__class__) - def __getattr__(self, key): try: return getattr(self.comparator, key) @@ -969,6 +984,13 @@ class BindParameter(roles.InElementRole, ColumnElement): __visit_name__ = "bindparam" + _traverse_internals = [ + ("key", InternalTraversal.dp_anon_name), + ("type", InternalTraversal.dp_type), + ("callable", InternalTraversal.dp_plain_dict), + ("value", InternalTraversal.dp_plain_obj), + ] + _is_crud = False _expanding_in_types = () @@ -1321,26 +1343,19 @@ class BindParameter(roles.InElementRole, ColumnElement): ) return c - def _cache_key(self, bindparams=None, **kw): - if bindparams is None: - # even though _cache_key is a private method, we would like to - # be super paranoid about this point. You can't include the - # "value" or "callable" in the cache key, because the value is - # not part of the structure of a statement and is likely to - # change every time. However you cannot *throw it away* either, - # because you can't invoke the statement without the parameter - # values that were explicitly placed. So require that they - # are collected here to make sure this happens. - if self._value_required_for_cache: - raise NotImplementedError( - "bindparams collection argument required for _cache_key " - "implementation. Bound parameter cache keys are not safe " - "to use without accommodating for the value or callable " - "within the parameter itself." - ) - else: - bindparams.append(self) - return (BindParameter, self.type._cache_key, self._orig_key) + def _gen_cache_key(self, anon_map, bindparams): + if self in anon_map: + return (anon_map[self], self.__class__) + + id_ = anon_map[self] + bindparams.append(self) + + return ( + id_, + self.__class__, + self.type._gen_cache_key, + traversals._resolve_name_for_compare(self, self.key, anon_map), + ) def _convert_to_unique(self): if not self.unique: @@ -1377,12 +1392,11 @@ class TypeClause(ClauseElement): __visit_name__ = "typeclause" + _traverse_internals = [("type", InternalTraversal.dp_type)] + def __init__(self, type_): self.type = type_ - def _cache_key(self, **kw): - return (TypeClause, self.type._cache_key) - class TextClause( roles.DDLConstraintColumnRole, @@ -1419,6 +1433,11 @@ class TextClause( __visit_name__ = "textclause" + _traverse_internals = [ + ("_bindparams", InternalTraversal.dp_string_clauseelement_dict), + ("text", InternalTraversal.dp_string), + ] + _is_text_clause = True _is_textual = True @@ -1861,19 +1880,6 @@ class TextClause( else: return self - def _copy_internals(self, clone=_clone, **kw): - self._bindparams = dict( - (b.key, clone(b, **kw)) for b in self._bindparams.values() - ) - - def get_children(self, **kwargs): - return list(self._bindparams.values()) - - def _cache_key(self, **kw): - return (self.text,) + tuple( - bind._cache_key for bind in self._bindparams.values() - ) - class Null(roles.ConstExprRole, ColumnElement): """Represent the NULL keyword in a SQL statement. @@ -1885,6 +1891,8 @@ class Null(roles.ConstExprRole, ColumnElement): __visit_name__ = "null" + _traverse_internals = [] + @util.memoized_property def type(self): return type_api.NULLTYPE @@ -1895,9 +1903,6 @@ class Null(roles.ConstExprRole, ColumnElement): return Null() - def _cache_key(self, **kw): - return (Null,) - class False_(roles.ConstExprRole, ColumnElement): """Represent the ``false`` keyword, or equivalent, in a SQL statement. @@ -1908,6 +1913,7 @@ class False_(roles.ConstExprRole, ColumnElement): """ __visit_name__ = "false" + _traverse_internals = [] @util.memoized_property def type(self): @@ -1954,9 +1960,6 @@ class False_(roles.ConstExprRole, ColumnElement): return False_() - def _cache_key(self, **kw): - return (False_,) - class True_(roles.ConstExprRole, ColumnElement): """Represent the ``true`` keyword, or equivalent, in a SQL statement. @@ -1968,6 +1971,8 @@ class True_(roles.ConstExprRole, ColumnElement): __visit_name__ = "true" + _traverse_internals = [] + @util.memoized_property def type(self): return type_api.BOOLEANTYPE @@ -2020,9 +2025,6 @@ class True_(roles.ConstExprRole, ColumnElement): return True_() - def _cache_key(self, **kw): - return (True_,) - class ClauseList( roles.InElementRole, @@ -2038,6 +2040,11 @@ class ClauseList( __visit_name__ = "clauselist" + _traverse_internals = [ + ("clauses", InternalTraversal.dp_clauseelement_list), + ("operator", InternalTraversal.dp_operator), + ] + def __init__(self, *clauses, **kwargs): self.operator = kwargs.pop("operator", operators.comma_op) self.group = kwargs.pop("group", True) @@ -2082,17 +2089,6 @@ class ClauseList( coercions.expect(self._text_converter_role, clause) ) - def _copy_internals(self, clone=_clone, **kw): - self.clauses = [clone(clause, **kw) for clause in self.clauses] - - def get_children(self, **kwargs): - return self.clauses - - def _cache_key(self, **kw): - return (ClauseList, self.operator) + tuple( - clause._cache_key(**kw) for clause in self.clauses - ) - @property def _from_objects(self): return list(itertools.chain(*[c._from_objects for c in self.clauses])) @@ -2115,11 +2111,6 @@ class BooleanClauseList(ClauseList, ColumnElement): "BooleanClauseList has a private constructor" ) - def _cache_key(self, **kw): - return (BooleanClauseList, self.operator) + tuple( - clause._cache_key(**kw) for clause in self.clauses - ) - @classmethod def _construct(cls, operator, continue_on, skip_on, *clauses, **kw): convert_clauses = [] @@ -2250,6 +2241,8 @@ or_ = BooleanClauseList.or_ class Tuple(ClauseList, ColumnElement): """Represent a SQL tuple.""" + _traverse_internals = ClauseList._traverse_internals + [] + def __init__(self, *clauses, **kw): """Return a :class:`.Tuple`. @@ -2289,11 +2282,6 @@ class Tuple(ClauseList, ColumnElement): def _select_iterable(self): return (self,) - def _cache_key(self, **kw): - return (Tuple,) + tuple( - clause._cache_key(**kw) for clause in self.clauses - ) - def _bind_param(self, operator, obj, type_=None): return Tuple( *[ @@ -2339,6 +2327,12 @@ class Case(ColumnElement): __visit_name__ = "case" + _traverse_internals = [ + ("value", InternalTraversal.dp_clauseelement), + ("whens", InternalTraversal.dp_clauseelement_tuples), + ("else_", InternalTraversal.dp_clauseelement), + ] + def __init__(self, whens, value=None, else_=None): r"""Produce a ``CASE`` expression. @@ -2501,40 +2495,6 @@ class Case(ColumnElement): else: self.else_ = None - def _copy_internals(self, clone=_clone, **kw): - if self.value is not None: - self.value = clone(self.value, **kw) - self.whens = [(clone(x, **kw), clone(y, **kw)) for x, y in self.whens] - if self.else_ is not None: - self.else_ = clone(self.else_, **kw) - - def get_children(self, **kwargs): - if self.value is not None: - yield self.value - for x, y in self.whens: - yield x - yield y - if self.else_ is not None: - yield self.else_ - - def _cache_key(self, **kw): - return ( - ( - Case, - self.value._cache_key(**kw) - if self.value is not None - else None, - ) - + tuple( - (x._cache_key(**kw), y._cache_key(**kw)) for x, y in self.whens - ) - + ( - self.else_._cache_key(**kw) - if self.else_ is not None - else None, - ) - ) - @property def _from_objects(self): return list( @@ -2603,6 +2563,11 @@ class Cast(WrapsColumnExpression, ColumnElement): __visit_name__ = "cast" + _traverse_internals = [ + ("clause", InternalTraversal.dp_clauseelement), + ("typeclause", InternalTraversal.dp_clauseelement), + ] + def __init__(self, expression, type_): r"""Produce a ``CAST`` expression. @@ -2662,20 +2627,6 @@ class Cast(WrapsColumnExpression, ColumnElement): ) self.typeclause = TypeClause(self.type) - def _copy_internals(self, clone=_clone, **kw): - self.clause = clone(self.clause, **kw) - self.typeclause = clone(self.typeclause, **kw) - - def get_children(self, **kwargs): - return self.clause, self.typeclause - - def _cache_key(self, **kw): - return ( - Cast, - self.clause._cache_key(**kw), - self.typeclause._cache_key(**kw), - ) - @property def _from_objects(self): return self.clause._from_objects @@ -2685,7 +2636,7 @@ class Cast(WrapsColumnExpression, ColumnElement): return self.clause -class TypeCoerce(WrapsColumnExpression, ColumnElement): +class TypeCoerce(HasMemoized, WrapsColumnExpression, ColumnElement): """Represent a Python-side type-coercion wrapper. :class:`.TypeCoerce` supplies the :func:`.expression.type_coerce` @@ -2705,6 +2656,13 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement): __visit_name__ = "type_coerce" + _traverse_internals = [ + ("clause", InternalTraversal.dp_clauseelement), + ("type", InternalTraversal.dp_type), + ] + + _memoized_property = util.group_expirable_memoized_property() + def __init__(self, expression, type_): r"""Associate a SQL expression with a particular type, without rendering ``CAST``. @@ -2773,21 +2731,11 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement): roles.ExpressionElementRole, expression, type_=self.type ) - def _copy_internals(self, clone=_clone, **kw): - self.clause = clone(self.clause, **kw) - self.__dict__.pop("typed_expression", None) - - def get_children(self, **kwargs): - return (self.clause,) - - def _cache_key(self, **kw): - return (TypeCoerce, self.type._cache_key, self.clause._cache_key(**kw)) - @property def _from_objects(self): return self.clause._from_objects - @util.memoized_property + @_memoized_property def typed_expression(self): if isinstance(self.clause, BindParameter): bp = self.clause._clone() @@ -2806,6 +2754,11 @@ class Extract(ColumnElement): __visit_name__ = "extract" + _traverse_internals = [ + ("expr", InternalTraversal.dp_clauseelement), + ("field", InternalTraversal.dp_string), + ] + def __init__(self, field, expr, **kwargs): """Return a :class:`.Extract` construct. @@ -2818,15 +2771,6 @@ class Extract(ColumnElement): self.field = field self.expr = coercions.expect(roles.ExpressionElementRole, expr) - def _copy_internals(self, clone=_clone, **kw): - self.expr = clone(self.expr, **kw) - - def get_children(self, **kwargs): - return (self.expr,) - - def _cache_key(self, **kw): - return (Extract, self.field, self.expr._cache_key(**kw)) - @property def _from_objects(self): return self.expr._from_objects @@ -2847,18 +2791,11 @@ class _label_reference(ColumnElement): __visit_name__ = "label_reference" + _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + def __init__(self, element): self.element = element - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def _cache_key(self, **kw): - return (_label_reference, self.element._cache_key(**kw)) - - def get_children(self, **kwargs): - return [self.element] - @property def _from_objects(self): return () @@ -2867,6 +2804,8 @@ class _label_reference(ColumnElement): class _textual_label_reference(ColumnElement): __visit_name__ = "textual_label_reference" + _traverse_internals = [("element", InternalTraversal.dp_string)] + def __init__(self, element): self.element = element @@ -2874,9 +2813,6 @@ class _textual_label_reference(ColumnElement): def _text_clause(self): return TextClause._create_text(self.element) - def _cache_key(self, **kw): - return (_textual_label_reference, self.element) - class UnaryExpression(ColumnElement): """Define a 'unary' expression. @@ -2894,13 +2830,18 @@ class UnaryExpression(ColumnElement): __visit_name__ = "unary" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("operator", InternalTraversal.dp_operator), + ("modifier", InternalTraversal.dp_operator), + ] + def __init__( self, element, operator=None, modifier=None, type_=None, - negate=None, wraps_column_expression=False, ): self.operator = operator @@ -2909,7 +2850,6 @@ class UnaryExpression(ColumnElement): against=self.operator or self.modifier ) self.type = type_api.to_instance(type_) - self.negate = negate self.wraps_column_expression = wraps_column_expression @classmethod @@ -3135,37 +3075,13 @@ class UnaryExpression(ColumnElement): def _from_objects(self): return self.element._from_objects - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def _cache_key(self, **kw): - return ( - UnaryExpression, - self.element._cache_key(**kw), - self.operator, - self.modifier, - ) - - def get_children(self, **kwargs): - return (self.element,) - def _negate(self): - if self.negate is not None: - return UnaryExpression( - self.element, - operator=self.negate, - negate=self.operator, - modifier=self.modifier, - type_=self.type, - wraps_column_expression=self.wraps_column_expression, - ) - elif self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: + if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: return UnaryExpression( self.self_group(against=operators.inv), operator=operators.inv, type_=type_api.BOOLEANTYPE, wraps_column_expression=self.wraps_column_expression, - negate=None, ) else: return ClauseElement._negate(self) @@ -3286,15 +3202,6 @@ class AsBoolean(WrapsColumnExpression, UnaryExpression): # type: (Optional[Any]) -> ClauseElement return self - def _cache_key(self, **kw): - return ( - self.element._cache_key(**kw), - self.type._cache_key, - self.operator, - self.negate, - self.modifier, - ) - def _negate(self): if isinstance(self.element, (True_, False_)): return self.element._negate() @@ -3318,6 +3225,14 @@ class BinaryExpression(ColumnElement): __visit_name__ = "binary" + _traverse_internals = [ + ("left", InternalTraversal.dp_clauseelement), + ("right", InternalTraversal.dp_clauseelement), + ("operator", InternalTraversal.dp_operator), + ("negate", InternalTraversal.dp_operator), + ("modifiers", InternalTraversal.dp_plain_dict), + ] + _is_implicitly_boolean = True """Indicates that any database will know this is a boolean expression even if the database does not have an explicit boolean datatype. @@ -3360,20 +3275,6 @@ class BinaryExpression(ColumnElement): def _from_objects(self): return self.left._from_objects + self.right._from_objects - def _copy_internals(self, clone=_clone, **kw): - self.left = clone(self.left, **kw) - self.right = clone(self.right, **kw) - - def get_children(self, **kwargs): - return self.left, self.right - - def _cache_key(self, **kw): - return ( - BinaryExpression, - self.left._cache_key(**kw), - self.right._cache_key(**kw), - ) - def self_group(self, against=None): # type: (Optional[Any]) -> ClauseElement @@ -3406,6 +3307,12 @@ class Slice(ColumnElement): __visit_name__ = "slice" + _traverse_internals = [ + ("start", InternalTraversal.dp_plain_obj), + ("stop", InternalTraversal.dp_plain_obj), + ("step", InternalTraversal.dp_plain_obj), + ] + def __init__(self, start, stop, step): self.start = start self.stop = stop @@ -3417,9 +3324,6 @@ class Slice(ColumnElement): assert against is operator.getitem return self - def _cache_key(self, **kw): - return (Slice, self.start, self.stop, self.step) - class IndexExpression(BinaryExpression): """Represent the class of expressions that are like an "index" operation. @@ -3444,6 +3348,11 @@ class GroupedElement(ClauseElement): class Grouping(GroupedElement, ColumnElement): """Represent a grouping within a column expression""" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("type", InternalTraversal.dp_type), + ] + def __init__(self, element): self.element = element self.type = getattr(element, "type", type_api.NULLTYPE) @@ -3460,15 +3369,6 @@ class Grouping(GroupedElement, ColumnElement): def _label(self): return getattr(self.element, "_label", None) or self.anon_label - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def get_children(self, **kwargs): - return (self.element,) - - def _cache_key(self, **kw): - return (Grouping, self.element._cache_key(**kw)) - @property def _from_objects(self): return self.element._from_objects @@ -3501,6 +3401,14 @@ class Over(ColumnElement): __visit_name__ = "over" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("order_by", InternalTraversal.dp_clauseelement), + ("partition_by", InternalTraversal.dp_clauseelement), + ("range_", InternalTraversal.dp_plain_obj), + ("rows", InternalTraversal.dp_plain_obj), + ] + order_by = None partition_by = None @@ -3667,30 +3575,6 @@ class Over(ColumnElement): def type(self): return self.element.type - def get_children(self, **kwargs): - return [ - c - for c in (self.element, self.partition_by, self.order_by) - if c is not None - ] - - def _cache_key(self, **kw): - return ( - (Over,) - + tuple( - e._cache_key(**kw) if e is not None else None - for e in (self.element, self.partition_by, self.order_by) - ) - + (self.range_, self.rows) - ) - - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - if self.partition_by is not None: - self.partition_by = clone(self.partition_by, **kw) - if self.order_by is not None: - self.order_by = clone(self.order_by, **kw) - @property def _from_objects(self): return list( @@ -3723,6 +3607,11 @@ class WithinGroup(ColumnElement): __visit_name__ = "withingroup" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("order_by", InternalTraversal.dp_clauseelement), + ] + order_by = None def __init__(self, element, *order_by): @@ -3791,25 +3680,6 @@ class WithinGroup(ColumnElement): else: return self.element.type - def get_children(self, **kwargs): - return [c for c in (self.element, self.order_by) if c is not None] - - def _cache_key(self, **kw): - return ( - WithinGroup, - self.element._cache_key(**kw) - if self.element is not None - else None, - self.order_by._cache_key(**kw) - if self.order_by is not None - else None, - ) - - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - if self.order_by is not None: - self.order_by = clone(self.order_by, **kw) - @property def _from_objects(self): return list( @@ -3845,6 +3715,11 @@ class FunctionFilter(ColumnElement): __visit_name__ = "funcfilter" + _traverse_internals = [ + ("func", InternalTraversal.dp_clauseelement), + ("criterion", InternalTraversal.dp_clauseelement), + ] + criterion = None def __init__(self, func, *criterion): @@ -3932,23 +3807,6 @@ class FunctionFilter(ColumnElement): def type(self): return self.func.type - def get_children(self, **kwargs): - return [c for c in (self.func, self.criterion) if c is not None] - - def _copy_internals(self, clone=_clone, **kw): - self.func = clone(self.func, **kw) - if self.criterion is not None: - self.criterion = clone(self.criterion, **kw) - - def _cache_key(self, **kw): - return ( - FunctionFilter, - self.func._cache_key(**kw), - self.criterion._cache_key(**kw) - if self.criterion is not None - else None, - ) - @property def _from_objects(self): return list( @@ -3962,7 +3820,7 @@ class FunctionFilter(ColumnElement): ) -class Label(roles.LabeledColumnExprRole, ColumnElement): +class Label(HasMemoized, roles.LabeledColumnExprRole, ColumnElement): """Represents a column label (AS). Represent a label, as typically applied to any column-level @@ -3972,6 +3830,14 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): __visit_name__ = "label" + _traverse_internals = [ + ("name", InternalTraversal.dp_anon_name), + ("_type", InternalTraversal.dp_type), + ("_element", InternalTraversal.dp_clauseelement), + ] + + _memoized_property = util.group_expirable_memoized_property() + def __init__(self, name, element, type_=None): """Return a :class:`Label` object for the given :class:`.ColumnElement`. @@ -4010,14 +3876,11 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): def __reduce__(self): return self.__class__, (self.name, self._element, self._type) - def _cache_key(self, **kw): - return (Label, self.element._cache_key(**kw), self._resolve_label) - @util.memoized_property def _is_implicitly_boolean(self): return self.element._is_implicitly_boolean - @util.memoized_property + @_memoized_property def _allow_label_resolve(self): return self.element._allow_label_resolve @@ -4031,7 +3894,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): self._type or getattr(self._element, "type", None) ) - @util.memoized_property + @_memoized_property def element(self): return self._element.self_group(against=operators.as_) @@ -4057,13 +3920,9 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): def foreign_keys(self): return self.element.foreign_keys - def get_children(self, **kwargs): - return (self.element,) - def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw): + self._reset_memoizations() self._element = clone(self._element, **kw) - self.__dict__.pop("element", None) - self.__dict__.pop("_allow_label_resolve", None) if anonymize_labels: self.name = self._resolve_label = _anonymous_label( "%%(%d %s)s" @@ -4124,6 +3983,13 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): __visit_name__ = "column" + _traverse_internals = [ + ("name", InternalTraversal.dp_string), + ("type", InternalTraversal.dp_type), + ("table", InternalTraversal.dp_clauseelement), + ("is_literal", InternalTraversal.dp_boolean), + ] + onupdate = default = server_default = server_onupdate = None _is_multiparam_column = False @@ -4254,14 +4120,6 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): table = property(_get_table, _set_table) - def _cache_key(self, **kw): - return ( - self.name, - self.table.name if self.table is not None else None, - self.is_literal, - self.type._cache_key, - ) - @_memoized_property def _from_objects(self): t = self.table @@ -4395,12 +4253,11 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): class CollationClause(ColumnElement): __visit_name__ = "collation" + _traverse_internals = [("collation", InternalTraversal.dp_string)] + def __init__(self, collation): self.collation = collation - def _cache_key(self, **kw): - return (CollationClause, self.collation) - class _IdentifiedClause(Executable, ClauseElement): diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 7ce822669c..08e69f075a 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -86,7 +86,6 @@ __all__ = [ from .base import _from_objects # noqa from .base import ColumnCollection # noqa from .base import Executable # noqa -from .base import Generative # noqa from .base import PARSE_AUTOCOMMIT # noqa from .dml import Delete # noqa from .dml import Insert # noqa @@ -242,7 +241,6 @@ _UnaryExpression = UnaryExpression _Case = Case _Tuple = Tuple _Over = Over -_Generative = Generative _TypeClause = TypeClause _Extract = Extract _Exists = Exists diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index cbc8e539fa..96e64dc284 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -17,7 +17,6 @@ from . import sqltypes from . import util as sqlutil from .base import ColumnCollection from .base import Executable -from .elements import _clone from .elements import _type_from_args from .elements import BinaryExpression from .elements import BindParameter @@ -33,7 +32,8 @@ from .elements import WithinGroup from .selectable import Alias from .selectable import FromClause from .selectable import Select -from .visitors import VisitableType +from .visitors import InternalTraversal +from .visitors import TraversibleType from .. import util @@ -78,10 +78,14 @@ class FunctionElement(Executable, ColumnElement, FromClause): """ + _traverse_internals = [("clause_expr", InternalTraversal.dp_clauseelement)] + packagenames = () _has_args = False + _memoized_property = FromClause._memoized_property + def __init__(self, *clauses, **kwargs): r"""Construct a :class:`.FunctionElement`. @@ -136,7 +140,7 @@ class FunctionElement(Executable, ColumnElement, FromClause): col = self.label(None) return ColumnCollection(columns=[(col.key, col)]) - @util.memoized_property + @_memoized_property def clauses(self): """Return the underlying :class:`.ClauseList` which contains the arguments for this :class:`.FunctionElement`. @@ -283,17 +287,6 @@ class FunctionElement(Executable, ColumnElement, FromClause): def _from_objects(self): return self.clauses._from_objects - def get_children(self, **kwargs): - return (self.clause_expr,) - - def _cache_key(self, **kw): - return (FunctionElement, self.clause_expr._cache_key(**kw)) - - def _copy_internals(self, clone=_clone, **kw): - self.clause_expr = clone(self.clause_expr, **kw) - self._reset_exported() - FunctionElement.clauses._reset(self) - def within_group_type(self, within_group): """For types that define their return type as based on the criteria within a WITHIN GROUP (ORDER BY) expression, called by the @@ -404,6 +397,13 @@ class FunctionElement(Executable, ColumnElement, FromClause): class FunctionAsBinary(BinaryExpression): + _traverse_internals = [ + ("sql_function", InternalTraversal.dp_clauseelement), + ("left_index", InternalTraversal.dp_plain_obj), + ("right_index", InternalTraversal.dp_plain_obj), + ("modifiers", InternalTraversal.dp_plain_dict), + ] + def __init__(self, fn, left_index, right_index): self.sql_function = fn self.left_index = left_index @@ -431,20 +431,6 @@ class FunctionAsBinary(BinaryExpression): def right(self, value): self.sql_function.clauses.clauses[self.right_index - 1] = value - def _copy_internals(self, clone=_clone, **kw): - self.sql_function = clone(self.sql_function, **kw) - - def get_children(self, **kw): - yield self.sql_function - - def _cache_key(self, **kw): - return ( - FunctionAsBinary, - self.sql_function._cache_key(**kw), - self.left_index, - self.right_index, - ) - class _FunctionGenerator(object): """Generate SQL function expressions. @@ -606,6 +592,12 @@ class Function(FunctionElement): __visit_name__ = "function" + _traverse_internals = FunctionElement._traverse_internals + [ + ("packagenames", InternalTraversal.dp_plain_obj), + ("name", InternalTraversal.dp_string), + ("type", InternalTraversal.dp_type), + ] + def __init__(self, name, *clauses, **kw): """Construct a :class:`.Function`. @@ -630,15 +622,8 @@ class Function(FunctionElement): unique=True, ) - def _cache_key(self, **kw): - return ( - (Function,) + tuple(self.packagenames) - if self.packagenames - else () + (self.name, self.clause_expr._cache_key(**kw)) - ) - -class _GenericMeta(VisitableType): +class _GenericMeta(TraversibleType): def __init__(cls, clsname, bases, clsdict): if annotation.Annotated not in cls.__mro__: cls.name = name = clsdict.get("name", clsname) @@ -764,6 +749,10 @@ class next_value(GenericFunction): type = sqltypes.Integer() name = "next_value" + _traverse_internals = [ + ("sequence", InternalTraversal.dp_named_ddl_element) + ] + def __init__(self, seq, **kw): assert isinstance( seq, schema.Sequence @@ -771,21 +760,12 @@ class next_value(GenericFunction): self._bind = kw.get("bind", None) self.sequence = seq - def _cache_key(self, **kw): - return (next_value, self.sequence.name) - def compare(self, other, **kw): return ( isinstance(other, next_value) and self.sequence.name == other.sequence.name ) - def get_children(self, **kwargs): - return [] - - def _copy_internals(self, **kw): - pass - @property def _from_objects(self): return [] diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 4e8f4a3970..ee7dc61ce8 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -50,6 +50,7 @@ from .elements import ColumnElement from .elements import quoted_name from .elements import TextClause from .selectable import TableClause +from .visitors import InternalTraversal from .. import event from .. import exc from .. import inspection @@ -425,6 +426,21 @@ class Table(DialectKWArgs, SchemaItem, TableClause): __visit_name__ = "table" + _traverse_internals = TableClause._traverse_internals + [ + ("schema", InternalTraversal.dp_string) + ] + + def _gen_cache_key(self, anon_map, bindparams): + return (self,) + + @util.deprecated_params( + useexisting=( + "0.7", + "The :paramref:`.Table.useexisting` parameter is deprecated and " + "will be removed in a future release. Please use " + ":paramref:`.Table.extend_existing`.", + ) + ) def __new__(cls, *args, **kw): if not args: # python3k pickle seems to call this @@ -763,6 +779,8 @@ class Table(DialectKWArgs, SchemaItem, TableClause): def get_children( self, column_collections=True, schema_visitor=False, **kw ): + # TODO: consider that we probably don't need column_collections=True + # at all, it does not seem to impact anything if not schema_visitor: return TableClause.get_children( self, column_collections=column_collections, **kw diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 6a7413fc09..4b3844eec1 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -31,6 +31,7 @@ from .base import ColumnSet from .base import DedupeColumnCollection from .base import Executable from .base import Generative +from .base import HasMemoized from .base import Immutable from .coercions import _document_text_coercion from .elements import _anonymous_label @@ -39,11 +40,13 @@ from .elements import and_ from .elements import BindParameter from .elements import ClauseElement from .elements import ClauseList +from .elements import ColumnClause from .elements import GroupedElement from .elements import Grouping from .elements import literal_column from .elements import True_ from .elements import UnaryExpression +from .visitors import InternalTraversal from .. import exc from .. import util @@ -201,6 +204,8 @@ class Selectable(ReturnsRows): class HasPrefixes(object): _prefixes = () + _traverse_internals = [("_prefixes", InternalTraversal.dp_prefix_sequence)] + @_generative @_document_text_coercion( "expr", @@ -252,6 +257,8 @@ class HasPrefixes(object): class HasSuffixes(object): _suffixes = () + _traverse_internals = [("_suffixes", InternalTraversal.dp_prefix_sequence)] + @_generative @_document_text_coercion( "expr", @@ -295,7 +302,7 @@ class HasSuffixes(object): ) -class FromClause(roles.AnonymizedFromClauseRole, Selectable): +class FromClause(HasMemoized, roles.AnonymizedFromClauseRole, Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -529,11 +536,6 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ return getattr(self, "name", self.__class__.__name__ + " object") - def _reset_exported(self): - """delete memoized collections when a FromClause is cloned.""" - - self._memoized_property.expire_instance(self) - def _generate_fromclause_column_proxies(self, fromclause): fromclause._columns._populate_separate_keys( col._make_proxy(fromclause) for col in self.c @@ -668,6 +670,14 @@ class Join(FromClause): __visit_name__ = "join" + _traverse_internals = [ + ("left", InternalTraversal.dp_clauseelement), + ("right", InternalTraversal.dp_clauseelement), + ("onclause", InternalTraversal.dp_clauseelement), + ("isouter", InternalTraversal.dp_boolean), + ("full", InternalTraversal.dp_boolean), + ] + _is_join = True def __init__(self, left, right, onclause=None, isouter=False, full=False): @@ -805,25 +815,6 @@ class Join(FromClause): self.left._refresh_for_new_column(column) self.right._refresh_for_new_column(column) - def _copy_internals(self, clone=_clone, **kw): - self._reset_exported() - self.left = clone(self.left, **kw) - self.right = clone(self.right, **kw) - self.onclause = clone(self.onclause, **kw) - - def get_children(self, **kwargs): - return self.left, self.right, self.onclause - - def _cache_key(self, **kw): - return ( - Join, - self.isouter, - self.full, - self.left._cache_key(**kw), - self.right._cache_key(**kw), - self.onclause._cache_key(**kw), - ) - def _match_primaries(self, left, right): if isinstance(left, Join): left_right = left.right @@ -1175,6 +1166,11 @@ class AliasedReturnsRows(FromClause): _is_from_container = True named_with_column = True + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("name", InternalTraversal.dp_anon_name), + ] + def __init__(self, *arg, **kw): raise NotImplementedError( "The %s class is not intended to be constructed " @@ -1243,18 +1239,13 @@ class AliasedReturnsRows(FromClause): def _copy_internals(self, clone=_clone, **kw): element = clone(self.element, **kw) + + # the element clone is usually against a Table that returns the + # same object. don't reset exported .c. collections and other + # memoized details if nothing changed if element is not self.element: self._reset_exported() - self.element = element - - def get_children(self, column_collections=True, **kw): - if column_collections: - for c in self.c: - yield c - yield self.element - - def _cache_key(self, **kw): - return (self.__class__, self.element._cache_key(**kw), self._orig_name) + self.element = element @property def _from_objects(self): @@ -1396,6 +1387,11 @@ class TableSample(AliasedReturnsRows): __visit_name__ = "tablesample" + _traverse_internals = AliasedReturnsRows._traverse_internals + [ + ("sampling", InternalTraversal.dp_clauseelement), + ("seed", InternalTraversal.dp_clauseelement), + ] + @classmethod def _factory(cls, selectable, sampling, name=None, seed=None): """Return a :class:`.TableSample` object. @@ -1466,6 +1462,16 @@ class CTE(Generative, HasSuffixes, AliasedReturnsRows): __visit_name__ = "cte" + _traverse_internals = ( + AliasedReturnsRows._traverse_internals + + [ + ("_cte_alias", InternalTraversal.dp_clauseelement), + ("_restates", InternalTraversal.dp_clauseelement_unordered_set), + ("recursive", InternalTraversal.dp_boolean), + ] + + HasSuffixes._traverse_internals + ) + @classmethod def _factory(cls, selectable, name=None, recursive=False): r"""Return a new :class:`.CTE`, or Common Table Expression instance. @@ -1495,15 +1501,13 @@ class CTE(Generative, HasSuffixes, AliasedReturnsRows): def _copy_internals(self, clone=_clone, **kw): super(CTE, self)._copy_internals(clone, **kw) + # TODO: I don't like that we can't use the traversal data here if self._cte_alias is not None: self._cte_alias = clone(self._cte_alias, **kw) self._restates = frozenset( [clone(elem, **kw) for elem in self._restates] ) - def _cache_key(self, *arg, **kw): - raise NotImplementedError("TODO") - def alias(self, name=None, flat=False): """Return an :class:`.Alias` of this :class:`.CTE`. @@ -1764,6 +1768,8 @@ class Subquery(AliasedReturnsRows): class FromGrouping(GroupedElement, FromClause): """Represent a grouping of a FROM clause""" + _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + def __init__(self, element): self.element = coercions.expect(roles.FromClauseRole, element) @@ -1792,15 +1798,6 @@ class FromGrouping(GroupedElement, FromClause): def _hide_froms(self): return self.element._hide_froms - def get_children(self, **kwargs): - return (self.element,) - - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def _cache_key(self, **kw): - return (FromGrouping, self.element._cache_key(**kw)) - @property def _from_objects(self): return self.element._from_objects @@ -1843,6 +1840,14 @@ class TableClause(Immutable, FromClause): __visit_name__ = "table" + _traverse_internals = [ + ( + "columns", + InternalTraversal.dp_fromclause_canonical_column_collection, + ), + ("name", InternalTraversal.dp_string), + ] + named_with_column = True implicit_returning = False @@ -1895,17 +1900,6 @@ class TableClause(Immutable, FromClause): self._columns.add(c) c.table = self - def get_children(self, column_collections=True, **kwargs): - if column_collections: - return [c for c in self.c] - else: - return [] - - def _cache_key(self, **kw): - return (TableClause, self.name) + tuple( - col._cache_key(**kw) for col in self._columns - ) - @util.dependencies("sqlalchemy.sql.dml") def insert(self, dml, values=None, inline=False, **kwargs): """Generate an :func:`.insert` construct against this @@ -1965,6 +1959,13 @@ class TableClause(Immutable, FromClause): class ForUpdateArg(ClauseElement): + _traverse_internals = [ + ("of", InternalTraversal.dp_clauseelement_list), + ("nowait", InternalTraversal.dp_boolean), + ("read", InternalTraversal.dp_boolean), + ("skip_locked", InternalTraversal.dp_boolean), + ] + @classmethod def parse_legacy_select(self, arg): """Parse the for_update argument of :func:`.select`. @@ -2029,19 +2030,6 @@ class ForUpdateArg(ClauseElement): def __hash__(self): return id(self) - def _copy_internals(self, clone=_clone, **kw): - if self.of is not None: - self.of = [clone(col, **kw) for col in self.of] - - def _cache_key(self, **kw): - return ( - ForUpdateArg, - self.nowait, - self.read, - self.skip_locked, - self.of._cache_key(**kw) if self.of is not None else None, - ) - def __init__( self, nowait=False, @@ -2074,6 +2062,7 @@ class SelectBase( roles.DMLSelectRole, roles.CompoundElementRole, roles.InElementRole, + HasMemoized, HasCTE, Executable, SupportsCloneAnnotations, @@ -2092,9 +2081,6 @@ class SelectBase( _memoized_property = util.group_expirable_memoized_property() - def _reset_memoizations(self): - self._memoized_property.expire_instance(self) - def _generate_fromclause_column_proxies(self, fromclause): # type: (FromClause) raise NotImplementedError() @@ -2339,6 +2325,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase): """ __visit_name__ = "grouping" + _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] _is_select_container = True @@ -2350,9 +2337,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase): def select_statement(self): return self.element - def get_children(self, **kwargs): - return (self.element,) - def self_group(self, against=None): # type: (Optional[Any]) -> FromClause return self @@ -2377,12 +2361,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase): """ return self.element.selected_columns - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def _cache_key(self, **kw): - return (SelectStatementGrouping, self.element._cache_key(**kw)) - @property def _from_objects(self): return self.element._from_objects @@ -2758,9 +2736,6 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): def _label_resolve_dict(self): raise NotImplementedError() - def _copy_internals(self, clone=_clone, **kw): - raise NotImplementedError() - class CompoundSelect(GenerativeSelect): """Forms the basis of ``UNION``, ``UNION ALL``, and other @@ -2785,6 +2760,16 @@ class CompoundSelect(GenerativeSelect): __visit_name__ = "compound_select" + _traverse_internals = [ + ("selects", InternalTraversal.dp_clauseelement_list), + ("_limit_clause", InternalTraversal.dp_clauseelement), + ("_offset_clause", InternalTraversal.dp_clauseelement), + ("_order_by_clause", InternalTraversal.dp_clauseelement), + ("_group_by_clause", InternalTraversal.dp_clauseelement), + ("_for_update_arg", InternalTraversal.dp_clauseelement), + ("keyword", InternalTraversal.dp_string), + ] + SupportsCloneAnnotations._traverse_internals + UNION = util.symbol("UNION") UNION_ALL = util.symbol("UNION ALL") EXCEPT = util.symbol("EXCEPT") @@ -3004,47 +2989,6 @@ class CompoundSelect(GenerativeSelect): """ return self.selects[0].selected_columns - def _copy_internals(self, clone=_clone, **kw): - self._reset_memoizations() - self.selects = [clone(s, **kw) for s in self.selects] - if hasattr(self, "_col_map"): - del self._col_map - for attr in ( - "_limit_clause", - "_offset_clause", - "_order_by_clause", - "_group_by_clause", - "_for_update_arg", - ): - if getattr(self, attr) is not None: - setattr(self, attr, clone(getattr(self, attr), **kw)) - - def get_children(self, **kwargs): - return [self._order_by_clause, self._group_by_clause] + list( - self.selects - ) - - def _cache_key(self, **kw): - return ( - (CompoundSelect, self.keyword) - + tuple(stmt._cache_key(**kw) for stmt in self.selects) - + ( - self._order_by_clause._cache_key(**kw) - if self._order_by_clause is not None - else None, - ) - + ( - self._group_by_clause._cache_key(**kw) - if self._group_by_clause is not None - else None, - ) - + ( - self._for_update_arg._cache_key(**kw) - if self._for_update_arg is not None - else None, - ) - ) - def bind(self): if self._bind: return self._bind @@ -3193,11 +3137,35 @@ class Select( _hints = util.immutabledict() _statement_hints = () _distinct = False - _from_cloned = None + _distinct_on = () _correlate = () _correlate_except = None _memoized_property = SelectBase._memoized_property + _traverse_internals = ( + [ + ("_from_obj", InternalTraversal.dp_fromclause_ordered_set), + ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ("_whereclause", InternalTraversal.dp_clauseelement), + ("_having", InternalTraversal.dp_clauseelement), + ("_order_by_clause", InternalTraversal.dp_clauseelement_list), + ("_group_by_clause", InternalTraversal.dp_clauseelement_list), + ("_correlate", InternalTraversal.dp_clauseelement_unordered_set), + ( + "_correlate_except", + InternalTraversal.dp_clauseelement_unordered_set, + ), + ("_for_update_arg", InternalTraversal.dp_clauseelement), + ("_statement_hints", InternalTraversal.dp_statement_hint_list), + ("_hints", InternalTraversal.dp_table_hint_list), + ("_distinct", InternalTraversal.dp_boolean), + ("_distinct_on", InternalTraversal.dp_clauseelement_list), + ] + + HasPrefixes._traverse_internals + + HasSuffixes._traverse_internals + + SupportsCloneAnnotations._traverse_internals + ) + @util.deprecated_params( autocommit=( "0.6", @@ -3416,13 +3384,14 @@ class Select( """ self._auto_correlate = correlate if distinct is not False: - if distinct is True: - self._distinct = True - else: - self._distinct = [ - coercions.expect(roles.WhereHavingRole, e) - for e in util.to_list(distinct) - ] + self._distinct = True + if not isinstance(distinct, bool): + self._distinct_on = tuple( + [ + coercions.expect(roles.WhereHavingRole, e) + for e in util.to_list(distinct) + ] + ) if from_obj is not None: self._from_obj = util.OrderedSet( @@ -3472,15 +3441,17 @@ class Select( GenerativeSelect.__init__(self, **kwargs) + # @_memoized_property @property def _froms(self): - # would love to cache this, - # but there's just enough edge cases, particularly now that - # declarative encourages construction of SQL expressions - # without tables present, to just regen this each time. + # current roadblock to caching is two tests that test that the + # SELECT can be compiled to a string, then a Table is created against + # columns, then it can be compiled again and works. this is somewhat + # valid as people make select() against declarative class where + # columns don't have their Table yet and perhaps some operations + # call upon _froms and cache it too soon. froms = [] seen = set() - translate = self._from_cloned for item in itertools.chain( _from_objects(*self._raw_columns), @@ -3493,8 +3464,6 @@ class Select( raise exc.InvalidRequestError( "select() construct refers to itself as a FROM" ) - if translate and item in translate: - item = translate[item] if not seen.intersection(item._cloned_set): froms.append(item) seen.update(item._cloned_set) @@ -3518,15 +3487,6 @@ class Select( itertools.chain(*[_expand_cloned(f._hide_froms) for f in froms]) ) if toremove: - # if we're maintaining clones of froms, - # add the copies out to the toremove list. only include - # clones that are lexical equivalents. - if self._from_cloned: - toremove.update( - self._from_cloned[f] - for f in toremove.intersection(self._from_cloned) - if self._from_cloned[f]._is_lexical_equivalent(f) - ) # filter out to FROM clauses not in the list, # using a list to maintain ordering froms = [f for f in froms if f not in toremove] @@ -3707,7 +3667,6 @@ class Select( return False def _copy_internals(self, clone=_clone, **kw): - # Select() object has been cloned and probably adapted by the # given clone function. Apply the cloning function to internal # objects @@ -3719,37 +3678,42 @@ class Select( # as of 0.7.4 we also put the current version of _froms, which # gets cleared on each generation. previously we were "baking" # _froms into self._from_obj. - self._from_cloned = from_cloned = dict( - (f, clone(f, **kw)) for f in self._from_obj.union(self._froms) - ) - # 3. update persistent _from_obj with the cloned versions. - self._from_obj = util.OrderedSet( - from_cloned[f] for f in self._from_obj + all_the_froms = list( + itertools.chain( + _from_objects(*self._raw_columns), + _from_objects(self._whereclause) + if self._whereclause is not None + else (), + ) ) + new_froms = {f: clone(f, **kw) for f in all_the_froms} + # copy FROM collections - # the _correlate collection is done separately, what can happen - # here is the same item is _correlate as in _from_obj but the - # _correlate version has an annotation on it - (specifically - # RelationshipProperty.Comparator._criterion_exists() does - # this). Also keep _correlate liberally open with its previous - # contents, as this set is used for matching, not rendering. - self._correlate = set(clone(f) for f in self._correlate).union( - self._correlate - ) + self._from_obj = util.OrderedSet( + clone(f, **kw) for f in self._from_obj + ).union(f for f in new_froms.values() if isinstance(f, Join)) - # do something similar for _correlate_except - this is a more - # unusual case but same idea applies + self._correlate = set(clone(f) for f in self._correlate) if self._correlate_except: self._correlate_except = set( clone(f) for f in self._correlate_except - ).union(self._correlate_except) + ) # 4. clone other things. The difficulty here is that Column - # objects are not actually cloned, and refer to their original - # .table, resulting in the wrong "from" parent after a clone - # operation. Hence _from_cloned and _from_obj supersede what is - # present here. + # objects are usually not altered by a straight clone because they + # are dependent on the FROM cloning we just did above in order to + # be targeted correctly, or a new FROM we have might be a JOIN + # object which doesn't have its own columns. so give the cloner a + # hint. + def replace(obj, **kw): + if isinstance(obj, ColumnClause) and obj.table in new_froms: + newelem = new_froms[obj.table].corresponding_column(obj) + return newelem + + kw["replace"] = replace + + # TODO: I'd still like to try to leverage the traversal data self._raw_columns = [clone(c, **kw) for c in self._raw_columns] for attr in ( "_limit_clause", @@ -3763,67 +3727,12 @@ class Select( if getattr(self, attr) is not None: setattr(self, attr, clone(getattr(self, attr), **kw)) - # erase _froms collection, - # etc. self._reset_memoizations() def get_children(self, **kwargs): - """return child elements as per the ClauseElement specification.""" - - return ( - self._raw_columns - + list(self._froms) - + [ - x - for x in ( - self._whereclause, - self._having, - self._order_by_clause, - self._group_by_clause, - ) - if x is not None - ] - ) - - def _cache_key(self, **kw): - return ( - (Select,) - + ("raw_columns",) - + tuple(elem._cache_key(**kw) for elem in self._raw_columns) - + ("elements",) - + tuple( - elem._cache_key(**kw) if elem is not None else None - for elem in ( - self._whereclause, - self._having, - self._order_by_clause, - self._group_by_clause, - ) - ) - + ("from_obj",) - + tuple(elem._cache_key(**kw) for elem in self._from_obj) - + ("correlate",) - + tuple( - elem._cache_key(**kw) - for elem in ( - self._correlate if self._correlate is not None else () - ) - ) - + ("correlate_except",) - + tuple( - elem._cache_key(**kw) - for elem in ( - self._correlate_except - if self._correlate_except is not None - else () - ) - ) - + ("for_update",), - ( - self._for_update_arg._cache_key(**kw) - if self._for_update_arg is not None - else None, - ), + # TODO: define "get_children" traversal items separately? + return self._froms + super(Select, self).get_children( + omit_attrs=["_from_obj", "_correlate", "_correlate_except"] ) @_generative @@ -3987,10 +3896,8 @@ class Select( """ if expr: expr = [coercions.expect(roles.ByOfRole, e) for e in expr] - if isinstance(self._distinct, list): - self._distinct = self._distinct + expr - else: - self._distinct = expr + self._distinct = True + self._distinct_on = self._distinct_on + tuple(expr) else: self._distinct = True @@ -4489,6 +4396,11 @@ class TextualSelect(SelectBase): __visit_name__ = "textual_select" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("column_args", InternalTraversal.dp_clauseelement_list), + ] + SupportsCloneAnnotations._traverse_internals + _is_textual = True def __init__(self, text, columns, positional=False): @@ -4534,18 +4446,6 @@ class TextualSelect(SelectBase): c._make_proxy(fromclause) for c in self.column_args ) - def _copy_internals(self, clone=_clone, **kw): - self._reset_memoizations() - self.element = clone(self.element, **kw) - - def get_children(self, **kw): - return [self.element] - - def _cache_key(self, **kw): - return (TextualSelect, self.element._cache_key(**kw)) + tuple( - col._cache_key(**kw) for col in self.column_args - ) - def _scalar_type(self): return self.column_args[0].type diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py new file mode 100644 index 0000000000..c0782ce486 --- /dev/null +++ b/lib/sqlalchemy/sql/traversals.py @@ -0,0 +1,768 @@ +from collections import deque +from collections import namedtuple + +from . import operators +from .visitors import ExtendedInternalTraversal +from .visitors import InternalTraversal +from .. import inspect +from .. import util + +SKIP_TRAVERSE = util.symbol("skip_traverse") +COMPARE_FAILED = False +COMPARE_SUCCEEDED = True +NO_CACHE = util.symbol("no_cache") + + +def compare(obj1, obj2, **kw): + if kw.get("use_proxies", False): + strategy = ColIdentityComparatorStrategy() + else: + strategy = TraversalComparatorStrategy() + + return strategy.compare(obj1, obj2, **kw) + + +class HasCacheKey(object): + _cache_key_traversal = NO_CACHE + + def _gen_cache_key(self, anon_map, bindparams): + """return an optional cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the strucures that would affect the SQL string or the type handlers + should result in a different cache key. + + If a structure cannot produce a useful cache key, it should raise + NotImplementedError, which will result in the entire structure + for which it's part of not being useful as a cache key. + + + """ + + if self in anon_map: + return (anon_map[self], self.__class__) + + id_ = anon_map[self] + + if self._cache_key_traversal is NO_CACHE: + anon_map[NO_CACHE] = True + return None + + result = (id_, self.__class__) + + for attrname, obj, meth in _cache_key_traversal.run_generated_dispatch( + self, self._cache_key_traversal, "_generated_cache_key_traversal" + ): + if obj is not None: + result += meth(attrname, obj, self, anon_map, bindparams) + return result + + def _generate_cache_key(self): + """return a cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the strucures that would affect the SQL string or the type handlers + should result in a different cache key. + + The cache key returned by this method is an instance of + :class:`.CacheKey`, which consists of a tuple representing the + cache key, as well as a list of :class:`.BindParameter` objects + which are extracted from the expression. While two expressions + that produce identical cache key tuples will themselves generate + identical SQL strings, the list of :class:`.BindParameter` objects + indicates the bound values which may have different values in + each one; these bound parameters must be consulted in order to + execute the statement with the correct parameters. + + a :class:`.ClauseElement` structure that does not implement + a :meth:`._gen_cache_key` method and does not implement a + :attr:`.traverse_internals` attribute will not be cacheable; when + such an element is embedded into a larger structure, this method + will return None, indicating no cache key is available. + + """ + bindparams = [] + + _anon_map = anon_map() + key = self._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + return CacheKey(key, bindparams) + + +class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): + def __hash__(self): + return hash(self.key) + + def __eq__(self, other): + return self.key == other.key + + +def _clone(element, **kw): + return element._clone() + + +class _CacheKey(ExtendedInternalTraversal): + def visit_has_cache_key(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj._gen_cache_key(anon_map, bindparams)) + + def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): + return self.visit_has_cache_key( + attrname, inspect(obj), parent, anon_map, bindparams + ) + + def visit_clauseelement(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj._gen_cache_key(anon_map, bindparams)) + + def visit_multi(self, attrname, obj, parent, anon_map, bindparams): + return ( + attrname, + obj._gen_cache_key(anon_map, bindparams) + if isinstance(obj, HasCacheKey) + else obj, + ) + + def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams): + return ( + attrname, + tuple( + elem._gen_cache_key(anon_map, bindparams) + if isinstance(elem, HasCacheKey) + else elem + for elem in obj + ), + ) + + def visit_has_cache_key_tuples( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in tup_elem + ) + for tup_elem in obj + ), + ) + + def visit_has_cache_key_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), + ) + + def visit_inspectable_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return self.visit_has_cache_key_list( + attrname, [inspect(o) for o in obj], parent, anon_map, bindparams + ) + + def visit_clauseelement_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), + ) + + def visit_clauseelement_tuples( + self, attrname, obj, parent, anon_map, bindparams + ): + return self.visit_has_cache_key_tuples( + attrname, obj, parent, anon_map, bindparams + ) + + def visit_anon_name(self, attrname, obj, parent, anon_map, bindparams): + from . import elements + + name = obj + if isinstance(name, elements._anonymous_label): + name = name.apply_map(anon_map) + + return (attrname, name) + + def visit_fromclause_ordered_set( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), + ) + + def visit_clauseelement_unordered_set( + self, attrname, obj, parent, anon_map, bindparams + ): + cache_keys = [ + elem._gen_cache_key(anon_map, bindparams) for elem in obj + ] + return ( + attrname, + tuple( + sorted(cache_keys) + ), # cache keys all start with (id_, class) + ) + + def visit_named_ddl_element( + self, attrname, obj, parent, anon_map, bindparams + ): + return (attrname, obj.name) + + def visit_prefix_sequence( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + (clause._gen_cache_key(anon_map, bindparams), strval) + for clause, strval in obj + ), + ) + + def visit_statement_hint_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return (attrname, obj) + + def visit_table_hint_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + clause._gen_cache_key(anon_map, bindparams), + dialect_name, + text, + ) + for (clause, dialect_name), text in obj.items() + ), + ) + + def visit_type(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj._gen_cache_key) + + def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, tuple((key, obj[key]) for key in sorted(obj))) + + def visit_string_clauseelement_dict( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + (key, obj[key]._gen_cache_key(anon_map, bindparams)) + for key in sorted(obj) + ), + ) + + def visit_string_multi_dict( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + key, + value._gen_cache_key(anon_map, bindparams) + if isinstance(value, HasCacheKey) + else value, + ) + for key, value in [(key, obj[key]) for key in sorted(obj)] + ), + ) + + def visit_string(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj) + + def visit_boolean(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj) + + def visit_operator(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj) + + def visit_plain_obj(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj) + + def visit_fromclause_canonical_column_collection( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple(col._gen_cache_key(anon_map, bindparams) for col in obj), + ) + + def visit_annotations_state( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + key, + self.dispatch(sym)( + key, obj[key], obj, anon_map, bindparams + ), + ) + for key, sym in parent._annotation_traversals + ), + ) + + def visit_unknown_structure( + self, attrname, obj, parent, anon_map, bindparams + ): + anon_map[NO_CACHE] = True + return () + + +_cache_key_traversal = _CacheKey() + + +class _CopyInternals(InternalTraversal): + """Generate a _copy_internals internal traversal dispatch for classes + with a _traverse_internals collection.""" + + def visit_clauseelement(self, parent, element, clone=_clone, **kw): + return clone(element, **kw) + + def visit_clauseelement_list(self, parent, element, clone=_clone, **kw): + return [clone(clause, **kw) for clause in element] + + def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw): + return [ + tuple(clone(tup_elem, **kw) for tup_elem in elem) + for elem in element + ] + + def visit_string_clauseelement_dict( + self, parent, element, clone=_clone, **kw + ): + return dict( + (key, clone(value, **kw)) for key, value in element.items() + ) + + +_copy_internals = _CopyInternals() + + +class _GetChildren(InternalTraversal): + """Generate a _children_traversal internal traversal dispatch for classes + with a _traverse_internals collection.""" + + def visit_has_cache_key(self, element, **kw): + return (element,) + + def visit_clauseelement(self, element, **kw): + return (element,) + + def visit_clauseelement_list(self, element, **kw): + return tuple(element) + + def visit_clauseelement_tuples(self, element, **kw): + tup = () + for elem in element: + tup += elem + return tup + + def visit_fromclause_canonical_column_collection(self, element, **kw): + if kw.get("column_collections", False): + return tuple(element) + else: + return () + + def visit_string_clauseelement_dict(self, element, **kw): + return tuple(element.values()) + + def visit_fromclause_ordered_set(self, element, **kw): + return tuple(element) + + def visit_clauseelement_unordered_set(self, element, **kw): + return tuple(element) + + +_get_children = _GetChildren() + + +@util.dependencies("sqlalchemy.sql.elements") +def _resolve_name_for_compare(elements, element, name, anon_map, **kw): + if isinstance(name, elements._anonymous_label): + name = name.apply_map(anon_map) + + return name + + +class anon_map(dict): + """A map that creates new keys for missing key access. + + Produces an incrementing sequence given a series of unique keys. + + This is similar to the compiler prefix_anon_map class although simpler. + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + def __init__(self): + self.index = 0 + + def __missing__(self, key): + self[key] = val = str(self.index) + self.index += 1 + return val + + +class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): + __slots__ = "stack", "cache", "anon_map" + + def __init__(self): + self.stack = deque() + self.cache = set() + + def _memoized_attr_anon_map(self): + return (anon_map(), anon_map()) + + def compare(self, obj1, obj2, **kw): + stack = self.stack + cache = self.cache + + compare_annotations = kw.get("compare_annotations", False) + + stack.append((obj1, obj2)) + + while stack: + left, right = stack.popleft() + + if left is right: + continue + elif left is None or right is None: + # we know they are different so no match + return False + elif (left, right) in cache: + continue + cache.add((left, right)) + + visit_name = left.__visit_name__ + if visit_name != right.__visit_name__: + return False + + meth = getattr(self, "compare_%s" % visit_name, None) + + if meth: + attributes_compared = meth(left, right, **kw) + if attributes_compared is COMPARE_FAILED: + return False + elif attributes_compared is SKIP_TRAVERSE: + continue + + # attributes_compared is returned as a list of attribute + # names that were "handled" by the comparison method above. + # remaining attribute names in the _traverse_internals + # will be compared. + else: + attributes_compared = () + + for ( + (left_attrname, left_visit_sym), + (right_attrname, right_visit_sym), + ) in util.zip_longest( + left._traverse_internals, + right._traverse_internals, + fillvalue=(None, None), + ): + if ( + left_attrname != right_attrname + or left_visit_sym is not right_visit_sym + ): + if not compare_annotations and ( + ( + left_visit_sym + is InternalTraversal.dp_annotations_state, + ) + or ( + right_visit_sym + is InternalTraversal.dp_annotations_state, + ) + ): + continue + + return False + elif left_attrname in attributes_compared: + continue + + dispatch = self.dispatch(left_visit_sym) + left_child = getattr(left, left_attrname) + right_child = getattr(right, right_attrname) + if left_child is None: + if right_child is not None: + return False + else: + continue + + comparison = dispatch( + left, left_child, right, right_child, **kw + ) + if comparison is COMPARE_FAILED: + return False + + return True + + def compare_inner(self, obj1, obj2, **kw): + comparator = self.__class__() + return comparator.compare(obj1, obj2, **kw) + + def visit_has_cache_key( + self, left_parent, left, right_parent, right, **kw + ): + if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key( + self.anon_map[1], [] + ): + return COMPARE_FAILED + + def visit_clauseelement( + self, left_parent, left, right_parent, right, **kw + ): + self.stack.append((left, right)) + + def visit_fromclause_canonical_column_collection( + self, left_parent, left, right_parent, right, **kw + ): + for lcol, rcol in util.zip_longest(left, right, fillvalue=None): + self.stack.append((lcol, rcol)) + + def visit_fromclause_derived_column_collection( + self, left_parent, left, right_parent, right, **kw + ): + pass + + def visit_string_clauseelement_dict( + self, left_parent, left, right_parent, right, **kw + ): + for lstr, rstr in util.zip_longest( + sorted(left), sorted(right), fillvalue=None + ): + if lstr != rstr: + return COMPARE_FAILED + self.stack.append((left[lstr], right[rstr])) + + def visit_annotations_state( + self, left_parent, left, right_parent, right, **kw + ): + if not kw.get("compare_annotations", False): + return + + for (lstr, lmeth), (rstr, rmeth) in util.zip_longest( + left_parent._annotation_traversals, + right_parent._annotation_traversals, + fillvalue=(None, None), + ): + if lstr != rstr or (lmeth is not rmeth): + return COMPARE_FAILED + + dispatch = self.dispatch(lmeth) + left_child = left[lstr] + right_child = right[rstr] + if left_child is None: + if right_child is not None: + return False + else: + continue + + comparison = dispatch(None, left_child, None, right_child, **kw) + if comparison is COMPARE_FAILED: + return comparison + + def visit_clauseelement_tuples( + self, left_parent, left, right_parent, right, **kw + ): + for ltup, rtup in util.zip_longest(left, right, fillvalue=None): + if ltup is None or rtup is None: + return COMPARE_FAILED + + for l, r in util.zip_longest(ltup, rtup, fillvalue=None): + self.stack.append((l, r)) + + def visit_clauseelement_list( + self, left_parent, left, right_parent, right, **kw + ): + for l, r in util.zip_longest(left, right, fillvalue=None): + self.stack.append((l, r)) + + def _compare_unordered_sequences(self, seq1, seq2, **kw): + if seq1 is None: + return seq2 is None + + completed = set() + for clause in seq1: + for other_clause in set(seq2).difference(completed): + if self.compare_inner(clause, other_clause, **kw): + completed.add(other_clause) + break + return len(completed) == len(seq1) == len(seq2) + + def visit_clauseelement_unordered_set( + self, left_parent, left, right_parent, right, **kw + ): + return self._compare_unordered_sequences(left, right, **kw) + + def visit_fromclause_ordered_set( + self, left_parent, left, right_parent, right, **kw + ): + for l, r in util.zip_longest(left, right, fillvalue=None): + self.stack.append((l, r)) + + def visit_string(self, left_parent, left, right_parent, right, **kw): + return left == right + + def visit_anon_name(self, left_parent, left, right_parent, right, **kw): + return _resolve_name_for_compare( + left_parent, left, self.anon_map[0], **kw + ) == _resolve_name_for_compare( + right_parent, right, self.anon_map[1], **kw + ) + + def visit_boolean(self, left_parent, left, right_parent, right, **kw): + return left == right + + def visit_operator(self, left_parent, left, right_parent, right, **kw): + return left is right + + def visit_type(self, left_parent, left, right_parent, right, **kw): + return left._compare_type_affinity(right) + + def visit_plain_dict(self, left_parent, left, right_parent, right, **kw): + return left == right + + def visit_plain_obj(self, left_parent, left, right_parent, right, **kw): + return left == right + + def visit_named_ddl_element( + self, left_parent, left, right_parent, right, **kw + ): + if left is None: + if right is not None: + return COMPARE_FAILED + + return left.name == right.name + + def visit_prefix_sequence( + self, left_parent, left, right_parent, right, **kw + ): + for (l_clause, l_str), (r_clause, r_str) in util.zip_longest( + left, right, fillvalue=(None, None) + ): + if l_str != r_str: + return COMPARE_FAILED + else: + self.stack.append((l_clause, r_clause)) + + def visit_table_hint_list( + self, left_parent, left, right_parent, right, **kw + ): + left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1])) + right_keys = sorted( + right, key=lambda elem: (elem[0].fullname, elem[1]) + ) + for (ltable, ldialect), (rtable, rdialect) in util.zip_longest( + left_keys, right_keys, fillvalue=(None, None) + ): + if ldialect != rdialect: + return COMPARE_FAILED + elif left[(ltable, ldialect)] != right[(rtable, rdialect)]: + return COMPARE_FAILED + else: + self.stack.append((ltable, rtable)) + + def visit_statement_hint_list( + self, left_parent, left, right_parent, right, **kw + ): + return left == right + + def visit_unknown_structure( + self, left_parent, left, right_parent, right, **kw + ): + raise NotImplementedError() + + def compare_clauselist(self, left, right, **kw): + if left.operator is right.operator: + if operators.is_associative(left.operator): + if self._compare_unordered_sequences( + left.clauses, right.clauses, **kw + ): + return ["operator", "clauses"] + else: + return COMPARE_FAILED + else: + return ["operator"] + else: + return COMPARE_FAILED + + def compare_binary(self, left, right, **kw): + if left.operator == right.operator: + if operators.is_commutative(left.operator): + if ( + compare(left.left, right.left, **kw) + and compare(left.right, right.right, **kw) + ) or ( + compare(left.left, right.right, **kw) + and compare(left.right, right.left, **kw) + ): + return ["operator", "negate", "left", "right"] + else: + return COMPARE_FAILED + else: + return ["operator", "negate"] + else: + return COMPARE_FAILED + + +class ColIdentityComparatorStrategy(TraversalComparatorStrategy): + def compare_column_element( + self, left, right, use_proxies=True, equivalents=(), **kw + ): + """Compare ColumnElements using proxies and equivalent collections. + + This is a comparison strategy specific to the ORM. + """ + + to_compare = (right,) + if equivalents and right in equivalents: + to_compare = equivalents[right].union(to_compare) + + for oth in to_compare: + if use_proxies and left.shares_lineage(oth): + return SKIP_TRAVERSE + elif hash(left) == hash(right): + return SKIP_TRAVERSE + else: + return COMPARE_FAILED + + def compare_column(self, left, right, **kw): + return self.compare_column_element(left, right, **kw) + + def compare_label(self, left, right, **kw): + return self.compare_column_element(left, right, **kw) + + def compare_table(self, left, right, **kw): + # tables compare on identity, since it's not really feasible to + # compare them column by column with the above rules + return SKIP_TRAVERSE if left is right else COMPARE_FAILED diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 9c5f5dd475..d09bb28bbe 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -12,8 +12,8 @@ from . import operators from .base import SchemaEventTarget -from .visitors import Visitable -from .visitors import VisitableType +from .visitors import Traversible +from .visitors import TraversibleType from .. import exc from .. import util @@ -28,7 +28,7 @@ INDEXABLE = None _resolve_value_to_type = None -class TypeEngine(Visitable): +class TypeEngine(Traversible): """The ultimate base class for all SQL datatypes. Common subclasses of :class:`.TypeEngine` include @@ -535,8 +535,13 @@ class TypeEngine(Visitable): return dialect.type_descriptor(self) @util.memoized_property - def _cache_key(self): - return util.constructor_key(self, self.__class__) + def _gen_cache_key(self): + names = util.get_cls_kwargs(self.__class__) + return (self.__class__,) + tuple( + (k, self.__dict__[k]) + for k in names + if k in self.__dict__ and not k.startswith("_") + ) def adapt(self, cls, **kw): """Produce an "adapted" form of this type, given an "impl" class @@ -617,7 +622,7 @@ class TypeEngine(Visitable): return util.generic_repr(self) -class VisitableCheckKWArg(util.EnsureKWArgType, VisitableType): +class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType): pass diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index e109852a2b..8539f4845a 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -734,7 +734,7 @@ def criterion_as_pairs( return pairs -class ClauseAdapter(visitors.ReplacingCloningVisitor): +class ClauseAdapter(visitors.ReplacingExternalTraversal): """Clones and modifies clauses based on column correspondence. E.g.:: diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 7b2ac285a9..8c06eb8afd 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -28,14 +28,10 @@ import operator from .. import exc from .. import util - +from ..util import langhelpers +from ..util import symbol __all__ = [ - "VisitableType", - "Visitable", - "ClauseVisitor", - "CloningVisitor", - "ReplacingCloningVisitor", "iterate", "iterate_depthfirst", "traverse_using", @@ -43,85 +39,382 @@ __all__ = [ "traverse_depthfirst", "cloned_traverse", "replacement_traverse", + "Traversible", + "TraversibleType", + "ExternalTraversal", + "InternalTraversal", ] -class VisitableType(type): - """Metaclass which assigns a ``_compiler_dispatch`` method to classes - having a ``__visit_name__`` attribute. +def _generate_compiler_dispatch(cls): + """Generate a _compiler_dispatch() external traversal on classes with a + __visit_name__ attribute. + + """ + visit_name = cls.__visit_name__ + + if isinstance(visit_name, util.compat.string_types): + # There is an optimization opportunity here because the + # the string name of the class's __visit_name__ is known at + # this early stage (import time) so it can be pre-constructed. + getter = operator.attrgetter("visit_%s" % visit_name) + + def _compiler_dispatch(self, visitor, **kw): + try: + meth = getter(visitor) + except AttributeError: + raise exc.UnsupportedCompilationError(visitor, cls) + else: + return meth(self, **kw) + + else: + # The optimization opportunity is lost for this case because the + # __visit_name__ is not yet a string. As a result, the visit + # string has to be recalculated with each compilation. + def _compiler_dispatch(self, visitor, **kw): + visit_attr = "visit_%s" % self.__visit_name__ + try: + meth = getattr(visitor, visit_attr) + except AttributeError: + raise exc.UnsupportedCompilationError(visitor, cls) + else: + return meth(self, **kw) + + _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + + self.__visit_name__ on the visitor, and call it with the same + kw params. + """ + cls._compiler_dispatch = _compiler_dispatch + + +class TraversibleType(type): + """Metaclass which assigns dispatch attributes to various kinds of + "visitable" classes. - The ``_compiler_dispatch`` attribute becomes an instance method which - looks approximately like the following:: + Attributes include: - def _compiler_dispatch (self, visitor, **kw): - '''Look for an attribute named "visit_" + self.__visit_name__ - on the visitor, and call it with the same kw params.''' - visit_attr = 'visit_%s' % self.__visit_name__ - return getattr(visitor, visit_attr)(self, **kw) + * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``. + This is called "external traversal" because the caller of each visit() + method is responsible for sub-traversing the inner elements of each + object. This is appropriate for string compilers and other traversals + that need to call upon the inner elements in a specific pattern. - Classes having no ``__visit_name__`` attribute will remain unaffected. + * internal traversal collections ``_children_traversal``, + ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from + an optional ``_traverse_internals`` collection of symbols which comes + from the :class:`.InternalTraversal` list of symbols. This is called + "internal traversal" MARKMARK """ def __init__(cls, clsname, bases, clsdict): - if clsname != "Visitable" and hasattr(cls, "__visit_name__"): - _generate_dispatch(cls) + if clsname != "Traversible": + if "__visit_name__" in clsdict: + _generate_compiler_dispatch(cls) + + super(TraversibleType, cls).__init__(clsname, bases, clsdict) - super(VisitableType, cls).__init__(clsname, bases, clsdict) +class Traversible(util.with_metaclass(TraversibleType)): + """Base class for visitable objects, applies the + :class:`.visitors.TraversibleType` metaclass. -def _generate_dispatch(cls): - """Return an optimized visit dispatch function for the cls - for use by the compiler. """ - if "__visit_name__" in cls.__dict__: - visit_name = cls.__visit_name__ - if isinstance(visit_name, util.compat.string_types): - # There is an optimization opportunity here because the - # the string name of the class's __visit_name__ is known at - # this early stage (import time) so it can be pre-constructed. - getter = operator.attrgetter("visit_%s" % visit_name) - def _compiler_dispatch(self, visitor, **kw): - try: - meth = getter(visitor) - except AttributeError: - raise exc.UnsupportedCompilationError(visitor, cls) - else: - return meth(self, **kw) +class _InternalTraversalType(type): + def __init__(cls, clsname, bases, clsdict): + if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"): + lookup = {} + for key, sym in clsdict.items(): + if key.startswith("dp_"): + visit_key = key.replace("dp_", "visit_") + sym_name = sym.name + assert sym_name not in lookup, sym_name + lookup[sym] = lookup[sym_name] = visit_key + if hasattr(cls, "_dispatch_lookup"): + lookup.update(cls._dispatch_lookup) + cls._dispatch_lookup = lookup + + super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict) + + +def _generate_dispatcher(visitor, internal_dispatch, method_name): + names = [] + for attrname, visit_sym in internal_dispatch: + meth = visitor.dispatch(visit_sym) + if meth: + visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym] + names.append((attrname, visit_name)) + + code = ( + (" return [\n") + + ( + ", \n".join( + " (%r, self.%s, visitor.%s)" + % (attrname, attrname, visit_name) + for attrname, visit_name in names + ) + ) + + ("\n ]\n") + ) + meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" + # print(meth_text) + return langhelpers._exec_code_in_env(meth_text, {}, method_name) - else: - # The optimization opportunity is lost for this case because the - # __visit_name__ is not yet a string. As a result, the visit - # string has to be recalculated with each compilation. - def _compiler_dispatch(self, visitor, **kw): - visit_attr = "visit_%s" % self.__visit_name__ - try: - meth = getattr(visitor, visit_attr) - except AttributeError: - raise exc.UnsupportedCompilationError(visitor, cls) - else: - return meth(self, **kw) - - _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__ - on the visitor, and call it with the same kw params. - """ - cls._compiler_dispatch = _compiler_dispatch - - -class Visitable(util.with_metaclass(VisitableType, object)): - """Base class for visitable objects, applies the - :class:`.visitors.VisitableType` metaclass. - The :class:`.Visitable` class is essentially at the base of the - :class:`.ClauseElement` hierarchy. +class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): + r"""Defines visitor symbols used for internal traversal. + + The :class:`.InternalTraversal` class is used in two ways. One is that + it can serve as the superclass for an object that implements the + various visit methods of the class. The other is that the symbols + themselves of :class:`.InternalTraversal` are used within + the ``_traverse_internals`` collection. Such as, the :class:`.Case` + object defines ``_travserse_internals`` as :: + + _traverse_internals = [ + ("value", InternalTraversal.dp_clauseelement), + ("whens", InternalTraversal.dp_clauseelement_tuples), + ("else_", InternalTraversal.dp_clauseelement), + ] + + Above, the :class:`.Case` class indicates its internal state as the + attribtues named ``value``, ``whens``, and ``else\_``. They each + link to an :class:`.InternalTraversal` method which indicates the type + of datastructure referred towards. + + Using the ``_traverse_internals`` structure, objects of type + :class:`.InternalTraversible` will have the following methods automatically + implemented: + + * :meth:`.Traversible.get_children` + + * :meth:`.Traversible._copy_internals` + + * :meth:`.Traversible._gen_cache_key` + + Subclasses can also implement these methods directly, particularly for the + :meth:`.Traversible._copy_internals` method, when special steps + are needed. + + .. versionadded:: 1.4 """ + def dispatch(self, visit_symbol): + """Given a method from :class:`.InternalTraversal`, return the + corresponding method on a subclass. -class ClauseVisitor(object): - """Base class for visitor objects which can traverse using + """ + name = self._dispatch_lookup[visit_symbol] + return getattr(self, name, None) + + def run_generated_dispatch( + self, target, internal_dispatch, generate_dispatcher_name + ): + try: + dispatcher = target.__class__.__dict__[generate_dispatcher_name] + except KeyError: + dispatcher = _generate_dispatcher( + self, internal_dispatch, generate_dispatcher_name + ) + setattr(target.__class__, generate_dispatcher_name, dispatcher) + return dispatcher(target, self) + + dp_has_cache_key = symbol("HC") + """Visit a :class:`.HasCacheKey` object.""" + + dp_clauseelement = symbol("CE") + """Visit a :class:`.ClauseElement` object.""" + + dp_fromclause_canonical_column_collection = symbol("FC") + """Visit a :class:`.FromClause` object in the context of the + ``columns`` attribute. + + The column collection is "canonical", meaning it is the originally + defined location of the :class:`.ColumnClause` objects. Right now + this means that the object being visited is a :class:`.TableClause` + or :class:`.Table` object only. + + """ + + dp_clauseelement_tuples = symbol("CT") + """Visit a list of tuples which contain :class:`.ClauseElement` + objects. + + """ + + dp_clauseelement_list = symbol("CL") + """Visit a list of :class:`.ClauseElement` objects. + + """ + + dp_clauseelement_unordered_set = symbol("CU") + """Visit an unordered set of :class:`.ClauseElement` objects. """ + + dp_fromclause_ordered_set = symbol("CO") + """Visit an ordered set of :class:`.FromClause` objects. """ + + dp_string = symbol("S") + """Visit a plain string value. + + Examples include table and column names, bound parameter keys, special + keywords such as "UNION", "UNION ALL". + + The string value is considered to be significant for cache key + generation. + + """ + + dp_anon_name = symbol("AN") + """Visit a potentially "anonymized" string value. + + The string value is considered to be significant for cache key + generation. + + """ + + dp_boolean = symbol("B") + """Visit a boolean value. + + The boolean value is considered to be significant for cache key + generation. + + """ + + dp_operator = symbol("O") + """Visit an operator. + + The operator is a function from the :mod:`sqlalchemy.sql.operators` + module. + + The operator value is considered to be significant for cache key + generation. + + """ + + dp_type = symbol("T") + """Visit a :class:`.TypeEngine` object + + The type object is considered to be significant for cache key + generation. + + """ + + dp_plain_dict = symbol("PD") + """Visit a dictionary with string keys. + + The keys of the dictionary should be strings, the values should + be immutable and hashable. The dictionary is considered to be + significant for cache key generation. + + """ + + dp_string_clauseelement_dict = symbol("CD") + """Visit a dictionary of string keys to :class:`.ClauseElement` + objects. + + """ + + dp_string_multi_dict = symbol("MD") + """Visit a dictionary of string keys to values which may either be + plain immutable/hashable or :class:`.HasCacheKey` objects. + + """ + + dp_plain_obj = symbol("PO") + """Visit a plain python object. + + The value should be immutable and hashable, such as an integer. + The value is considered to be significant for cache key generation. + + """ + + dp_annotations_state = symbol("A") + """Visit the state of the :class:`.Annotatated` version of an object. + + """ + + dp_named_ddl_element = symbol("DD") + """Visit a simple named DDL element. + + The current object used by this method is the :class:`.Sequence`. + + The object is only considered to be important for cache key generation + as far as its name, but not any other aspects of it. + + """ + + dp_prefix_sequence = symbol("PS") + """Visit the sequence represented by :class:`.HasPrefixes` + or :class:`.HasSuffixes`. + + """ + + dp_table_hint_list = symbol("TH") + """Visit the ``_hints`` collection of a :class:`.Select` object. + + """ + + dp_statement_hint_list = symbol("SH") + """Visit the ``_statement_hints`` collection of a :class:`.Select` + object. + + """ + + dp_unknown_structure = symbol("UK") + """Visit an unknown structure. + + """ + + +class ExtendedInternalTraversal(InternalTraversal): + """defines additional symbols that are useful in caching applications. + + Traversals for :class:`.ClauseElement` objects only need to use + those symbols present in :class:`.InternalTraversal`. However, for + additional caching use cases within the ORM, symbols dealing with the + :class:`.HasCacheKey` class are added here. + + """ + + dp_ignore = symbol("IG") + """Specify an object that should be ignored entirely. + + This currently applies function call argument caching where some + arguments should not be considered to be part of a cache key. + + """ + + dp_inspectable = symbol("IS") + """Visit an inspectable object where the return value is a HasCacheKey` + object.""" + + dp_multi = symbol("M") + """Visit an object that may be a :class:`.HasCacheKey` or may be a + plain hashable object.""" + + dp_multi_list = symbol("MT") + """Visit a tuple containing elements that may be :class:`.HasCacheKey` or + may be a plain hashable object.""" + + dp_has_cache_key_tuples = symbol("HT") + """Visit a list of tuples which contain :class:`.HasCacheKey` + objects. + + """ + + dp_has_cache_key_list = symbol("HL") + """Visit a list of :class:`.HasCacheKey` objects.""" + + dp_inspectable_list = symbol("IL") + """Visit a list of inspectable objects which upon inspection are + HasCacheKey objects.""" + + +class ExternalTraversal(object): + """Base class for visitor objects which can traverse externally using the :func:`.visitors.traverse` function. Direct usage of the :func:`.visitors.traverse` function is usually @@ -178,7 +471,7 @@ class ClauseVisitor(object): return self -class CloningVisitor(ClauseVisitor): +class CloningExternalTraversal(ExternalTraversal): """Base class for visitor objects which can traverse using the :func:`.visitors.cloned_traverse` function. @@ -203,7 +496,7 @@ class CloningVisitor(ClauseVisitor): ) -class ReplacingCloningVisitor(CloningVisitor): +class ReplacingExternalTraversal(CloningExternalTraversal): """Base class for visitor objects which can traverse using the :func:`.visitors.replacement_traverse` function. @@ -233,6 +526,14 @@ class ReplacingCloningVisitor(CloningVisitor): return replacement_traverse(obj, self.__traverse_options__, replace) +# backwards compatibility +Visitable = Traversible +VisitableType = TraversibleType +ClauseVisitor = ExternalTraversal +CloningVisitor = CloningExternalTraversal +ReplacingCloningVisitor = ReplacingExternalTraversal + + def iterate(obj, opts): r"""traverse the given expression structure, returning an iterator. @@ -405,11 +706,18 @@ def cloned_traverse(obj, opts, visitors): cloned = {} stop_on = set(opts.get("stop_on", [])) - def clone(elem): + def clone(elem, **kw): if elem in stop_on: return elem else: if id(elem) not in cloned: + + if "replace" in kw: + newelem = kw["replace"](elem) + if newelem is not None: + cloned[id(elem)] = newelem + return newelem + cloned[id(elem)] = newelem = elem._clone() newelem._copy_internals(clone=clone) meth = visitors.get(newelem.__visit_name__, None) @@ -461,7 +769,14 @@ def replacement_traverse(obj, opts, replace): stop_on.add(id(newelem)) return newelem else: + if elem not in cloned: + if "replace" in kw: + newelem = kw["replace"](elem) + if newelem is not None: + cloned[elem] = newelem + return newelem + cloned[elem] = newelem = elem._clone() newelem._copy_internals(clone=clone, **kw) return cloned[elem] diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 632f559373..209bc02e3f 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -934,7 +934,7 @@ class BranchedOptionTest(fixtures.MappedTest): configure_mappers() - def test_generate_cache_key_unbound_branching(self): + def test_generate_path_cache_key_unbound_branching(self): A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") base = joinedload(A.bs) @@ -950,11 +950,11 @@ class BranchedOptionTest(fixtures.MappedTest): @profiling.function_call_count() def go(): for opt in opts: - opt._generate_cache_key(cache_path) + opt._generate_path_cache_key(cache_path) go() - def test_generate_cache_key_bound_branching(self): + def test_generate_path_cache_key_bound_branching(self): A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") base = Load(A).joinedload(A.bs) @@ -970,7 +970,7 @@ class BranchedOptionTest(fixtures.MappedTest): @profiling.function_call_count() def go(): for opt in opts: - opt._generate_cache_key(cache_path) + opt._generate_path_cache_key(cache_path) go() diff --git a/test/ext/test_baked.py b/test/ext/test_baked.py index acefe625ab..01f0e267f2 100644 --- a/test/ext/test_baked.py +++ b/test/ext/test_baked.py @@ -1533,7 +1533,7 @@ class CustomIntegrationTest(testing.AssertsCompiledSQL, BakedTest): if query._current_path: query._cache_key = "user7_addresses" - def _generate_cache_key(self, path): + def _generate_path_cache_key(self, path): return None return RelationshipCache() diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py new file mode 100644 index 0000000000..79a94848ea --- /dev/null +++ b/test/orm/test_cache_key.py @@ -0,0 +1,120 @@ +from sqlalchemy import inspect +from sqlalchemy.orm import aliased +from sqlalchemy.orm import defaultload +from sqlalchemy.orm import defer +from sqlalchemy.orm import joinedload +from sqlalchemy.orm import Load +from sqlalchemy.orm import subqueryload +from sqlalchemy.testing import eq_ +from test.orm import _fixtures +from ..sql.test_compare import CacheKeyFixture + + +class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): + run_setup_mappers = "once" + run_inserts = None + run_deletes = None + + @classmethod + def setup_mappers(cls): + cls._setup_stock_mapping() + + def test_mapper_and_aliased(self): + User, Address, Keyword = self.classes("User", "Address", "Keyword") + + self._run_cache_key_fixture( + lambda: (inspect(User), inspect(Address), inspect(aliased(User))) + ) + + def test_attributes(self): + User, Address, Keyword = self.classes("User", "Address", "Keyword") + + self._run_cache_key_fixture( + lambda: ( + User.id, + Address.id, + aliased(User).id, + aliased(User, name="foo").id, + aliased(User, name="bar").id, + User.name, + User.addresses, + Address.email_address, + aliased(User).addresses, + ) + ) + + def test_unbound_options(self): + User, Address, Keyword, Order, Item = self.classes( + "User", "Address", "Keyword", "Order", "Item" + ) + + self._run_cache_key_fixture( + lambda: ( + joinedload(User.addresses), + joinedload("addresses"), + joinedload(User.orders).selectinload("items"), + joinedload(User.orders).selectinload(Order.items), + defer(User.id), + defer("id"), + defer(Address.id), + joinedload(User.addresses).defer(Address.id), + joinedload(aliased(User).addresses).defer(Address.id), + joinedload(User.addresses).defer("id"), + joinedload(User.orders).joinedload(Order.items), + joinedload(User.orders).subqueryload(Order.items), + subqueryload(User.orders).subqueryload(Order.items), + subqueryload(User.orders) + .subqueryload(Order.items) + .defer(Item.description), + defaultload(User.orders).defaultload(Order.items), + defaultload(User.orders), + ) + ) + + def test_bound_options(self): + User, Address, Keyword, Order, Item = self.classes( + "User", "Address", "Keyword", "Order", "Item" + ) + + self._run_cache_key_fixture( + lambda: ( + Load(User).joinedload(User.addresses), + Load(User).joinedload(User.orders), + Load(User).defer(User.id), + Load(User).subqueryload("addresses"), + Load(Address).defer("id"), + Load(aliased(Address)).defer("id"), + Load(User).joinedload(User.addresses).defer(Address.id), + Load(User).joinedload(User.orders).joinedload(Order.items), + Load(User).joinedload(User.orders).subqueryload(Order.items), + Load(User).subqueryload(User.orders).subqueryload(Order.items), + Load(User) + .subqueryload(User.orders) + .subqueryload(Order.items) + .defer(Item.description), + Load(User).defaultload(User.orders).defaultload(Order.items), + Load(User).defaultload(User.orders), + ) + ) + + def test_bound_options_equiv_on_strname(self): + """Bound loader options resolve on string name so test that the cache + key for the string version matches the resolved version. + + """ + User, Address, Keyword, Order, Item = self.classes( + "User", "Address", "Keyword", "Order", "Item" + ) + + for left, right in [ + (Load(User).defer(User.id), Load(User).defer("id")), + ( + Load(User).joinedload(User.addresses), + Load(User).joinedload("addresses"), + ), + ( + Load(User).joinedload(User.orders).joinedload(Order.items), + Load(User).joinedload("orders").joinedload("items"), + ), + ]: + eq_(left._generate_cache_key(), right._generate_cache_key()) diff --git a/test/orm/test_options.py b/test/orm/test_options.py index bf099e7e6d..e84d5950c8 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -1790,7 +1790,7 @@ class SubOptionsTest(PathTest, QueryTest): ) -class CacheKeyTest(PathTest, QueryTest): +class PathedCacheKeyTest(PathTest, QueryTest): run_create_tables = False run_inserts = None @@ -1805,7 +1805,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = joinedload(User.orders).joinedload(Order.items) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), (((Order, "items", Item, ("lazy", "joined")),)), ) @@ -1821,12 +1821,12 @@ class CacheKeyTest(PathTest, QueryTest): opt2 = base.joinedload(Order.address) eq_( - opt1._generate_cache_key(query_path), + opt1._generate_path_cache_key(query_path), (((Order, "items", Item, ("lazy", "joined")),)), ) eq_( - opt2._generate_cache_key(query_path), + opt2._generate_path_cache_key(query_path), (((Order, "address", Address, ("lazy", "joined")),)), ) @@ -1842,12 +1842,12 @@ class CacheKeyTest(PathTest, QueryTest): opt2 = base.joinedload(Order.address) eq_( - opt1._generate_cache_key(query_path), + opt1._generate_path_cache_key(query_path), (((Order, "items", Item, ("lazy", "joined")),)), ) eq_( - opt2._generate_cache_key(query_path), + opt2._generate_path_cache_key(query_path), (((Order, "address", Address, ("lazy", "joined")),)), ) @@ -1860,7 +1860,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = Load(User).joinedload(User.orders).joinedload(Order.items) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), (((Order, "items", Item, ("lazy", "joined")),)), ) @@ -1872,7 +1872,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "addresses"]) opt = joinedload(User.orders).joinedload(Order.items) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_bound_cache_key_excluded_on_other(self): User, Address, Order, Item, SubItem = self.classes( @@ -1882,7 +1882,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "addresses"]) opt = Load(User).joinedload(User.orders).joinedload(Order.items) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_unbound_cache_key_excluded_on_aliased(self): User, Address, Order, Item, SubItem = self.classes( @@ -1901,7 +1901,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "orders"]) opt = joinedload(aliased(User).orders).joinedload(Order.items) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_bound_cache_key_wildcard_one(self): # do not change this test, it is testing @@ -1911,7 +1911,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "addresses"]) opt = Load(User).lazyload("*") - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_unbound_cache_key_wildcard_one(self): User, Address = self.classes("User", "Address") @@ -1920,7 +1920,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = lazyload("*") eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), (("relationship:_sa_default", ("lazy", "select")),), ) @@ -1933,7 +1933,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = Load(User).lazyload("orders").lazyload("*") eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( ("orders", Order, ("lazy", "select")), ("orders", Order, "relationship:*", ("lazy", "select")), @@ -1949,7 +1949,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = lazyload("orders").lazyload("*") eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( ("orders", Order, ("lazy", "select")), ("orders", Order, "relationship:*", ("lazy", "select")), @@ -1968,7 +1968,7 @@ class CacheKeyTest(PathTest, QueryTest): ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( (SubItem, ("lazy", "subquery")), ("extra_keywords", Keyword, ("lazy", "subquery")), @@ -1987,7 +1987,7 @@ class CacheKeyTest(PathTest, QueryTest): ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( (SubItem, ("lazy", "subquery")), ("extra_keywords", Keyword, ("lazy", "subquery")), @@ -2008,7 +2008,7 @@ class CacheKeyTest(PathTest, QueryTest): ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( (SubItem, ("lazy", "subquery")), ("extra_keywords", Keyword, ("lazy", "subquery")), @@ -2029,7 +2029,7 @@ class CacheKeyTest(PathTest, QueryTest): ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( (SubItem, ("lazy", "subquery")), ("extra_keywords", Keyword, ("lazy", "subquery")), @@ -2056,7 +2056,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = subqueryload(User.orders).subqueryload( Order.items.of_type(SubItem) ) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_unbound_cache_key_excluded_of_type_unsafe(self): User, Address, Order, Item, SubItem = self.classes( @@ -2078,7 +2078,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = subqueryload(User.orders).subqueryload( Order.items.of_type(aliased(SubItem)) ) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_bound_cache_key_excluded_of_type_safe(self): User, Address, Order, Item, SubItem = self.classes( @@ -2102,7 +2102,7 @@ class CacheKeyTest(PathTest, QueryTest): .subqueryload(User.orders) .subqueryload(Order.items.of_type(SubItem)) ) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_bound_cache_key_excluded_of_type_unsafe(self): User, Address, Order, Item, SubItem = self.classes( @@ -2126,7 +2126,7 @@ class CacheKeyTest(PathTest, QueryTest): .subqueryload(User.orders) .subqueryload(Order.items.of_type(aliased(SubItem))) ) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_unbound_cache_key_included_of_type_safe(self): User, Address, Order, Item, SubItem = self.classes( @@ -2137,7 +2137,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = joinedload(User.orders).joinedload(Order.items.of_type(SubItem)) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ((Order, "items", SubItem, ("lazy", "joined")),), ) @@ -2155,7 +2155,7 @@ class CacheKeyTest(PathTest, QueryTest): ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ((Order, "items", SubItem, ("lazy", "joined")),), ) @@ -2169,7 +2169,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = joinedload(User.orders).joinedload( Order.items.of_type(aliased(SubItem)) ) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_unbound_cache_key_included_unsafe_option_two(self): User, Address, Order, Item, SubItem = self.classes( @@ -2181,7 +2181,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = joinedload(User.orders).joinedload( Order.items.of_type(aliased(SubItem)) ) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_unbound_cache_key_included_unsafe_option_three(self): User, Address, Order, Item, SubItem = self.classes( @@ -2193,7 +2193,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = joinedload(User.orders).joinedload( Order.items.of_type(aliased(SubItem)) ) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_unbound_cache_key_included_unsafe_query(self): User, Address, Order, Item, SubItem = self.classes( @@ -2204,7 +2204,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([inspect(au), "orders"]) opt = joinedload(au.orders).joinedload(Order.items) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_unbound_cache_key_included_safe_w_deferred(self): User, Address, Order, Item, SubItem = self.classes( @@ -2219,7 +2219,7 @@ class CacheKeyTest(PathTest, QueryTest): .defer(Address.user_id) ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( ( Address, @@ -2247,12 +2247,12 @@ class CacheKeyTest(PathTest, QueryTest): ) eq_( - opt1._generate_cache_key(query_path), + opt1._generate_path_cache_key(query_path), ((Order, "items", Item, ("lazy", "joined")),), ) eq_( - opt2._generate_cache_key(query_path), + opt2._generate_path_cache_key(query_path), ( (Order, "address", Address, ("lazy", "joined")), ( @@ -2288,7 +2288,7 @@ class CacheKeyTest(PathTest, QueryTest): .defer(Address.user_id) ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( ( Address, @@ -2316,12 +2316,12 @@ class CacheKeyTest(PathTest, QueryTest): ) eq_( - opt1._generate_cache_key(query_path), + opt1._generate_path_cache_key(query_path), ((Order, "items", Item, ("lazy", "joined")),), ) eq_( - opt2._generate_cache_key(query_path), + opt2._generate_path_cache_key(query_path), ( (Order, "address", Address, ("lazy", "joined")), ( @@ -2356,7 +2356,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "orders"]) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( ( Order, @@ -2385,7 +2385,7 @@ class CacheKeyTest(PathTest, QueryTest): au = aliased(User) opt = Load(au).joinedload(au.orders).joinedload(Order.items) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_bound_cache_key_included_unsafe_option_one(self): User, Address, Order, Item, SubItem = self.classes( @@ -2399,7 +2399,7 @@ class CacheKeyTest(PathTest, QueryTest): .joinedload(User.orders) .joinedload(Order.items.of_type(aliased(SubItem))) ) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_bound_cache_key_included_unsafe_option_two(self): User, Address, Order, Item, SubItem = self.classes( @@ -2413,7 +2413,7 @@ class CacheKeyTest(PathTest, QueryTest): .joinedload(User.orders) .joinedload(Order.items.of_type(aliased(SubItem))) ) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_bound_cache_key_included_unsafe_option_three(self): User, Address, Order, Item, SubItem = self.classes( @@ -2427,7 +2427,7 @@ class CacheKeyTest(PathTest, QueryTest): .joinedload(User.orders) .joinedload(Order.items.of_type(aliased(SubItem))) ) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_bound_cache_key_included_unsafe_query(self): User, Address, Order, Item, SubItem = self.classes( @@ -2438,7 +2438,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([inspect(au), "orders"]) opt = Load(au).joinedload(au.orders).joinedload(Order.items) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_bound_cache_key_included_safe_w_option(self): User, Address, Order, Item, SubItem = self.classes( @@ -2454,7 +2454,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "orders"]) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( ( Order, @@ -2483,7 +2483,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = defaultload(User.addresses).load_only("id", "email_address") eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( (Address, "id", ("deferred", False), ("instrument", True)), ( @@ -2513,7 +2513,7 @@ class CacheKeyTest(PathTest, QueryTest): Address.id, Address.email_address ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( (Address, "id", ("deferred", False), ("instrument", True)), ( @@ -2545,7 +2545,7 @@ class CacheKeyTest(PathTest, QueryTest): .load_only("id", "email_address") ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( (Address, "id", ("deferred", False), ("instrument", True)), ( @@ -2572,7 +2572,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = defaultload(User.addresses).undefer_group("xyz") eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ((Address, "column:*", ("undefer_group_xyz", True)),), ) @@ -2584,6 +2584,6 @@ class CacheKeyTest(PathTest, QueryTest): opt = Load(User).defaultload(User.addresses).undefer_group("xyz") eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ((Address, "column:*", ("undefer_group_xyz", True)),), ) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index d48a8ed338..5d21960b70 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -32,6 +32,7 @@ from sqlalchemy.sql import operators from sqlalchemy.sql import True_ from sqlalchemy.sql import type_coerce from sqlalchemy.sql import visitors +from sqlalchemy.sql.base import HasCacheKey from sqlalchemy.sql.elements import _label_reference from sqlalchemy.sql.elements import _textual_label_reference from sqlalchemy.sql.elements import Annotated @@ -46,13 +47,13 @@ from sqlalchemy.sql.functions import FunctionElement from sqlalchemy.sql.functions import GenericFunction from sqlalchemy.sql.functions import ReturnTypeFromArgs from sqlalchemy.sql.selectable import _OffsetLimitParam +from sqlalchemy.sql.selectable import AliasedReturnsRows from sqlalchemy.sql.selectable import FromGrouping from sqlalchemy.sql.selectable import Selectable from sqlalchemy.sql.selectable import SelectStatementGrouping -from sqlalchemy.testing import assert_raises_message +from sqlalchemy.sql.visitors import InternalTraversal from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing import ne_ @@ -63,8 +64,17 @@ meta = MetaData() meta2 = MetaData() table_a = Table("a", meta, Column("a", Integer), Column("b", String)) +table_b_like_a = Table("b2", meta, Column("a", Integer), Column("b", String)) + table_a_2 = Table("a", meta2, Column("a", Integer), Column("b", String)) +table_a_2_fs = Table( + "a", meta2, Column("a", Integer), Column("b", String), schema="fs" +) +table_a_2_bs = Table( + "a", meta2, Column("a", Integer), Column("b", String), schema="bs" +) + table_b = Table("b", meta, Column("a", Integer), Column("b", Integer)) table_c = Table("c", meta, Column("x", Integer), Column("y", Integer)) @@ -72,8 +82,18 @@ table_c = Table("c", meta, Column("x", Integer), Column("y", Integer)) table_d = Table("d", meta, Column("y", Integer), Column("z", Integer)) -class CompareAndCopyTest(fixtures.TestBase): +class MyEntity(HasCacheKey): + def __init__(self, name, element): + self.name = name + self.element = element + + _cache_key_traversal = [ + ("name", InternalTraversal.dp_string), + ("element", InternalTraversal.dp_clauseelement), + ] + +class CoreFixtures(object): # lambdas which return a tuple of ColumnElement objects. # must return at least two objects that should compare differently. # to test more varieties of "difference" additional objects can be added. @@ -100,11 +120,47 @@ class CompareAndCopyTest(fixtures.TestBase): text("select a, b, c from table").columns( a=Integer, b=String, c=Integer ), + text("select a, b, c from table where foo=:bar").bindparams( + bindparam("bar", Integer) + ), + text("select a, b, c from table where foo=:foo").bindparams( + bindparam("foo", Integer) + ), + text("select a, b, c from table where foo=:bar").bindparams( + bindparam("bar", String) + ), ), lambda: ( column("q") == column("x"), column("q") == column("y"), column("z") == column("x"), + column("z") + column("x"), + column("z") - column("x"), + column("x") - column("z"), + column("z") > column("x"), + # note these two are mathematically equivalent but for now they + # are considered to be different + column("z") >= column("x"), + column("x") <= column("z"), + column("q").between(5, 6), + column("q").between(5, 6, symmetric=True), + column("q").like("somstr"), + column("q").like("somstr", escape="\\"), + column("q").like("somstr", escape="X"), + ), + lambda: ( + table_a.c.a, + table_a.c.a._annotate({"orm": True}), + table_a.c.a._annotate({"orm": True})._annotate({"bar": False}), + table_a.c.a._annotate( + {"orm": True, "parententity": MyEntity("a", table_a)} + ), + table_a.c.a._annotate( + {"orm": True, "parententity": MyEntity("b", table_a)} + ), + table_a.c.a._annotate( + {"orm": True, "parententity": MyEntity("b", select([table_a]))} + ), ), lambda: ( cast(column("q"), Integer), @@ -225,6 +281,58 @@ class CompareAndCopyTest(fixtures.TestBase): .where(table_a.c.b == 5) .correlate_except(table_b), ), + lambda: ( + select([table_a.c.a]).cte(), + select([table_a.c.a]).cte(recursive=True), + select([table_a.c.a]).cte(name="some_cte", recursive=True), + select([table_a.c.a]).cte(name="some_cte"), + select([table_a.c.a]).cte(name="some_cte").alias("other_cte"), + select([table_a.c.a]) + .cte(name="some_cte") + .union_all(select([table_a.c.a])), + select([table_a.c.a]) + .cte(name="some_cte") + .union_all(select([table_a.c.b])), + select([table_a.c.a]).lateral(), + select([table_a.c.a]).lateral(name="bar"), + table_a.tablesample(func.bernoulli(1)), + table_a.tablesample(func.bernoulli(1), seed=func.random()), + table_a.tablesample(func.bernoulli(1), seed=func.other_random()), + table_a.tablesample(func.hoho(1)), + table_a.tablesample(func.bernoulli(1), name="bar"), + table_a.tablesample( + func.bernoulli(1), name="bar", seed=func.random() + ), + ), + lambda: ( + select([table_a.c.a]), + select([table_a.c.a]).prefix_with("foo"), + select([table_a.c.a]).prefix_with("foo", dialect="mysql"), + select([table_a.c.a]).prefix_with("foo", dialect="postgresql"), + select([table_a.c.a]).prefix_with("bar"), + select([table_a.c.a]).suffix_with("bar"), + ), + lambda: ( + select([table_a_2.c.a]), + select([table_a_2_fs.c.a]), + select([table_a_2_bs.c.a]), + ), + lambda: ( + select([table_a.c.a]), + select([table_a.c.a]).with_hint(None, "some hint"), + select([table_a.c.a]).with_hint(None, "some other hint"), + select([table_a.c.a]).with_hint(table_a, "some hint"), + select([table_a.c.a]) + .with_hint(table_a, "some hint") + .with_hint(None, "some other hint"), + select([table_a.c.a]).with_hint(table_a, "some other hint"), + select([table_a.c.a]).with_hint( + table_a, "some hint", dialect_name="mysql" + ), + select([table_a.c.a]).with_hint( + table_a, "some hint", dialect_name="postgresql" + ), + ), lambda: ( table_a.join(table_b, table_a.c.a == table_b.c.a), table_a.join( @@ -273,12 +381,202 @@ class CompareAndCopyTest(fixtures.TestBase): table("a", column("x"), column("y", Integer)), table("a", column("q"), column("y", Integer)), ), - lambda: ( - Table("a", MetaData(), Column("q", Integer), Column("b", String)), - Table("b", MetaData(), Column("q", Integer), Column("b", String)), - ), + lambda: (table_a, table_b), ] + def _complex_fixtures(): + def one(): + a1 = table_a.alias() + a2 = table_b_like_a.alias() + + stmt = ( + select([table_a.c.a, a1.c.b, a2.c.b]) + .where(table_a.c.b == a1.c.b) + .where(a1.c.b == a2.c.b) + .where(a1.c.a == 5) + ) + + return stmt + + def one_diff(): + a1 = table_b_like_a.alias() + a2 = table_a.alias() + + stmt = ( + select([table_a.c.a, a1.c.b, a2.c.b]) + .where(table_a.c.b == a1.c.b) + .where(a1.c.b == a2.c.b) + .where(a1.c.a == 5) + ) + + return stmt + + def two(): + inner = one().subquery() + + stmt = select([table_b.c.a, inner.c.a, inner.c.b]).select_from( + table_b.join(inner, table_b.c.b == inner.c.b) + ) + + return stmt + + def three(): + + a1 = table_a.alias() + a2 = table_a.alias() + ex = exists().where(table_b.c.b == a1.c.a) + + stmt = ( + select([a1.c.a, a2.c.a]) + .select_from(a1.join(a2, a1.c.b == a2.c.b)) + .where(ex) + ) + return stmt + + return [one(), one_diff(), two(), three()] + + fixtures.append(_complex_fixtures) + + +class CacheKeyFixture(object): + def _run_cache_key_fixture(self, fixture): + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + if a == b: + a_key = case_a[a]._generate_cache_key() + b_key = case_b[b]._generate_cache_key() + eq_(a_key.key, b_key.key) + + for a_param, b_param in zip( + a_key.bindparams, b_key.bindparams + ): + assert a_param.compare(b_param, compare_values=False) + else: + a_key = case_a[a]._generate_cache_key() + b_key = case_b[b]._generate_cache_key() + + if a_key.key == b_key.key: + for a_param, b_param in zip( + a_key.bindparams, b_key.bindparams + ): + if not a_param.compare(b_param, compare_values=True): + break + else: + # this fails unconditionally since we could not + # find bound parameter values that differed. + # Usually we intended to get two distinct keys here + # so the failure will be more descriptive using the + # ne_() assertion. + ne_(a_key.key, b_key.key) + else: + ne_(a_key.key, b_key.key) + + # ClauseElement-specific test to ensure the cache key + # collected all the bound parameters + if isinstance(case_a[a], ClauseElement) and isinstance( + case_b[b], ClauseElement + ): + assert_a_params = [] + assert_b_params = [] + visitors.traverse_depthfirst( + case_a[a], {}, {"bindparam": assert_a_params.append} + ) + visitors.traverse_depthfirst( + case_b[b], {}, {"bindparam": assert_b_params.append} + ) + + # note we're asserting the order of the params as well as + # if there are dupes or not. ordering has to be deterministic + # and matches what a traversal would provide. + # regular traverse_depthfirst does produce dupes in cases like + # select([some_alias]). + # select_from(join(some_alias, other_table)) + # where a bound parameter is inside of some_alias. the + # cache key case is more minimalistic + eq_( + sorted(a_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_a_params), key=lambda b: b.key + ), + ) + eq_( + sorted(b_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_b_params), key=lambda b: b.key + ), + ) + + +class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase): + def test_cache_key(self): + for fixture in self.fixtures: + self._run_cache_key_fixture(fixture) + + def test_cache_key_unknown_traverse(self): + class Foobar1(ClauseElement): + _traverse_internals = [ + ("key", InternalTraversal.dp_anon_name), + ("type_", InternalTraversal.dp_unknown_structure), + ] + + def __init__(self, key, type_): + self.key = key + self.type_ = type_ + + f1 = Foobar1("foo", String()) + eq_(f1._generate_cache_key(), None) + + def test_cache_key_no_method(self): + class Foobar1(ClauseElement): + pass + + class Foobar2(ColumnElement): + pass + + # the None for cache key will prevent objects + # which contain these elements from being cached. + f1 = Foobar1() + eq_(f1._generate_cache_key(), None) + + f2 = Foobar2() + eq_(f2._generate_cache_key(), None) + + s1 = select([column("q"), Foobar2()]) + + eq_(s1._generate_cache_key(), None) + + def test_get_children_no_method(self): + class Foobar1(ClauseElement): + pass + + class Foobar2(ColumnElement): + pass + + f1 = Foobar1() + eq_(f1.get_children(), []) + + f2 = Foobar2() + eq_(f2.get_children(), []) + + def test_copy_internals_no_method(self): + class Foobar1(ClauseElement): + pass + + class Foobar2(ColumnElement): + pass + + f1 = Foobar1() + f2 = Foobar2() + + f1._copy_internals() + f2._copy_internals() + + +class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): @classmethod def setup_class(cls): # TODO: we need to get dialects here somehow, perhaps in test_suite? @@ -293,7 +591,10 @@ class CompareAndCopyTest(fixtures.TestBase): cls for cls in class_hierarchy(ClauseElement) if issubclass(cls, (ColumnElement, Selectable)) - and "__init__" in cls.__dict__ + and ( + "__init__" in cls.__dict__ + or issubclass(cls, AliasedReturnsRows) + ) and not issubclass(cls, (Annotated)) and "orm" not in cls.__module__ and "compiler" not in cls.__module__ @@ -318,123 +619,16 @@ class CompareAndCopyTest(fixtures.TestBase): ): if a == b: is_true( - case_a[a].compare( - case_b[b], arbitrary_expression=True - ), + case_a[a].compare(case_b[b], compare_annotations=True), "%r != %r" % (case_a[a], case_b[b]), ) else: is_false( - case_a[a].compare( - case_b[b], arbitrary_expression=True - ), + case_a[a].compare(case_b[b], compare_annotations=True), "%r == %r" % (case_a[a], case_b[b]), ) - def test_cache_key(self): - def assert_params_append(assert_params): - def append(param): - if param._value_required_for_cache: - assert_params.append(param) - else: - is_(param.value, None) - - return append - - for fixture in self.fixtures: - case_a = fixture() - case_b = fixture() - - for a, b in itertools.combinations_with_replacement( - range(len(case_a)), 2 - ): - - assert_a_params = [] - assert_b_params = [] - - visitors.traverse_depthfirst( - case_a[a], - {}, - {"bindparam": assert_params_append(assert_a_params)}, - ) - visitors.traverse_depthfirst( - case_b[b], - {}, - {"bindparam": assert_params_append(assert_b_params)}, - ) - if assert_a_params: - assert_raises_message( - NotImplementedError, - "bindparams collection argument required ", - case_a[a]._cache_key, - ) - if assert_b_params: - assert_raises_message( - NotImplementedError, - "bindparams collection argument required ", - case_b[b]._cache_key, - ) - - if not assert_a_params and not assert_b_params: - if a == b: - eq_(case_a[a]._cache_key(), case_b[b]._cache_key()) - else: - ne_(case_a[a]._cache_key(), case_b[b]._cache_key()) - - def test_cache_key_gather_bindparams(self): - for fixture in self.fixtures: - case_a = fixture() - case_b = fixture() - - # in the "bindparams" case, the cache keys for bound parameters - # with only different values will be the same, but the params - # themselves are gathered into a collection. - for a, b in itertools.combinations_with_replacement( - range(len(case_a)), 2 - ): - a_params = {"bindparams": []} - b_params = {"bindparams": []} - if a == b: - a_key = case_a[a]._cache_key(**a_params) - b_key = case_b[b]._cache_key(**b_params) - eq_(a_key, b_key) - - if a_params["bindparams"]: - for a_param, b_param in zip( - a_params["bindparams"], b_params["bindparams"] - ): - assert a_param.compare(b_param) - else: - a_key = case_a[a]._cache_key(**a_params) - b_key = case_b[b]._cache_key(**b_params) - - if a_key == b_key: - for a_param, b_param in zip( - a_params["bindparams"], b_params["bindparams"] - ): - if not a_param.compare(b_param): - break - else: - assert False, "Bound parameters are all the same" - else: - ne_(a_key, b_key) - - assert_a_params = [] - assert_b_params = [] - visitors.traverse_depthfirst( - case_a[a], {}, {"bindparam": assert_a_params.append} - ) - visitors.traverse_depthfirst( - case_b[b], {}, {"bindparam": assert_b_params.append} - ) - - # note we're asserting the order of the params as well as - # if there are dupes or not. ordering has to be deterministic - # and matches what a traversal would provide. - eq_(a_params["bindparams"], assert_a_params) - eq_(b_params["bindparams"], assert_b_params) - def test_compare_col_identity(self): stmt1 = ( select([table_a.c.a, table_b.c.b]) @@ -473,8 +667,9 @@ class CompareAndCopyTest(fixtures.TestBase): assert case_a[0].compare(case_b[0]) - clone = case_a[0]._clone() - clone._copy_internals() + clone = visitors.replacement_traverse( + case_a[0], {}, lambda elem: None + ) assert clone.compare(case_b[0]) @@ -511,6 +706,37 @@ class CompareAndCopyTest(fixtures.TestBase): class CompareClausesTest(fixtures.TestBase): + def test_compare_metadata_tables(self): + # metadata Table objects cache on their own identity, not their + # structure. This is mainly to reduce the size of cache keys + # as well as reduce computational overhead, as Table objects have + # very large internal state and they are also generally global + # objects. + + t1 = Table("a", MetaData(), Column("q", Integer), Column("p", Integer)) + t2 = Table("a", MetaData(), Column("q", Integer), Column("p", Integer)) + + ne_(t1._generate_cache_key(), t2._generate_cache_key()) + + eq_(t1._generate_cache_key().key, (t1,)) + + def test_compare_adhoc_tables(self): + # non-metadata tables compare on their structure. these objects are + # not commonly used. + + # note this test is a bit redundant as we have a similar test + # via the fixtures also + t1 = table("a", Column("q", Integer), Column("p", Integer)) + t2 = table("a", Column("q", Integer), Column("p", Integer)) + t3 = table("b", Column("q", Integer), Column("p", Integer)) + t4 = table("a", Column("q", Integer), Column("x", Integer)) + + eq_(t1._generate_cache_key(), t2._generate_cache_key()) + + ne_(t1._generate_cache_key(), t3._generate_cache_key()) + ne_(t1._generate_cache_key(), t4._generate_cache_key()) + ne_(t3._generate_cache_key(), t4._generate_cache_key()) + def test_compare_comparison_associative(self): l1 = table_c.c.x == table_d.c.y @@ -521,6 +747,15 @@ class CompareClausesTest(fixtures.TestBase): is_true(l1.compare(l2)) is_false(l1.compare(l3)) + def test_compare_comparison_non_commutative_inverses(self): + l1 = table_c.c.x >= table_d.c.y + l2 = table_d.c.y < table_c.c.x + l3 = table_d.c.y <= table_c.c.x + + # we're not doing this kind of commutativity right now. + is_false(l1.compare(l2)) + is_false(l1.compare(l3)) + def test_compare_clauselist_associative(self): l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z) @@ -624,3 +859,45 @@ class CompareClausesTest(fixtures.TestBase): use_proxies=True, ) ) + + def test_compare_annotated_clears_mapping(self): + t = table("t", column("x"), column("y")) + x_a = t.c.x._annotate({"foo": True}) + x_b = t.c.x._annotate({"foo": True}) + + is_true(x_a.compare(x_b, compare_annotations=True)) + is_false( + x_a.compare(x_b._annotate({"bar": True}), compare_annotations=True) + ) + + s1 = select([t.c.x])._annotate({"foo": True}) + s2 = select([t.c.x])._annotate({"foo": True}) + + is_true(s1.compare(s2, compare_annotations=True)) + + is_false( + s1.compare(s2._annotate({"bar": True}), compare_annotations=True) + ) + + def test_compare_annotated_wo_annotations(self): + t = table("t", column("x"), column("y")) + x_a = t.c.x._annotate({}) + x_b = t.c.x._annotate({"foo": True}) + + is_true(t.c.x.compare(x_a)) + is_true(x_b.compare(x_a)) + + is_true(x_a.compare(t.c.x)) + is_false(x_a.compare(t.c.y)) + is_false(t.c.y.compare(x_a)) + is_true((t.c.x == 5).compare(x_a == 5)) + is_false((t.c.y == 5).compare(x_a == 5)) + + s = select([t]).subquery() + x_p = s.c.x + is_false(x_a.compare(x_p)) + is_false(t.c.x.compare(x_p)) + x_p_a = x_p._annotate({}) + is_true(x_p_a.compare(x_p)) + is_true(x_p.compare(x_p_a)) + is_false(x_p_a.compare(x_a)) diff --git a/test/sql/test_generative.py b/test/sql/test_external_traversal.py similarity index 99% rename from test/sql/test_generative.py rename to test/sql/test_external_traversal.py index 8d347a522a..8bfe5cf6f8 100644 --- a/test/sql/test_generative.py +++ b/test/sql/test_external_traversal.py @@ -55,6 +55,7 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults): # identity semantics. class A(ClauseElement): __visit_name__ = "a" + _traverse_internals = [] def __init__(self, expr): self.expr = expr diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 637f1f8a5a..06cfdc4b5a 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -118,11 +118,14 @@ class DefaultColumnComparatorTest(fixtures.TestBase): ) ) + modifiers = operator(left, right).modifiers + assert operator(left, right).compare( BinaryExpression( coercions.expect(roles.WhereHavingRole, left), coercions.expect(roles.WhereHavingRole, right), operator, + modifiers=modifiers, ) ) diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 2bc7ccc931..184e4a99c2 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -1070,7 +1070,7 @@ class SelectableTest( s4 = s3.with_only_columns([table2.c.b]) self.assert_compile(s4, "SELECT t2.b FROM t2") - def test_from_list_warning_against_existing(self): + def test_from_list_against_existing_one(self): c1 = Column("c1", Integer) s = select([c1]) @@ -1081,7 +1081,7 @@ class SelectableTest( self.assert_compile(s, "SELECT t.c1 FROM t") - def test_from_list_recovers_after_warning(self): + def test_from_list_against_existing_two(self): c1 = Column("c1", Integer) c2 = Column("c2", Integer) @@ -1090,18 +1090,11 @@ class SelectableTest( # force a compile. eq_(str(s), "SELECT c1") - @testing.emits_warning() - def go(): - return Table("t", MetaData(), c1, c2) - - t = go() + t = Table("t", MetaData(), c1, c2) eq_(c1._from_objects, [t]) eq_(c2._from_objects, [t]) - # 's' has been baked. Can't afford - # not caching select._froms. - # hopefully the warning will clue the user self.assert_compile(s, "SELECT t.c1 FROM t") self.assert_compile(select([c1]), "SELECT t.c1 FROM t") self.assert_compile(select([c2]), "SELECT t.c2 FROM t") @@ -1124,6 +1117,26 @@ class SelectableTest( "foo", ) + def test_whereclause_adapted(self): + table1 = table("t1", column("a")) + + s1 = select([table1]).subquery() + + s2 = select([s1]).where(s1.c.a == 5) + + assert s2._whereclause.left.table is s1 + + ta = select([table1]).subquery() + + s3 = sql_util.ClauseAdapter(ta).traverse(s2) + + assert s1 not in s3._froms + + # these are new assumptions with the newer approach that + # actively swaps out whereclause and others + assert s3._whereclause.left.table is not s1 + assert s3._whereclause.left.table in s3._froms + class RefreshForNewColTest(fixtures.TestBase): def test_join_uninit(self): @@ -2241,25 +2254,6 @@ class AnnotationsTest(fixtures.TestBase): annot = obj._annotate({}) ne_(set([obj]), set([annot])) - def test_compare(self): - t = table("t", column("x"), column("y")) - x_a = t.c.x._annotate({}) - assert t.c.x.compare(x_a) - assert x_a.compare(t.c.x) - assert not x_a.compare(t.c.y) - assert not t.c.y.compare(x_a) - assert (t.c.x == 5).compare(x_a == 5) - assert not (t.c.y == 5).compare(x_a == 5) - - s = select([t]).subquery() - x_p = s.c.x - assert not x_a.compare(x_p) - assert not t.c.x.compare(x_p) - x_p_a = x_p._annotate({}) - assert x_p_a.compare(x_p) - assert x_p.compare(x_p_a) - assert not x_p_a.compare(x_a) - def test_proxy_set_iteration_includes_annotated(self): from sqlalchemy.schema import Column @@ -2542,13 +2536,13 @@ class AnnotationsTest(fixtures.TestBase): ): # the columns clause isn't changed at all assert sel._raw_columns[0].table is a1 - assert sel._froms[0] is sel._froms[1].left + assert sel._froms[0].element is sel._froms[1].left.element eq_(str(s), str(sel)) # when we are modifying annotations sets only - # partially, each element is copied unconditionally - # when encountered. + # partially, elements are copied uniquely based on id(). + # this is new as of 1.4, previously they'd be copied every time for sel in ( sql_util._deep_deannotate(s, {"foo": "bar"}), sql_util._deep_annotate(s, {"foo": "bar"}), diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py index 988d5331eb..48d6de6db0 100644 --- a/test/sql/test_utils.py +++ b/test/sql/test_utils.py @@ -7,6 +7,6 @@ from sqlalchemy.testing import fixtures class MiscTest(fixtures.TestBase): def test_column_element_no_visit(self): class MyElement(ColumnElement): - pass + _traverse_internals = [] eq_(sql_util.find_tables(MyElement(), check_columns=True), [])