]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add anonymizing context to cache keys, comparison; convert traversal
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Aug 2019 18:45:23 +0000 (14:45 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 4 Nov 2019 18:22:43 +0000 (13:22 -0500)
Created new visitor system called "internal traversal" that
applies a data driven approach to the concept of a class that
defines its own traversal steps, in contrast to the existing
style of traversal now known as "external traversal" where
the visitor class defines the traversal, i.e. the SQLCompiler.

The internal traversal system now implements get_children(),
_copy_internals(), compare() and _cache_key() for most Core elements.
Core elements with special needs like Select still implement
some of these methods directly however most of these methods
are no longer explicitly implemented.

The data-driven system is also applied to ORM elements that
take part in SQL expressions so that these objects, like mappers,
aliasedclass, query options, etc. can all participate in the
cache key process.

Still not considered is that this approach to defining traversibility
will be used to create some kind of generic introspection system
that works across Core / ORM.  It's also not clear if
real statement caching using the _cache_key() method is feasible,
if it is shown that running _cache_key() is nearly as expensive as
compiling in any case.    Because it is data driven, it is more
straightforward to optimize using inlined code, as is the case now,
as well as potentially using C code to speed it up.

In addition, the caching sytem now accommodates for anonymous
name labels, which is essential so that constructs which have
anonymous labels can be cacheable, that is, their position
within a statement in relation to other anonymous names causes
them to generate an integer counter relative to that construct
which will be the same every time.   Gathering of bound parameters
from any cache key generation is also now required as there is
no use case for a cache key that does not extract bound parameter
values.

Applies-to: #4639
Change-Id: I0660584def8627cad566719ee98d3be045db4b8d

37 files changed:
doc/build/core/sqlelement.rst
doc/build/core/visitors.rst
lib/sqlalchemy/dialects/mysql/dml.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/dml.py
lib/sqlalchemy/ext/baked.py
lib/sqlalchemy/ext/compiler.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/path_registry.py
lib/sqlalchemy/orm/strategy_options.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/annotation.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/clause_compare.py [deleted file]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/default_comparator.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/traversals.py [new file with mode: 0644]
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
test/aaa_profiling/test_orm.py
test/ext/test_baked.py
test/orm/test_cache_key.py [new file with mode: 0644]
test/orm/test_options.py
test/sql/test_compare.py
test/sql/test_external_traversal.py [moved from test/sql/test_generative.py with 99% similarity]
test/sql/test_operators.py
test/sql/test_selectable.py
test/sql/test_utils.py

index 41a27af71148117787a008e12059cae14ce73625..f7f9cab641ff7a558956bb0581c421393e77a2d2 100644 (file)
@@ -82,6 +82,9 @@ the FROM clause of a SELECT statement.
 .. autoclass:: BindParameter
    :members:
 
+.. autoclass:: CacheKey
+   :members:
+
 .. autoclass:: Case
    :members:
 
@@ -90,6 +93,7 @@ the FROM clause of a SELECT statement.
 
 .. autoclass:: ClauseElement
    :members:
+   :inherited-members:
 
 
 .. autoclass:: ClauseList
index 02f6e24fc44aaa632b1bfc2c9ac8312b1a0207e0..539d664405c25e87ea70ea0b6f4a2626f1c224b1 100644 (file)
@@ -23,3 +23,4 @@ as well as when building out custom SQL expressions using the
 
 .. automodule:: sqlalchemy.sql.visitors
    :members:
+   :private-members:
\ No newline at end of file
index 293aa426da122713c606d645b3b0e4fbf6cac0c0..b43b364fabe6a7df80e214e40520d5a72d3eb512 100644 (file)
@@ -103,7 +103,6 @@ class Insert(StandardInsert):
 
         inserted_alias = getattr(self, "inserted_alias", None)
         self._post_values_clause = OnDuplicateClause(inserted_alias, values)
-        return self
 
 
 insert = public_factory(Insert, ".dialects.mysql.insert")
index 909d568a7f291fb04b4854407628e591c7f1842b..e94f9913cd63a78823e36bb6b90a422e7f0224b5 100644 (file)
@@ -1658,23 +1658,20 @@ class PGCompiler(compiler.SQLCompiler):
         return "ONLY " + sqltext
 
     def get_select_precolumns(self, select, **kw):
-        if select._distinct is not False:
-            if select._distinct is True:
-                return "DISTINCT "
-            elif isinstance(select._distinct, (list, tuple)):
+        if select._distinct or select._distinct_on:
+            if select._distinct_on:
                 return (
                     "DISTINCT ON ("
                     + ", ".join(
-                        [self.process(col, **kw) for col in select._distinct]
+                        [
+                            self.process(col, **kw)
+                            for col in select._distinct_on
+                        ]
                     )
                     + ") "
                 )
             else:
-                return (
-                    "DISTINCT ON ("
-                    + self.process(select._distinct, **kw)
-                    + ") "
-                )
+                return "DISTINCT "
         else:
             return ""
 
index 4e77f5a4c92da70e7e8b798b134834b75b323ff7..f4467976afe03d8e57c4640fba3acb7c46915b7f 100644 (file)
@@ -103,7 +103,6 @@ class Insert(StandardInsert):
         self._post_values_clause = OnConflictDoUpdate(
             constraint, index_elements, index_where, set_, where
         )
-        return self
 
     @_generative
     def on_conflict_do_nothing(
@@ -138,7 +137,6 @@ class Insert(StandardInsert):
         self._post_values_clause = OnConflictDoNothing(
             constraint, index_elements, index_where
         )
-        return self
 
 
 insert = public_factory(Insert, ".dialects.postgresql.insert")
index d18a35a4072938d8b9bd49317f227d950a610d54..8e137f141a22f7973bbfea3f1a7c9c301258b8a0 100644 (file)
@@ -198,7 +198,7 @@ class BakedQuery(object):
             self.spoil()
         else:
             for opt in options:
-                cache_key = opt._generate_cache_key(cache_path)
+                cache_key = opt._generate_path_cache_key(cache_path)
                 if cache_key is False:
                     self.spoil()
                 elif cache_key is not None:
index 4a5a8ba9cdeba9f111cbf9f1cdd51129dd230290..c2b2347587a93ec8eff0e8ad12f9dec4d766457f 100644 (file)
@@ -455,7 +455,7 @@ def deregister(class_):
 
     if hasattr(class_, "_compiler_dispatcher"):
         # regenerate default _compiler_dispatch
-        visitors._generate_dispatch(class_)
+        visitors._generate_compiler_dispatch(class_)
         # remove custom directive
         del class_._compiler_dispatcher
 
index 83069f113c6ba89869c6dcb2a8cc1d231476595e..aa2986205bcba67a9617063d5155e4991eabd224 100644 (file)
@@ -47,6 +47,8 @@ from .base import state_str
 from .. import event
 from .. import inspection
 from .. import util
+from ..sql import base as sql_base
+from ..sql import visitors
 
 
 @inspection._self_inspects
@@ -54,6 +56,7 @@ class QueryableAttribute(
     interfaces._MappedAttribute,
     interfaces.InspectionAttr,
     interfaces.PropComparator,
+    sql_base.HasCacheKey,
 ):
     """Base class for :term:`descriptor` objects that intercept
     attribute events on behalf of a :class:`.MapperProperty`
@@ -102,6 +105,13 @@ class QueryableAttribute(
                     if base[key].dispatch._active_history:
                         self.dispatch._active_history = True
 
+    _cache_key_traversal = [
+        # ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj),
+        ("key", visitors.ExtendedInternalTraversal.dp_string),
+        ("_parententity", visitors.ExtendedInternalTraversal.dp_multi),
+        ("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+    ]
+
     @util.memoized_property
     def _supports_population(self):
         return self.impl.supports_population
index 6f8d1929345b85221afbd218186bc3cd4aef74a9..a3dea6b0e3daedf88f25ac6cb01fdb7ee05ec690 100644 (file)
@@ -216,7 +216,6 @@ def _assertions(*assertions):
         for assertion in assertions:
             assertion(self, fn.__name__)
         fn(self, *args[1:], **kw)
-        return self
 
     return generate
 
index e94a81fedbb6ed0d0ee5d9bcb7adb020dd698b78..704ce9df79375b2364b135809584973d3b96393e 100644 (file)
@@ -36,6 +36,8 @@ from .. import inspect
 from .. import inspection
 from .. import util
 from ..sql import operators
+from ..sql import visitors
+from ..sql.traversals import HasCacheKey
 
 
 __all__ = (
@@ -54,7 +56,9 @@ __all__ = (
 )
 
 
-class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots):
+class MapperProperty(
+    HasCacheKey, _MappedAttribute, InspectionAttr, util.MemoizedSlots
+):
     """Represent a particular class attribute mapped by :class:`.Mapper`.
 
     The most common occurrences of :class:`.MapperProperty` are the
@@ -74,6 +78,11 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots):
         "info",
     )
 
+    _cache_key_traversal = [
+        ("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+        ("key", visitors.ExtendedInternalTraversal.dp_string),
+    ]
+
     cascade = frozenset()
     """The set of 'cascade' attribute names.
 
@@ -647,7 +656,7 @@ class MapperOption(object):
 
         self.process_query(query)
 
-    def _generate_cache_key(self, path):
+    def _generate_path_cache_key(self, path):
         """Used by the "baked lazy loader" to see if this option can be cached.
 
         The "baked lazy loader" refers to the :class:`.Query` that is
index 376ad1923326d43f13f70f4ee62bc96b4cde51a0..548eca58db7485985661bea302aedbd5309ad439 100644 (file)
@@ -71,7 +71,7 @@ _CONFIGURE_MUTEX = util.threading.RLock()
 
 @inspection._self_inspects
 @log.class_logger
-class Mapper(InspectionAttr):
+class Mapper(sql_base.HasCacheKey, InspectionAttr):
     """Define the correlation of class attributes to database table
     columns.
 
@@ -729,6 +729,10 @@ class Mapper(InspectionAttr):
         """
         return self
 
+    _cache_key_traversal = [
+        ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj)
+    ]
+
     @property
     def entity(self):
         r"""Part of the inspection API.
index 2f680a3a162e4955cbf0f85182b782d10f961abc..585cb80bc328bd5ece2219dba87175a6aa97e794 100644 (file)
@@ -15,7 +15,8 @@ from .base import class_mapper
 from .. import exc
 from .. import inspection
 from .. import util
-
+from ..sql import visitors
+from ..sql.traversals import HasCacheKey
 
 log = logging.getLogger(__name__)
 
@@ -28,7 +29,7 @@ _WILDCARD_TOKEN = "*"
 _DEFAULT_TOKEN = "_sa_default"
 
 
-class PathRegistry(object):
+class PathRegistry(HasCacheKey):
     """Represent query load paths and registry functions.
 
     Basically represents structures like:
@@ -57,6 +58,10 @@ class PathRegistry(object):
     is_token = False
     is_root = False
 
+    _cache_key_traversal = [
+        ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list)
+    ]
+
     def __eq__(self, other):
         return other is not None and self.path == other.path
 
@@ -78,6 +83,9 @@ class PathRegistry(object):
     def __len__(self):
         return len(self.path)
 
+    def __hash__(self):
+        return id(self)
+
     @property
     def length(self):
         return len(self.path)
index 26f47f6169fcd1f0b9d5852aedf811ff14aa54f2..99bbbe37c7c95a37b0425bd0bbf3b768d01808ec 100644 (file)
@@ -26,11 +26,13 @@ from .. import inspect
 from .. import util
 from ..sql import coercions
 from ..sql import roles
+from ..sql import visitors
 from ..sql.base import _generative
 from ..sql.base import Generative
+from ..sql.traversals import HasCacheKey
 
 
-class Load(Generative, MapperOption):
+class Load(HasCacheKey, Generative, MapperOption):
     """Represents loader options which modify the state of a
     :class:`.Query` in order to affect how various mapped attributes are
     loaded.
@@ -70,6 +72,17 @@ class Load(Generative, MapperOption):
 
     """
 
+    _cache_key_traversal = [
+        ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+        ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
+        ("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+        (
+            "_context_cache_key",
+            visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples,
+        ),
+        ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict),
+    ]
+
     def __init__(self, entity):
         insp = inspect(entity)
         self.path = insp._path_registry
@@ -89,7 +102,16 @@ class Load(Generative, MapperOption):
         load._of_type = None
         return load
 
-    def _generate_cache_key(self, path):
+    @property
+    def _context_cache_key(self):
+        serialized = []
+        for (key, loader_path), obj in self.context.items():
+            if key != "loader":
+                continue
+            serialized.append(loader_path + (obj,))
+        return serialized
+
+    def _generate_path_cache_key(self, path):
         if path.path[0].is_aliased_class:
             return False
 
@@ -522,9 +544,16 @@ class _UnboundLoad(Load):
         self._to_bind = []
         self.local_opts = {}
 
+    _cache_key_traversal = [
+        ("path", visitors.ExtendedInternalTraversal.dp_multi_list),
+        ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
+        ("_to_bind", visitors.ExtendedInternalTraversal.dp_has_cache_key_list),
+        ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict),
+    ]
+
     _is_chain_link = False
 
-    def _generate_cache_key(self, path):
+    def _generate_path_cache_key(self, path):
         serialized = ()
         for val in self._to_bind:
             for local_elem, val_elem in zip(self.path, val.path):
@@ -533,7 +562,7 @@ class _UnboundLoad(Load):
             else:
                 opt = val._bind_loader([path.path[0]], None, None, False)
                 if opt:
-                    c_key = opt._generate_cache_key(path)
+                    c_key = opt._generate_path_cache_key(path)
                     if c_key is False:
                         return False
                     elif c_key:
@@ -660,7 +689,6 @@ class _UnboundLoad(Load):
 
         opt = meth(opt, all_tokens[-1], **kw)
         opt._is_chain_link = False
-
         return opt
 
     def _chop_path(self, to_chop, path):
index 5f0f41e8d95dfe13b9162bdbbd717badbdd3a75b..c86993678317257f943aa08013365d787e826af6 100644 (file)
@@ -30,10 +30,12 @@ from .. import exc as sa_exc
 from .. import inspection
 from .. import sql
 from .. import util
+from ..sql import base as sql_base
 from ..sql import coercions
 from ..sql import expression
 from ..sql import roles
 from ..sql import util as sql_util
+from ..sql import visitors
 
 
 all_cascades = frozenset(
@@ -530,7 +532,7 @@ class AliasedClass(object):
         return str(self._aliased_insp)
 
 
-class AliasedInsp(InspectionAttr):
+class AliasedInsp(sql_base.HasCacheKey, InspectionAttr):
     """Provide an inspection interface for an
     :class:`.AliasedClass` object.
 
@@ -627,6 +629,12 @@ class AliasedInsp(InspectionAttr):
     def __clause_element__(self):
         return self.selectable
 
+    _cache_key_traversal = [
+        ("name", visitors.ExtendedInternalTraversal.dp_string),
+        ("_adapt_on_names", visitors.ExtendedInternalTraversal.dp_boolean),
+        ("selectable", visitors.ExtendedInternalTraversal.dp_clauseelement),
+    ]
+
     @property
     def class_(self):
         """Return the mapped class ultimately represented by this
index a0264845e382cdabb77d2f9bb77b7010dd2cfd19..0d995ec8a23ad956e980ae3076e4a14a94460149 100644 (file)
@@ -12,12 +12,32 @@ associations.
 """
 
 from . import operators
+from .base import HasCacheKey
+from .visitors import InternalTraversal
 from .. import util
 
 
-class SupportsCloneAnnotations(object):
+class SupportsAnnotations(object):
+    @util.memoized_property
+    def _annotation_traversals(self):
+        return [
+            (
+                key,
+                InternalTraversal.dp_has_cache_key
+                if isinstance(value, HasCacheKey)
+                else InternalTraversal.dp_plain_obj,
+            )
+            for key, value in self._annotations.items()
+        ]
+
+
+class SupportsCloneAnnotations(SupportsAnnotations):
     _annotations = util.immutabledict()
 
+    _traverse_internals = [
+        ("_annotations", InternalTraversal.dp_annotations_state)
+    ]
+
     def _annotate(self, values):
         """return a copy of this ClauseElement with annotations
         updated by the given dictionary.
@@ -25,6 +45,7 @@ class SupportsCloneAnnotations(object):
         """
         new = self._clone()
         new._annotations = new._annotations.union(values)
+        new.__dict__.pop("_annotation_traversals", None)
         return new
 
     def _with_annotations(self, values):
@@ -34,6 +55,7 @@ class SupportsCloneAnnotations(object):
         """
         new = self._clone()
         new._annotations = util.immutabledict(values)
+        new.__dict__.pop("_annotation_traversals", None)
         return new
 
     def _deannotate(self, values=None, clone=False):
@@ -49,12 +71,13 @@ class SupportsCloneAnnotations(object):
             # the expression for a deep deannotation
             new = self._clone()
             new._annotations = {}
+            new.__dict__.pop("_annotation_traversals", None)
             return new
         else:
             return self
 
 
-class SupportsWrappingAnnotations(object):
+class SupportsWrappingAnnotations(SupportsAnnotations):
     def _annotate(self, values):
         """return a copy of this ClauseElement with annotations
         updated by the given dictionary.
@@ -123,6 +146,7 @@ class Annotated(object):
 
     def __init__(self, element, values):
         self.__dict__ = element.__dict__.copy()
+        self.__dict__.pop("_annotation_traversals", None)
         self.__element = element
         self._annotations = values
         self._hash = hash(element)
@@ -135,6 +159,7 @@ class Annotated(object):
     def _with_annotations(self, values):
         clone = self.__class__.__new__(self.__class__)
         clone.__dict__ = self.__dict__.copy()
+        clone.__dict__.pop("_annotation_traversals", None)
         clone._annotations = values
         return clone
 
@@ -192,7 +217,17 @@ def _deep_annotate(element, annotations, exclude=None):
 
     """
 
-    def clone(elem):
+    # annotated objects hack the __hash__() method so if we want to
+    # uniquely process them we have to use id()
+
+    cloned_ids = {}
+
+    def clone(elem, **kw):
+        id_ = id(elem)
+
+        if id_ in cloned_ids:
+            return cloned_ids[id_]
+
         if (
             exclude
             and hasattr(elem, "proxy_set")
@@ -204,6 +239,7 @@ def _deep_annotate(element, annotations, exclude=None):
         else:
             newelem = elem
         newelem._copy_internals(clone=clone)
+        cloned_ids[id_] = newelem
         return newelem
 
     if element is not None:
@@ -214,23 +250,21 @@ def _deep_annotate(element, annotations, exclude=None):
 def _deep_deannotate(element, values=None):
     """Deep copy the given element, removing annotations."""
 
-    cloned = util.column_dict()
+    cloned = {}
 
-    def clone(elem):
-        # if a values dict is given,
-        # the elem must be cloned each time it appears,
-        # as there may be different annotations in source
-        # elements that are remaining.  if totally
-        # removing all annotations, can assume the same
-        # slate...
-        if values or elem not in cloned:
+    def clone(elem, **kw):
+        if values:
+            key = id(elem)
+        else:
+            key = elem
+
+        if key not in cloned:
             newelem = elem._deannotate(values=values, clone=True)
             newelem._copy_internals(clone=clone)
-            if not values:
-                cloned[elem] = newelem
+            cloned[key] = newelem
             return newelem
         else:
-            return cloned[elem]
+            return cloned[key]
 
     if element is not None:
         element = clone(element)
@@ -268,6 +302,11 @@ def _new_annotation_type(cls, base_cls):
         "Annotated%s" % cls.__name__, (base_cls, cls), {}
     )
     globals()["Annotated%s" % cls.__name__] = anno_cls
+
+    if "_traverse_internals" in cls.__dict__:
+        anno_cls._traverse_internals = list(cls._traverse_internals) + [
+            ("_annotations", InternalTraversal.dp_annotations_state)
+        ]
     return anno_cls
 
 
index 7e9199bfa8aa1128ad4aa9753ab920db0c057922..d11a3a31397ab4c30468a29ee449f140c9f01a36 100644 (file)
@@ -14,6 +14,7 @@ import itertools
 import operator
 import re
 
+from .traversals import HasCacheKey  # noqa
 from .visitors import ClauseVisitor
 from .. import exc
 from .. import util
@@ -38,18 +39,41 @@ class Immutable(object):
     def _clone(self):
         return self
 
+    def _copy_internals(self, **kw):
+        pass
+
+
+class HasMemoized(object):
+    def _reset_memoizations(self):
+        self._memoized_property.expire_instance(self)
+
+    def _reset_exported(self):
+        self._memoized_property.expire_instance(self)
+
+    def _copy_internals(self, **kw):
+        super(HasMemoized, self)._copy_internals(**kw)
+        self._reset_memoizations()
+
 
 def _from_objects(*elements):
     return itertools.chain(*[element._from_objects for element in elements])
 
 
 def _generative(fn):
+    """non-caching _generative() decorator.
+
+    This is basically the legacy decorator that copies the object and
+    runs a method on the new copy.
+
+    """
+
     @util.decorator
-    def _generative(fn, *args, **kw):
+    def _generative(fn, self, *args, **kw):
         """Mark a method as generative."""
 
-        self = args[0]._generate()
-        fn(self, *args[1:], **kw)
+        self = self._generate()
+        x = fn(self, *args, **kw)
+        assert x is None, "generative methods must have no return value"
         return self
 
     decorated = _generative(fn)
@@ -357,10 +381,8 @@ class DialectKWArgs(object):
 
 
 class Generative(object):
-    """Allow a ClauseElement to generate itself via the
-    @_generative decorator.
-
-    """
+    """Provide a method-chaining pattern in conjunction with the
+    @_generative decorator."""
 
     def _generate(self):
         s = self.__class__.__new__(self.__class__)
diff --git a/lib/sqlalchemy/sql/clause_compare.py b/lib/sqlalchemy/sql/clause_compare.py
deleted file mode 100644 (file)
index 30a9034..0000000
+++ /dev/null
@@ -1,334 +0,0 @@
-from collections import deque
-
-from . import operators
-from .. import util
-
-
-SKIP_TRAVERSE = util.symbol("skip_traverse")
-
-
-def compare(obj1, obj2, **kw):
-    if kw.get("use_proxies", False):
-        strategy = ColIdentityComparatorStrategy()
-    else:
-        strategy = StructureComparatorStrategy()
-
-    return strategy.compare(obj1, obj2, **kw)
-
-
-class StructureComparatorStrategy(object):
-    __slots__ = "compare_stack", "cache"
-
-    def __init__(self):
-        self.compare_stack = deque()
-        self.cache = set()
-
-    def compare(self, obj1, obj2, **kw):
-        stack = self.compare_stack
-        cache = self.cache
-
-        stack.append((obj1, obj2))
-
-        while stack:
-            left, right = stack.popleft()
-
-            if left is right:
-                continue
-            elif left is None or right is None:
-                # we know they are different so no match
-                return False
-            elif (left, right) in cache:
-                continue
-            cache.add((left, right))
-
-            visit_name = left.__visit_name__
-
-            # we're not exactly looking for identical types, because
-            # there are things like Column and AnnotatedColumn.  So the
-            # visit_name has to at least match up
-            if visit_name != right.__visit_name__:
-                return False
-
-            meth = getattr(self, "compare_%s" % visit_name, None)
-
-            if meth:
-                comparison = meth(left, right, **kw)
-                if comparison is False:
-                    return False
-                elif comparison is SKIP_TRAVERSE:
-                    continue
-
-            for c1, c2 in util.zip_longest(
-                left.get_children(column_collections=False),
-                right.get_children(column_collections=False),
-                fillvalue=None,
-            ):
-                if c1 is None or c2 is None:
-                    # collections are different sizes, comparison fails
-                    return False
-                stack.append((c1, c2))
-
-        return True
-
-    def compare_inner(self, obj1, obj2, **kw):
-        stack = self.compare_stack
-        try:
-            self.compare_stack = deque()
-            return self.compare(obj1, obj2, **kw)
-        finally:
-            self.compare_stack = stack
-
-    def _compare_unordered_sequences(self, seq1, seq2, **kw):
-        if seq1 is None:
-            return seq2 is None
-
-        completed = set()
-        for clause in seq1:
-            for other_clause in set(seq2).difference(completed):
-                if self.compare_inner(clause, other_clause, **kw):
-                    completed.add(other_clause)
-                    break
-        return len(completed) == len(seq1) == len(seq2)
-
-    def compare_bindparam(self, left, right, **kw):
-        # note the ".key" is often generated from id(self) so can't
-        # be compared, as far as determining structure.
-        return (
-            left.type._compare_type_affinity(right.type)
-            and left.value == right.value
-            and left.callable == right.callable
-            and left._orig_key == right._orig_key
-        )
-
-    def compare_clauselist(self, left, right, **kw):
-        if left.operator is right.operator:
-            if operators.is_associative(left.operator):
-                if self._compare_unordered_sequences(
-                    left.clauses, right.clauses
-                ):
-                    return SKIP_TRAVERSE
-                else:
-                    return False
-            else:
-                # normal ordered traversal
-                return True
-        else:
-            return False
-
-    def compare_unary(self, left, right, **kw):
-        if left.operator:
-            disp = self._get_operator_dispatch(
-                left.operator, "unary", "operator"
-            )
-            if disp is not None:
-                result = disp(left, right, left.operator, **kw)
-                if result is not True:
-                    return result
-        elif left.modifier:
-            disp = self._get_operator_dispatch(
-                left.modifier, "unary", "modifier"
-            )
-            if disp is not None:
-                result = disp(left, right, left.operator, **kw)
-                if result is not True:
-                    return result
-        return (
-            left.operator == right.operator and left.modifier == right.modifier
-        )
-
-    def compare_binary(self, left, right, **kw):
-        disp = self._get_operator_dispatch(left.operator, "binary", None)
-        if disp:
-            result = disp(left, right, left.operator, **kw)
-            if result is not True:
-                return result
-
-        if left.operator == right.operator:
-            if operators.is_commutative(left.operator):
-                if (
-                    compare(left.left, right.left, **kw)
-                    and compare(left.right, right.right, **kw)
-                ) or (
-                    compare(left.left, right.right, **kw)
-                    and compare(left.right, right.left, **kw)
-                ):
-                    return SKIP_TRAVERSE
-                else:
-                    return False
-            else:
-                return True
-        else:
-            return False
-
-    def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
-        # used by compare_binary, compare_unary
-        attrname = "visit_%s_%s%s" % (
-            operator_.__name__,
-            qualifier1,
-            "_" + qualifier2 if qualifier2 else "",
-        )
-        return getattr(self, attrname, None)
-
-    def visit_function_as_comparison_op_binary(
-        self, left, right, operator, **kw
-    ):
-        return (
-            left.left_index == right.left_index
-            and left.right_index == right.right_index
-        )
-
-    def compare_function(self, left, right, **kw):
-        return left.name == right.name
-
-    def compare_column(self, left, right, **kw):
-        if left.table is not None:
-            self.compare_stack.appendleft((left.table, right.table))
-        return (
-            left.key == right.key
-            and left.name == right.name
-            and (
-                left.type._compare_type_affinity(right.type)
-                if left.type is not None
-                else right.type is None
-            )
-            and left.is_literal == right.is_literal
-        )
-
-    def compare_collation(self, left, right, **kw):
-        return left.collation == right.collation
-
-    def compare_type_coerce(self, left, right, **kw):
-        return left.type._compare_type_affinity(right.type)
-
-    @util.dependencies("sqlalchemy.sql.elements")
-    def compare_alias(self, elements, left, right, **kw):
-        return (
-            left.name == right.name
-            if not isinstance(left.name, elements._anonymous_label)
-            else isinstance(right.name, elements._anonymous_label)
-        )
-
-    def compare_cte(self, elements, left, right, **kw):
-        raise NotImplementedError("TODO")
-
-    def compare_extract(self, left, right, **kw):
-        return left.field == right.field
-
-    def compare_textual_label_reference(self, left, right, **kw):
-        return left.element == right.element
-
-    def compare_slice(self, left, right, **kw):
-        return (
-            left.start == right.start
-            and left.stop == right.stop
-            and left.step == right.step
-        )
-
-    def compare_over(self, left, right, **kw):
-        return left.range_ == right.range_ and left.rows == right.rows
-
-    @util.dependencies("sqlalchemy.sql.elements")
-    def compare_label(self, elements, left, right, **kw):
-        return left._type._compare_type_affinity(right._type) and (
-            left.name == right.name
-            if not isinstance(left.name, elements._anonymous_label)
-            else isinstance(right.name, elements._anonymous_label)
-        )
-
-    def compare_typeclause(self, left, right, **kw):
-        return left.type._compare_type_affinity(right.type)
-
-    def compare_join(self, left, right, **kw):
-        return left.isouter == right.isouter and left.full == right.full
-
-    def compare_table(self, left, right, **kw):
-        if left.name != right.name:
-            return False
-
-        self.compare_stack.extendleft(
-            util.zip_longest(left.columns, right.columns)
-        )
-
-    def compare_compound_select(self, left, right, **kw):
-
-        if not self._compare_unordered_sequences(
-            left.selects, right.selects, **kw
-        ):
-            return False
-
-        if left.keyword != right.keyword:
-            return False
-
-        if left._for_update_arg != right._for_update_arg:
-            return False
-
-        if not self.compare_inner(
-            left._order_by_clause, right._order_by_clause, **kw
-        ):
-            return False
-
-        if not self.compare_inner(
-            left._group_by_clause, right._group_by_clause, **kw
-        ):
-            return False
-
-        return SKIP_TRAVERSE
-
-    def compare_select(self, left, right, **kw):
-        if not self._compare_unordered_sequences(
-            left._correlate, right._correlate
-        ):
-            return False
-        if not self._compare_unordered_sequences(
-            left._correlate_except, right._correlate_except
-        ):
-            return False
-
-        if not self._compare_unordered_sequences(
-            left._from_obj, right._from_obj
-        ):
-            return False
-
-        if left._for_update_arg != right._for_update_arg:
-            return False
-
-        return True
-
-    def compare_textual_select(self, left, right, **kw):
-        self.compare_stack.extendleft(
-            util.zip_longest(left.column_args, right.column_args)
-        )
-        return left.positional == right.positional
-
-
-class ColIdentityComparatorStrategy(StructureComparatorStrategy):
-    def compare_column_element(
-        self, left, right, use_proxies=True, equivalents=(), **kw
-    ):
-        """Compare ColumnElements using proxies and equivalent collections.
-
-        This is a comparison strategy specific to the ORM.
-        """
-
-        to_compare = (right,)
-        if equivalents and right in equivalents:
-            to_compare = equivalents[right].union(to_compare)
-
-        for oth in to_compare:
-            if use_proxies and left.shares_lineage(oth):
-                return True
-            elif hash(left) == hash(right):
-                return True
-        else:
-            return False
-
-    def compare_column(self, left, right, **kw):
-        return self.compare_column_element(left, right, **kw)
-
-    def compare_label(self, left, right, **kw):
-        return self.compare_column_element(left, right, **kw)
-
-    def compare_table(self, left, right, **kw):
-        # tables compare on identity, since it's not really feasible to
-        # compare them column by column with the above rules
-        return left is right
index 5ecec7d6c208c8cb31418938f7b4df1bea76ecf8..546fffc6c4c2cfb582debb161f9865b9db2dc8ae 100644 (file)
@@ -434,6 +434,27 @@ class _CompileLabel(elements.ColumnElement):
         return self
 
 
+class prefix_anon_map(dict):
+    """A map that creates new keys for missing key access.
+
+    Considers keys of the form "<ident> <name>" to produce
+    new symbols "<name>_<index>", where "index" is an incrementing integer
+    corresponding to <name>.
+
+    Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
+    is otherwise usually used for this type of operation.
+
+    """
+
+    def __missing__(self, key):
+        (ident, derived) = key.split(" ", 1)
+        anonymous_counter = self.get(derived, 1)
+        self[derived] = anonymous_counter + 1
+        value = derived + "_" + str(anonymous_counter)
+        self[key] = value
+        return value
+
+
 class SQLCompiler(Compiled):
     """Default implementation of :class:`.Compiled`.
 
@@ -574,7 +595,7 @@ class SQLCompiler(Compiled):
 
         # a map which tracks "anonymous" identifiers that are created on
         # the fly here
-        self.anon_map = util.PopulateDict(self._process_anon)
+        self.anon_map = prefix_anon_map()
 
         # a map which tracks "truncated" names based on
         # dialect.label_length or dialect.max_identifier_length
@@ -1712,12 +1733,6 @@ class SQLCompiler(Compiled):
     def _anonymize(self, name):
         return name % self.anon_map
 
-    def _process_anon(self, key):
-        (ident, derived) = key.split(" ", 1)
-        anonymous_counter = self.anon_map.get(derived, 1)
-        self.anon_map[derived] = anonymous_counter + 1
-        return derived + "_" + str(anonymous_counter)
-
     def bindparam_string(
         self,
         name,
index 918f7524e53cb21d2e365e665ce8645ee4238353..c0baa8555682f74136ea2c6b6219416bfc0b3a37 100644 (file)
@@ -178,6 +178,9 @@ def _unsupported_impl(expr, op, *arg, **kw):
 
 def _inv_impl(expr, op, **kw):
     """See :meth:`.ColumnOperators.__inv__`."""
+
+    # undocumented element currently used by the ORM for
+    # relationship.contains()
     if hasattr(expr, "negation_clause"):
         return expr.negation_clause
     else:
index e6f57b8d1114bf1586574104975f6942179b8683..ba615bc3fad36d869eb241050a2586e60d517a6b 100644 (file)
@@ -16,23 +16,29 @@ import itertools
 import operator
 import re
 
-from . import clause_compare
 from . import coercions
 from . import operators
 from . import roles
+from . import traversals
 from . import type_api
 from .annotation import Annotated
 from .annotation import SupportsWrappingAnnotations
 from .base import _clone
 from .base import _generative
 from .base import Executable
+from .base import HasCacheKey
+from .base import HasMemoized
 from .base import Immutable
 from .base import NO_ARG
 from .base import PARSE_AUTOCOMMIT
 from .coercions import _document_text_coercion
+from .traversals import _copy_internals
+from .traversals import _get_children
+from .traversals import NO_CACHE
 from .visitors import cloned_traverse
+from .visitors import InternalTraversal
 from .visitors import traverse
-from .visitors import Visitable
+from .visitors import Traversible
 from .. import exc
 from .. import inspection
 from .. import util
@@ -162,7 +168,9 @@ def not_(clause):
 
 
 @inspection._self_inspects
-class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
+class ClauseElement(
+    roles.SQLRole, SupportsWrappingAnnotations, HasCacheKey, Traversible
+):
     """Base class for elements of a programmatically constructed SQL
     expression.
 
@@ -190,6 +198,13 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
 
     _order_by_label_element = None
 
+    @property
+    def _cache_key_traversal(self):
+        try:
+            return self._traverse_internals
+        except AttributeError:
+            return NO_CACHE
+
     def _clone(self):
         """Create a shallow copy of this ClauseElement.
 
@@ -221,28 +236,6 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
         """
         return self
 
-    def _cache_key(self, **kw):
-        """return an optional cache key.
-
-        The cache key is a tuple which can contain any series of
-        objects that are hashable and also identifies
-        this object uniquely within the presence of a larger SQL expression
-        or statement, for the purposes of caching the resulting query.
-
-        The cache key should be based on the SQL compiled structure that would
-        ultimately be produced.   That is, two structures that are composed in
-        exactly the same way should produce the same cache key; any difference
-        in the strucures that would affect the SQL string or the type handlers
-        should result in a different cache key.
-
-        If a structure cannot produce a useful cache key, it should raise
-        NotImplementedError, which will result in the entire structure
-        for which it's part of not being useful as a cache key.
-
-
-        """
-        raise NotImplementedError()
-
     @property
     def _constructor(self):
         """return the 'constructor' for this ClauseElement.
@@ -336,9 +329,9 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
         (see :class:`.ColumnElement`)
 
         """
-        return clause_compare.compare(self, other, **kw)
+        return traversals.compare(self, other, **kw)
 
-    def _copy_internals(self, clone=_clone, **kw):
+    def _copy_internals(self, **kw):
         """Reassign internal elements to be clones of themselves.
 
         Called during a copy-and-traverse operation on newly
@@ -349,21 +342,46 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
         traversal, cloned traversal, annotations).
 
         """
-        pass
 
-    def get_children(self, **kwargs):
-        r"""Return immediate child elements of this :class:`.ClauseElement`.
+        try:
+            traverse_internals = self._traverse_internals
+        except AttributeError:
+            return
+
+        for attrname, obj, meth in _copy_internals.run_generated_dispatch(
+            self, traverse_internals, "_generated_copy_internals_traversal"
+        ):
+            if obj is not None:
+                result = meth(self, obj, **kw)
+                if result is not None:
+                    setattr(self, attrname, result)
+
+    def get_children(self, omit_attrs=None, **kw):
+        r"""Return immediate child :class:`.Traversible` elements of this
+        :class:`.Traversible`.
 
         This is used for visit traversal.
 
-        \**kwargs may contain flags that change the collection that is
+        \**kw may contain flags that change the collection that is
         returned, for example to return a subset of items in order to
         cut down on larger traversals, or to return child items from a
         different context (such as schema-level collections instead of
         clause-level).
 
         """
-        return []
+        result = []
+        try:
+            traverse_internals = self._traverse_internals
+        except AttributeError:
+            return result
+
+        for attrname, obj, meth in _get_children.run_generated_dispatch(
+            self, traverse_internals, "_generated_get_children_traversal"
+        ):
+            if obj is None or omit_attrs and attrname in omit_attrs:
+                continue
+            result.extend(meth(obj, **kw))
+        return result
 
     def self_group(self, against=None):
         # type: (Optional[Any]) -> ClauseElement
@@ -501,6 +519,8 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
         return or_(self, other)
 
     def __invert__(self):
+        # undocumented element currently used by the ORM for
+        # relationship.contains()
         if hasattr(self, "negation_clause"):
             return self.negation_clause
         else:
@@ -508,9 +528,7 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
 
     def _negate(self):
         return UnaryExpression(
-            self.self_group(against=operators.inv),
-            operator=operators.inv,
-            negate=None,
+            self.self_group(against=operators.inv), operator=operators.inv
         )
 
     def __bool__(self):
@@ -731,9 +749,6 @@ class ColumnElement(
         else:
             return comparator_factory(self)
 
-    def _cache_key(self, **kw):
-        raise NotImplementedError(self.__class__)
-
     def __getattr__(self, key):
         try:
             return getattr(self.comparator, key)
@@ -969,6 +984,13 @@ class BindParameter(roles.InElementRole, ColumnElement):
 
     __visit_name__ = "bindparam"
 
+    _traverse_internals = [
+        ("key", InternalTraversal.dp_anon_name),
+        ("type", InternalTraversal.dp_type),
+        ("callable", InternalTraversal.dp_plain_dict),
+        ("value", InternalTraversal.dp_plain_obj),
+    ]
+
     _is_crud = False
     _expanding_in_types = ()
 
@@ -1321,26 +1343,19 @@ class BindParameter(roles.InElementRole, ColumnElement):
             )
         return c
 
-    def _cache_key(self, bindparams=None, **kw):
-        if bindparams is None:
-            # even though _cache_key is a private method, we would like to
-            # be super paranoid about this point.   You can't include the
-            # "value" or "callable" in the cache key, because the value is
-            # not part of the structure of a statement and is likely to
-            # change every time.  However you cannot *throw it away* either,
-            # because you can't invoke the statement without the parameter
-            # values that were explicitly placed.    So require that they
-            # are collected here to make sure this happens.
-            if self._value_required_for_cache:
-                raise NotImplementedError(
-                    "bindparams collection argument required for _cache_key "
-                    "implementation.  Bound parameter cache keys are not safe "
-                    "to use without accommodating for the value or callable "
-                    "within the parameter itself."
-                )
-        else:
-            bindparams.append(self)
-        return (BindParameter, self.type._cache_key, self._orig_key)
+    def _gen_cache_key(self, anon_map, bindparams):
+        if self in anon_map:
+            return (anon_map[self], self.__class__)
+
+        id_ = anon_map[self]
+        bindparams.append(self)
+
+        return (
+            id_,
+            self.__class__,
+            self.type._gen_cache_key,
+            traversals._resolve_name_for_compare(self, self.key, anon_map),
+        )
 
     def _convert_to_unique(self):
         if not self.unique:
@@ -1377,12 +1392,11 @@ class TypeClause(ClauseElement):
 
     __visit_name__ = "typeclause"
 
+    _traverse_internals = [("type", InternalTraversal.dp_type)]
+
     def __init__(self, type_):
         self.type = type_
 
-    def _cache_key(self, **kw):
-        return (TypeClause, self.type._cache_key)
-
 
 class TextClause(
     roles.DDLConstraintColumnRole,
@@ -1419,6 +1433,11 @@ class TextClause(
 
     __visit_name__ = "textclause"
 
+    _traverse_internals = [
+        ("_bindparams", InternalTraversal.dp_string_clauseelement_dict),
+        ("text", InternalTraversal.dp_string),
+    ]
+
     _is_text_clause = True
 
     _is_textual = True
@@ -1861,19 +1880,6 @@ class TextClause(
         else:
             return self
 
-    def _copy_internals(self, clone=_clone, **kw):
-        self._bindparams = dict(
-            (b.key, clone(b, **kw)) for b in self._bindparams.values()
-        )
-
-    def get_children(self, **kwargs):
-        return list(self._bindparams.values())
-
-    def _cache_key(self, **kw):
-        return (self.text,) + tuple(
-            bind._cache_key for bind in self._bindparams.values()
-        )
-
 
 class Null(roles.ConstExprRole, ColumnElement):
     """Represent the NULL keyword in a SQL statement.
@@ -1885,6 +1891,8 @@ class Null(roles.ConstExprRole, ColumnElement):
 
     __visit_name__ = "null"
 
+    _traverse_internals = []
+
     @util.memoized_property
     def type(self):
         return type_api.NULLTYPE
@@ -1895,9 +1903,6 @@ class Null(roles.ConstExprRole, ColumnElement):
 
         return Null()
 
-    def _cache_key(self, **kw):
-        return (Null,)
-
 
 class False_(roles.ConstExprRole, ColumnElement):
     """Represent the ``false`` keyword, or equivalent, in a SQL statement.
@@ -1908,6 +1913,7 @@ class False_(roles.ConstExprRole, ColumnElement):
     """
 
     __visit_name__ = "false"
+    _traverse_internals = []
 
     @util.memoized_property
     def type(self):
@@ -1954,9 +1960,6 @@ class False_(roles.ConstExprRole, ColumnElement):
 
         return False_()
 
-    def _cache_key(self, **kw):
-        return (False_,)
-
 
 class True_(roles.ConstExprRole, ColumnElement):
     """Represent the ``true`` keyword, or equivalent, in a SQL statement.
@@ -1968,6 +1971,8 @@ class True_(roles.ConstExprRole, ColumnElement):
 
     __visit_name__ = "true"
 
+    _traverse_internals = []
+
     @util.memoized_property
     def type(self):
         return type_api.BOOLEANTYPE
@@ -2020,9 +2025,6 @@ class True_(roles.ConstExprRole, ColumnElement):
 
         return True_()
 
-    def _cache_key(self, **kw):
-        return (True_,)
-
 
 class ClauseList(
     roles.InElementRole,
@@ -2038,6 +2040,11 @@ class ClauseList(
 
     __visit_name__ = "clauselist"
 
+    _traverse_internals = [
+        ("clauses", InternalTraversal.dp_clauseelement_list),
+        ("operator", InternalTraversal.dp_operator),
+    ]
+
     def __init__(self, *clauses, **kwargs):
         self.operator = kwargs.pop("operator", operators.comma_op)
         self.group = kwargs.pop("group", True)
@@ -2082,17 +2089,6 @@ class ClauseList(
                 coercions.expect(self._text_converter_role, clause)
             )
 
-    def _copy_internals(self, clone=_clone, **kw):
-        self.clauses = [clone(clause, **kw) for clause in self.clauses]
-
-    def get_children(self, **kwargs):
-        return self.clauses
-
-    def _cache_key(self, **kw):
-        return (ClauseList, self.operator) + tuple(
-            clause._cache_key(**kw) for clause in self.clauses
-        )
-
     @property
     def _from_objects(self):
         return list(itertools.chain(*[c._from_objects for c in self.clauses]))
@@ -2115,11 +2111,6 @@ class BooleanClauseList(ClauseList, ColumnElement):
             "BooleanClauseList has a private constructor"
         )
 
-    def _cache_key(self, **kw):
-        return (BooleanClauseList, self.operator) + tuple(
-            clause._cache_key(**kw) for clause in self.clauses
-        )
-
     @classmethod
     def _construct(cls, operator, continue_on, skip_on, *clauses, **kw):
         convert_clauses = []
@@ -2250,6 +2241,8 @@ or_ = BooleanClauseList.or_
 class Tuple(ClauseList, ColumnElement):
     """Represent a SQL tuple."""
 
+    _traverse_internals = ClauseList._traverse_internals + []
+
     def __init__(self, *clauses, **kw):
         """Return a :class:`.Tuple`.
 
@@ -2289,11 +2282,6 @@ class Tuple(ClauseList, ColumnElement):
     def _select_iterable(self):
         return (self,)
 
-    def _cache_key(self, **kw):
-        return (Tuple,) + tuple(
-            clause._cache_key(**kw) for clause in self.clauses
-        )
-
     def _bind_param(self, operator, obj, type_=None):
         return Tuple(
             *[
@@ -2339,6 +2327,12 @@ class Case(ColumnElement):
 
     __visit_name__ = "case"
 
+    _traverse_internals = [
+        ("value", InternalTraversal.dp_clauseelement),
+        ("whens", InternalTraversal.dp_clauseelement_tuples),
+        ("else_", InternalTraversal.dp_clauseelement),
+    ]
+
     def __init__(self, whens, value=None, else_=None):
         r"""Produce a ``CASE`` expression.
 
@@ -2501,40 +2495,6 @@ class Case(ColumnElement):
         else:
             self.else_ = None
 
-    def _copy_internals(self, clone=_clone, **kw):
-        if self.value is not None:
-            self.value = clone(self.value, **kw)
-        self.whens = [(clone(x, **kw), clone(y, **kw)) for x, y in self.whens]
-        if self.else_ is not None:
-            self.else_ = clone(self.else_, **kw)
-
-    def get_children(self, **kwargs):
-        if self.value is not None:
-            yield self.value
-        for x, y in self.whens:
-            yield x
-            yield y
-        if self.else_ is not None:
-            yield self.else_
-
-    def _cache_key(self, **kw):
-        return (
-            (
-                Case,
-                self.value._cache_key(**kw)
-                if self.value is not None
-                else None,
-            )
-            + tuple(
-                (x._cache_key(**kw), y._cache_key(**kw)) for x, y in self.whens
-            )
-            + (
-                self.else_._cache_key(**kw)
-                if self.else_ is not None
-                else None,
-            )
-        )
-
     @property
     def _from_objects(self):
         return list(
@@ -2603,6 +2563,11 @@ class Cast(WrapsColumnExpression, ColumnElement):
 
     __visit_name__ = "cast"
 
+    _traverse_internals = [
+        ("clause", InternalTraversal.dp_clauseelement),
+        ("typeclause", InternalTraversal.dp_clauseelement),
+    ]
+
     def __init__(self, expression, type_):
         r"""Produce a ``CAST`` expression.
 
@@ -2662,20 +2627,6 @@ class Cast(WrapsColumnExpression, ColumnElement):
         )
         self.typeclause = TypeClause(self.type)
 
-    def _copy_internals(self, clone=_clone, **kw):
-        self.clause = clone(self.clause, **kw)
-        self.typeclause = clone(self.typeclause, **kw)
-
-    def get_children(self, **kwargs):
-        return self.clause, self.typeclause
-
-    def _cache_key(self, **kw):
-        return (
-            Cast,
-            self.clause._cache_key(**kw),
-            self.typeclause._cache_key(**kw),
-        )
-
     @property
     def _from_objects(self):
         return self.clause._from_objects
@@ -2685,7 +2636,7 @@ class Cast(WrapsColumnExpression, ColumnElement):
         return self.clause
 
 
-class TypeCoerce(WrapsColumnExpression, ColumnElement):
+class TypeCoerce(HasMemoized, WrapsColumnExpression, ColumnElement):
     """Represent a Python-side type-coercion wrapper.
 
     :class:`.TypeCoerce` supplies the :func:`.expression.type_coerce`
@@ -2705,6 +2656,13 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement):
 
     __visit_name__ = "type_coerce"
 
+    _traverse_internals = [
+        ("clause", InternalTraversal.dp_clauseelement),
+        ("type", InternalTraversal.dp_type),
+    ]
+
+    _memoized_property = util.group_expirable_memoized_property()
+
     def __init__(self, expression, type_):
         r"""Associate a SQL expression with a particular type, without rendering
         ``CAST``.
@@ -2773,21 +2731,11 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement):
             roles.ExpressionElementRole, expression, type_=self.type
         )
 
-    def _copy_internals(self, clone=_clone, **kw):
-        self.clause = clone(self.clause, **kw)
-        self.__dict__.pop("typed_expression", None)
-
-    def get_children(self, **kwargs):
-        return (self.clause,)
-
-    def _cache_key(self, **kw):
-        return (TypeCoerce, self.type._cache_key, self.clause._cache_key(**kw))
-
     @property
     def _from_objects(self):
         return self.clause._from_objects
 
-    @util.memoized_property
+    @_memoized_property
     def typed_expression(self):
         if isinstance(self.clause, BindParameter):
             bp = self.clause._clone()
@@ -2806,6 +2754,11 @@ class Extract(ColumnElement):
 
     __visit_name__ = "extract"
 
+    _traverse_internals = [
+        ("expr", InternalTraversal.dp_clauseelement),
+        ("field", InternalTraversal.dp_string),
+    ]
+
     def __init__(self, field, expr, **kwargs):
         """Return a :class:`.Extract` construct.
 
@@ -2818,15 +2771,6 @@ class Extract(ColumnElement):
         self.field = field
         self.expr = coercions.expect(roles.ExpressionElementRole, expr)
 
-    def _copy_internals(self, clone=_clone, **kw):
-        self.expr = clone(self.expr, **kw)
-
-    def get_children(self, **kwargs):
-        return (self.expr,)
-
-    def _cache_key(self, **kw):
-        return (Extract, self.field, self.expr._cache_key(**kw))
-
     @property
     def _from_objects(self):
         return self.expr._from_objects
@@ -2847,18 +2791,11 @@ class _label_reference(ColumnElement):
 
     __visit_name__ = "label_reference"
 
+    _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+
     def __init__(self, element):
         self.element = element
 
-    def _copy_internals(self, clone=_clone, **kw):
-        self.element = clone(self.element, **kw)
-
-    def _cache_key(self, **kw):
-        return (_label_reference, self.element._cache_key(**kw))
-
-    def get_children(self, **kwargs):
-        return [self.element]
-
     @property
     def _from_objects(self):
         return ()
@@ -2867,6 +2804,8 @@ class _label_reference(ColumnElement):
 class _textual_label_reference(ColumnElement):
     __visit_name__ = "textual_label_reference"
 
+    _traverse_internals = [("element", InternalTraversal.dp_string)]
+
     def __init__(self, element):
         self.element = element
 
@@ -2874,9 +2813,6 @@ class _textual_label_reference(ColumnElement):
     def _text_clause(self):
         return TextClause._create_text(self.element)
 
-    def _cache_key(self, **kw):
-        return (_textual_label_reference, self.element)
-
 
 class UnaryExpression(ColumnElement):
     """Define a 'unary' expression.
@@ -2894,13 +2830,18 @@ class UnaryExpression(ColumnElement):
 
     __visit_name__ = "unary"
 
+    _traverse_internals = [
+        ("element", InternalTraversal.dp_clauseelement),
+        ("operator", InternalTraversal.dp_operator),
+        ("modifier", InternalTraversal.dp_operator),
+    ]
+
     def __init__(
         self,
         element,
         operator=None,
         modifier=None,
         type_=None,
-        negate=None,
         wraps_column_expression=False,
     ):
         self.operator = operator
@@ -2909,7 +2850,6 @@ class UnaryExpression(ColumnElement):
             against=self.operator or self.modifier
         )
         self.type = type_api.to_instance(type_)
-        self.negate = negate
         self.wraps_column_expression = wraps_column_expression
 
     @classmethod
@@ -3135,37 +3075,13 @@ class UnaryExpression(ColumnElement):
     def _from_objects(self):
         return self.element._from_objects
 
-    def _copy_internals(self, clone=_clone, **kw):
-        self.element = clone(self.element, **kw)
-
-    def _cache_key(self, **kw):
-        return (
-            UnaryExpression,
-            self.element._cache_key(**kw),
-            self.operator,
-            self.modifier,
-        )
-
-    def get_children(self, **kwargs):
-        return (self.element,)
-
     def _negate(self):
-        if self.negate is not None:
-            return UnaryExpression(
-                self.element,
-                operator=self.negate,
-                negate=self.operator,
-                modifier=self.modifier,
-                type_=self.type,
-                wraps_column_expression=self.wraps_column_expression,
-            )
-        elif self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
+        if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
             return UnaryExpression(
                 self.self_group(against=operators.inv),
                 operator=operators.inv,
                 type_=type_api.BOOLEANTYPE,
                 wraps_column_expression=self.wraps_column_expression,
-                negate=None,
             )
         else:
             return ClauseElement._negate(self)
@@ -3286,15 +3202,6 @@ class AsBoolean(WrapsColumnExpression, UnaryExpression):
         # type: (Optional[Any]) -> ClauseElement
         return self
 
-    def _cache_key(self, **kw):
-        return (
-            self.element._cache_key(**kw),
-            self.type._cache_key,
-            self.operator,
-            self.negate,
-            self.modifier,
-        )
-
     def _negate(self):
         if isinstance(self.element, (True_, False_)):
             return self.element._negate()
@@ -3318,6 +3225,14 @@ class BinaryExpression(ColumnElement):
 
     __visit_name__ = "binary"
 
+    _traverse_internals = [
+        ("left", InternalTraversal.dp_clauseelement),
+        ("right", InternalTraversal.dp_clauseelement),
+        ("operator", InternalTraversal.dp_operator),
+        ("negate", InternalTraversal.dp_operator),
+        ("modifiers", InternalTraversal.dp_plain_dict),
+    ]
+
     _is_implicitly_boolean = True
     """Indicates that any database will know this is a boolean expression
     even if the database does not have an explicit boolean datatype.
@@ -3360,20 +3275,6 @@ class BinaryExpression(ColumnElement):
     def _from_objects(self):
         return self.left._from_objects + self.right._from_objects
 
-    def _copy_internals(self, clone=_clone, **kw):
-        self.left = clone(self.left, **kw)
-        self.right = clone(self.right, **kw)
-
-    def get_children(self, **kwargs):
-        return self.left, self.right
-
-    def _cache_key(self, **kw):
-        return (
-            BinaryExpression,
-            self.left._cache_key(**kw),
-            self.right._cache_key(**kw),
-        )
-
     def self_group(self, against=None):
         # type: (Optional[Any]) -> ClauseElement
 
@@ -3406,6 +3307,12 @@ class Slice(ColumnElement):
 
     __visit_name__ = "slice"
 
+    _traverse_internals = [
+        ("start", InternalTraversal.dp_plain_obj),
+        ("stop", InternalTraversal.dp_plain_obj),
+        ("step", InternalTraversal.dp_plain_obj),
+    ]
+
     def __init__(self, start, stop, step):
         self.start = start
         self.stop = stop
@@ -3417,9 +3324,6 @@ class Slice(ColumnElement):
         assert against is operator.getitem
         return self
 
-    def _cache_key(self, **kw):
-        return (Slice, self.start, self.stop, self.step)
-
 
 class IndexExpression(BinaryExpression):
     """Represent the class of expressions that are like an "index" operation.
@@ -3444,6 +3348,11 @@ class GroupedElement(ClauseElement):
 class Grouping(GroupedElement, ColumnElement):
     """Represent a grouping within a column expression"""
 
+    _traverse_internals = [
+        ("element", InternalTraversal.dp_clauseelement),
+        ("type", InternalTraversal.dp_type),
+    ]
+
     def __init__(self, element):
         self.element = element
         self.type = getattr(element, "type", type_api.NULLTYPE)
@@ -3460,15 +3369,6 @@ class Grouping(GroupedElement, ColumnElement):
     def _label(self):
         return getattr(self.element, "_label", None) or self.anon_label
 
-    def _copy_internals(self, clone=_clone, **kw):
-        self.element = clone(self.element, **kw)
-
-    def get_children(self, **kwargs):
-        return (self.element,)
-
-    def _cache_key(self, **kw):
-        return (Grouping, self.element._cache_key(**kw))
-
     @property
     def _from_objects(self):
         return self.element._from_objects
@@ -3501,6 +3401,14 @@ class Over(ColumnElement):
 
     __visit_name__ = "over"
 
+    _traverse_internals = [
+        ("element", InternalTraversal.dp_clauseelement),
+        ("order_by", InternalTraversal.dp_clauseelement),
+        ("partition_by", InternalTraversal.dp_clauseelement),
+        ("range_", InternalTraversal.dp_plain_obj),
+        ("rows", InternalTraversal.dp_plain_obj),
+    ]
+
     order_by = None
     partition_by = None
 
@@ -3667,30 +3575,6 @@ class Over(ColumnElement):
     def type(self):
         return self.element.type
 
-    def get_children(self, **kwargs):
-        return [
-            c
-            for c in (self.element, self.partition_by, self.order_by)
-            if c is not None
-        ]
-
-    def _cache_key(self, **kw):
-        return (
-            (Over,)
-            + tuple(
-                e._cache_key(**kw) if e is not None else None
-                for e in (self.element, self.partition_by, self.order_by)
-            )
-            + (self.range_, self.rows)
-        )
-
-    def _copy_internals(self, clone=_clone, **kw):
-        self.element = clone(self.element, **kw)
-        if self.partition_by is not None:
-            self.partition_by = clone(self.partition_by, **kw)
-        if self.order_by is not None:
-            self.order_by = clone(self.order_by, **kw)
-
     @property
     def _from_objects(self):
         return list(
@@ -3723,6 +3607,11 @@ class WithinGroup(ColumnElement):
 
     __visit_name__ = "withingroup"
 
+    _traverse_internals = [
+        ("element", InternalTraversal.dp_clauseelement),
+        ("order_by", InternalTraversal.dp_clauseelement),
+    ]
+
     order_by = None
 
     def __init__(self, element, *order_by):
@@ -3791,25 +3680,6 @@ class WithinGroup(ColumnElement):
         else:
             return self.element.type
 
-    def get_children(self, **kwargs):
-        return [c for c in (self.element, self.order_by) if c is not None]
-
-    def _cache_key(self, **kw):
-        return (
-            WithinGroup,
-            self.element._cache_key(**kw)
-            if self.element is not None
-            else None,
-            self.order_by._cache_key(**kw)
-            if self.order_by is not None
-            else None,
-        )
-
-    def _copy_internals(self, clone=_clone, **kw):
-        self.element = clone(self.element, **kw)
-        if self.order_by is not None:
-            self.order_by = clone(self.order_by, **kw)
-
     @property
     def _from_objects(self):
         return list(
@@ -3845,6 +3715,11 @@ class FunctionFilter(ColumnElement):
 
     __visit_name__ = "funcfilter"
 
+    _traverse_internals = [
+        ("func", InternalTraversal.dp_clauseelement),
+        ("criterion", InternalTraversal.dp_clauseelement),
+    ]
+
     criterion = None
 
     def __init__(self, func, *criterion):
@@ -3932,23 +3807,6 @@ class FunctionFilter(ColumnElement):
     def type(self):
         return self.func.type
 
-    def get_children(self, **kwargs):
-        return [c for c in (self.func, self.criterion) if c is not None]
-
-    def _copy_internals(self, clone=_clone, **kw):
-        self.func = clone(self.func, **kw)
-        if self.criterion is not None:
-            self.criterion = clone(self.criterion, **kw)
-
-    def _cache_key(self, **kw):
-        return (
-            FunctionFilter,
-            self.func._cache_key(**kw),
-            self.criterion._cache_key(**kw)
-            if self.criterion is not None
-            else None,
-        )
-
     @property
     def _from_objects(self):
         return list(
@@ -3962,7 +3820,7 @@ class FunctionFilter(ColumnElement):
         )
 
 
-class Label(roles.LabeledColumnExprRole, ColumnElement):
+class Label(HasMemoized, roles.LabeledColumnExprRole, ColumnElement):
     """Represents a column label (AS).
 
     Represent a label, as typically applied to any column-level
@@ -3972,6 +3830,14 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
 
     __visit_name__ = "label"
 
+    _traverse_internals = [
+        ("name", InternalTraversal.dp_anon_name),
+        ("_type", InternalTraversal.dp_type),
+        ("_element", InternalTraversal.dp_clauseelement),
+    ]
+
+    _memoized_property = util.group_expirable_memoized_property()
+
     def __init__(self, name, element, type_=None):
         """Return a :class:`Label` object for the
         given :class:`.ColumnElement`.
@@ -4010,14 +3876,11 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
     def __reduce__(self):
         return self.__class__, (self.name, self._element, self._type)
 
-    def _cache_key(self, **kw):
-        return (Label, self.element._cache_key(**kw), self._resolve_label)
-
     @util.memoized_property
     def _is_implicitly_boolean(self):
         return self.element._is_implicitly_boolean
 
-    @util.memoized_property
+    @_memoized_property
     def _allow_label_resolve(self):
         return self.element._allow_label_resolve
 
@@ -4031,7 +3894,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
             self._type or getattr(self._element, "type", None)
         )
 
-    @util.memoized_property
+    @_memoized_property
     def element(self):
         return self._element.self_group(against=operators.as_)
 
@@ -4057,13 +3920,9 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
     def foreign_keys(self):
         return self.element.foreign_keys
 
-    def get_children(self, **kwargs):
-        return (self.element,)
-
     def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
+        self._reset_memoizations()
         self._element = clone(self._element, **kw)
-        self.__dict__.pop("element", None)
-        self.__dict__.pop("_allow_label_resolve", None)
         if anonymize_labels:
             self.name = self._resolve_label = _anonymous_label(
                 "%%(%d %s)s"
@@ -4124,6 +3983,13 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
 
     __visit_name__ = "column"
 
+    _traverse_internals = [
+        ("name", InternalTraversal.dp_string),
+        ("type", InternalTraversal.dp_type),
+        ("table", InternalTraversal.dp_clauseelement),
+        ("is_literal", InternalTraversal.dp_boolean),
+    ]
+
     onupdate = default = server_default = server_onupdate = None
 
     _is_multiparam_column = False
@@ -4254,14 +4120,6 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
 
     table = property(_get_table, _set_table)
 
-    def _cache_key(self, **kw):
-        return (
-            self.name,
-            self.table.name if self.table is not None else None,
-            self.is_literal,
-            self.type._cache_key,
-        )
-
     @_memoized_property
     def _from_objects(self):
         t = self.table
@@ -4395,12 +4253,11 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
 class CollationClause(ColumnElement):
     __visit_name__ = "collation"
 
+    _traverse_internals = [("collation", InternalTraversal.dp_string)]
+
     def __init__(self, collation):
         self.collation = collation
 
-    def _cache_key(self, **kw):
-        return (CollationClause, self.collation)
-
 
 class _IdentifiedClause(Executable, ClauseElement):
 
index 7ce822669c17a5fc54a3892f1477449634df0070..08e69f075af73a94eaef9236342204d05f6a1965 100644 (file)
@@ -86,7 +86,6 @@ __all__ = [
 from .base import _from_objects  # noqa
 from .base import ColumnCollection  # noqa
 from .base import Executable  # noqa
-from .base import Generative  # noqa
 from .base import PARSE_AUTOCOMMIT  # noqa
 from .dml import Delete  # noqa
 from .dml import Insert  # noqa
@@ -242,7 +241,6 @@ _UnaryExpression = UnaryExpression
 _Case = Case
 _Tuple = Tuple
 _Over = Over
-_Generative = Generative
 _TypeClause = TypeClause
 _Extract = Extract
 _Exists = Exists
index cbc8e539fa1f5e47dc85f8886077ace45ab2f67c..96e64dc284f56039f37f364fdffdbce9eba83d56 100644 (file)
@@ -17,7 +17,6 @@ from . import sqltypes
 from . import util as sqlutil
 from .base import ColumnCollection
 from .base import Executable
-from .elements import _clone
 from .elements import _type_from_args
 from .elements import BinaryExpression
 from .elements import BindParameter
@@ -33,7 +32,8 @@ from .elements import WithinGroup
 from .selectable import Alias
 from .selectable import FromClause
 from .selectable import Select
-from .visitors import VisitableType
+from .visitors import InternalTraversal
+from .visitors import TraversibleType
 from .. import util
 
 
@@ -78,10 +78,14 @@ class FunctionElement(Executable, ColumnElement, FromClause):
 
     """
 
+    _traverse_internals = [("clause_expr", InternalTraversal.dp_clauseelement)]
+
     packagenames = ()
 
     _has_args = False
 
+    _memoized_property = FromClause._memoized_property
+
     def __init__(self, *clauses, **kwargs):
         r"""Construct a :class:`.FunctionElement`.
 
@@ -136,7 +140,7 @@ class FunctionElement(Executable, ColumnElement, FromClause):
         col = self.label(None)
         return ColumnCollection(columns=[(col.key, col)])
 
-    @util.memoized_property
+    @_memoized_property
     def clauses(self):
         """Return the underlying :class:`.ClauseList` which contains
         the arguments for this :class:`.FunctionElement`.
@@ -283,17 +287,6 @@ class FunctionElement(Executable, ColumnElement, FromClause):
     def _from_objects(self):
         return self.clauses._from_objects
 
-    def get_children(self, **kwargs):
-        return (self.clause_expr,)
-
-    def _cache_key(self, **kw):
-        return (FunctionElement, self.clause_expr._cache_key(**kw))
-
-    def _copy_internals(self, clone=_clone, **kw):
-        self.clause_expr = clone(self.clause_expr, **kw)
-        self._reset_exported()
-        FunctionElement.clauses._reset(self)
-
     def within_group_type(self, within_group):
         """For types that define their return type as based on the criteria
         within a WITHIN GROUP (ORDER BY) expression, called by the
@@ -404,6 +397,13 @@ class FunctionElement(Executable, ColumnElement, FromClause):
 
 
 class FunctionAsBinary(BinaryExpression):
+    _traverse_internals = [
+        ("sql_function", InternalTraversal.dp_clauseelement),
+        ("left_index", InternalTraversal.dp_plain_obj),
+        ("right_index", InternalTraversal.dp_plain_obj),
+        ("modifiers", InternalTraversal.dp_plain_dict),
+    ]
+
     def __init__(self, fn, left_index, right_index):
         self.sql_function = fn
         self.left_index = left_index
@@ -431,20 +431,6 @@ class FunctionAsBinary(BinaryExpression):
     def right(self, value):
         self.sql_function.clauses.clauses[self.right_index - 1] = value
 
-    def _copy_internals(self, clone=_clone, **kw):
-        self.sql_function = clone(self.sql_function, **kw)
-
-    def get_children(self, **kw):
-        yield self.sql_function
-
-    def _cache_key(self, **kw):
-        return (
-            FunctionAsBinary,
-            self.sql_function._cache_key(**kw),
-            self.left_index,
-            self.right_index,
-        )
-
 
 class _FunctionGenerator(object):
     """Generate SQL function expressions.
@@ -606,6 +592,12 @@ class Function(FunctionElement):
 
     __visit_name__ = "function"
 
+    _traverse_internals = FunctionElement._traverse_internals + [
+        ("packagenames", InternalTraversal.dp_plain_obj),
+        ("name", InternalTraversal.dp_string),
+        ("type", InternalTraversal.dp_type),
+    ]
+
     def __init__(self, name, *clauses, **kw):
         """Construct a :class:`.Function`.
 
@@ -630,15 +622,8 @@ class Function(FunctionElement):
             unique=True,
         )
 
-    def _cache_key(self, **kw):
-        return (
-            (Function,) + tuple(self.packagenames)
-            if self.packagenames
-            else () + (self.name, self.clause_expr._cache_key(**kw))
-        )
-
 
-class _GenericMeta(VisitableType):
+class _GenericMeta(TraversibleType):
     def __init__(cls, clsname, bases, clsdict):
         if annotation.Annotated not in cls.__mro__:
             cls.name = name = clsdict.get("name", clsname)
@@ -764,6 +749,10 @@ class next_value(GenericFunction):
     type = sqltypes.Integer()
     name = "next_value"
 
+    _traverse_internals = [
+        ("sequence", InternalTraversal.dp_named_ddl_element)
+    ]
+
     def __init__(self, seq, **kw):
         assert isinstance(
             seq, schema.Sequence
@@ -771,21 +760,12 @@ class next_value(GenericFunction):
         self._bind = kw.get("bind", None)
         self.sequence = seq
 
-    def _cache_key(self, **kw):
-        return (next_value, self.sequence.name)
-
     def compare(self, other, **kw):
         return (
             isinstance(other, next_value)
             and self.sequence.name == other.sequence.name
         )
 
-    def get_children(self, **kwargs):
-        return []
-
-    def _copy_internals(self, **kw):
-        pass
-
     @property
     def _from_objects(self):
         return []
index 4e8f4a39701c908138e9c58da822fe63d5f18286..ee7dc61ce8973ff378b6adcfd8ca1c2478c8c3a1 100644 (file)
@@ -50,6 +50,7 @@ from .elements import ColumnElement
 from .elements import quoted_name
 from .elements import TextClause
 from .selectable import TableClause
+from .visitors import InternalTraversal
 from .. import event
 from .. import exc
 from .. import inspection
@@ -425,6 +426,21 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
 
     __visit_name__ = "table"
 
+    _traverse_internals = TableClause._traverse_internals + [
+        ("schema", InternalTraversal.dp_string)
+    ]
+
+    def _gen_cache_key(self, anon_map, bindparams):
+        return (self,)
+
+    @util.deprecated_params(
+        useexisting=(
+            "0.7",
+            "The :paramref:`.Table.useexisting` parameter is deprecated and "
+            "will be removed in a future release.  Please use "
+            ":paramref:`.Table.extend_existing`.",
+        )
+    )
     def __new__(cls, *args, **kw):
         if not args:
             # python3k pickle seems to call this
@@ -763,6 +779,8 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
     def get_children(
         self, column_collections=True, schema_visitor=False, **kw
     ):
+        # TODO: consider that we probably don't need column_collections=True
+        # at all, it does not seem to impact anything
         if not schema_visitor:
             return TableClause.get_children(
                 self, column_collections=column_collections, **kw
index 6a7413fc098ae7cb9e70ba252cd81eef1baab97a..4b3844eec142ccb3f00b44734d72b7bcc2a22266 100644 (file)
@@ -31,6 +31,7 @@ from .base import ColumnSet
 from .base import DedupeColumnCollection
 from .base import Executable
 from .base import Generative
+from .base import HasMemoized
 from .base import Immutable
 from .coercions import _document_text_coercion
 from .elements import _anonymous_label
@@ -39,11 +40,13 @@ from .elements import and_
 from .elements import BindParameter
 from .elements import ClauseElement
 from .elements import ClauseList
+from .elements import ColumnClause
 from .elements import GroupedElement
 from .elements import Grouping
 from .elements import literal_column
 from .elements import True_
 from .elements import UnaryExpression
+from .visitors import InternalTraversal
 from .. import exc
 from .. import util
 
@@ -201,6 +204,8 @@ class Selectable(ReturnsRows):
 class HasPrefixes(object):
     _prefixes = ()
 
+    _traverse_internals = [("_prefixes", InternalTraversal.dp_prefix_sequence)]
+
     @_generative
     @_document_text_coercion(
         "expr",
@@ -252,6 +257,8 @@ class HasPrefixes(object):
 class HasSuffixes(object):
     _suffixes = ()
 
+    _traverse_internals = [("_suffixes", InternalTraversal.dp_prefix_sequence)]
+
     @_generative
     @_document_text_coercion(
         "expr",
@@ -295,7 +302,7 @@ class HasSuffixes(object):
         )
 
 
-class FromClause(roles.AnonymizedFromClauseRole, Selectable):
+class FromClause(HasMemoized, roles.AnonymizedFromClauseRole, Selectable):
     """Represent an element that can be used within the ``FROM``
     clause of a ``SELECT`` statement.
 
@@ -529,11 +536,6 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
         """
         return getattr(self, "name", self.__class__.__name__ + " object")
 
-    def _reset_exported(self):
-        """delete memoized collections when a FromClause is cloned."""
-
-        self._memoized_property.expire_instance(self)
-
     def _generate_fromclause_column_proxies(self, fromclause):
         fromclause._columns._populate_separate_keys(
             col._make_proxy(fromclause) for col in self.c
@@ -668,6 +670,14 @@ class Join(FromClause):
 
     __visit_name__ = "join"
 
+    _traverse_internals = [
+        ("left", InternalTraversal.dp_clauseelement),
+        ("right", InternalTraversal.dp_clauseelement),
+        ("onclause", InternalTraversal.dp_clauseelement),
+        ("isouter", InternalTraversal.dp_boolean),
+        ("full", InternalTraversal.dp_boolean),
+    ]
+
     _is_join = True
 
     def __init__(self, left, right, onclause=None, isouter=False, full=False):
@@ -805,25 +815,6 @@ class Join(FromClause):
         self.left._refresh_for_new_column(column)
         self.right._refresh_for_new_column(column)
 
-    def _copy_internals(self, clone=_clone, **kw):
-        self._reset_exported()
-        self.left = clone(self.left, **kw)
-        self.right = clone(self.right, **kw)
-        self.onclause = clone(self.onclause, **kw)
-
-    def get_children(self, **kwargs):
-        return self.left, self.right, self.onclause
-
-    def _cache_key(self, **kw):
-        return (
-            Join,
-            self.isouter,
-            self.full,
-            self.left._cache_key(**kw),
-            self.right._cache_key(**kw),
-            self.onclause._cache_key(**kw),
-        )
-
     def _match_primaries(self, left, right):
         if isinstance(left, Join):
             left_right = left.right
@@ -1175,6 +1166,11 @@ class AliasedReturnsRows(FromClause):
     _is_from_container = True
     named_with_column = True
 
+    _traverse_internals = [
+        ("element", InternalTraversal.dp_clauseelement),
+        ("name", InternalTraversal.dp_anon_name),
+    ]
+
     def __init__(self, *arg, **kw):
         raise NotImplementedError(
             "The %s class is not intended to be constructed "
@@ -1243,18 +1239,13 @@ class AliasedReturnsRows(FromClause):
 
     def _copy_internals(self, clone=_clone, **kw):
         element = clone(self.element, **kw)
+
+        # the element clone is usually against a Table that returns the
+        # same object.  don't reset exported .c. collections and other
+        # memoized details if nothing changed
         if element is not self.element:
             self._reset_exported()
-        self.element = element
-
-    def get_children(self, column_collections=True, **kw):
-        if column_collections:
-            for c in self.c:
-                yield c
-        yield self.element
-
-    def _cache_key(self, **kw):
-        return (self.__class__, self.element._cache_key(**kw), self._orig_name)
+            self.element = element
 
     @property
     def _from_objects(self):
@@ -1396,6 +1387,11 @@ class TableSample(AliasedReturnsRows):
 
     __visit_name__ = "tablesample"
 
+    _traverse_internals = AliasedReturnsRows._traverse_internals + [
+        ("sampling", InternalTraversal.dp_clauseelement),
+        ("seed", InternalTraversal.dp_clauseelement),
+    ]
+
     @classmethod
     def _factory(cls, selectable, sampling, name=None, seed=None):
         """Return a :class:`.TableSample` object.
@@ -1466,6 +1462,16 @@ class CTE(Generative, HasSuffixes, AliasedReturnsRows):
 
     __visit_name__ = "cte"
 
+    _traverse_internals = (
+        AliasedReturnsRows._traverse_internals
+        + [
+            ("_cte_alias", InternalTraversal.dp_clauseelement),
+            ("_restates", InternalTraversal.dp_clauseelement_unordered_set),
+            ("recursive", InternalTraversal.dp_boolean),
+        ]
+        + HasSuffixes._traverse_internals
+    )
+
     @classmethod
     def _factory(cls, selectable, name=None, recursive=False):
         r"""Return a new :class:`.CTE`, or Common Table Expression instance.
@@ -1495,15 +1501,13 @@ class CTE(Generative, HasSuffixes, AliasedReturnsRows):
 
     def _copy_internals(self, clone=_clone, **kw):
         super(CTE, self)._copy_internals(clone, **kw)
+        # TODO: I don't like that we can't use the traversal data here
         if self._cte_alias is not None:
             self._cte_alias = clone(self._cte_alias, **kw)
         self._restates = frozenset(
             [clone(elem, **kw) for elem in self._restates]
         )
 
-    def _cache_key(self, *arg, **kw):
-        raise NotImplementedError("TODO")
-
     def alias(self, name=None, flat=False):
         """Return an :class:`.Alias` of this :class:`.CTE`.
 
@@ -1764,6 +1768,8 @@ class Subquery(AliasedReturnsRows):
 class FromGrouping(GroupedElement, FromClause):
     """Represent a grouping of a FROM clause"""
 
+    _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+
     def __init__(self, element):
         self.element = coercions.expect(roles.FromClauseRole, element)
 
@@ -1792,15 +1798,6 @@ class FromGrouping(GroupedElement, FromClause):
     def _hide_froms(self):
         return self.element._hide_froms
 
-    def get_children(self, **kwargs):
-        return (self.element,)
-
-    def _copy_internals(self, clone=_clone, **kw):
-        self.element = clone(self.element, **kw)
-
-    def _cache_key(self, **kw):
-        return (FromGrouping, self.element._cache_key(**kw))
-
     @property
     def _from_objects(self):
         return self.element._from_objects
@@ -1843,6 +1840,14 @@ class TableClause(Immutable, FromClause):
 
     __visit_name__ = "table"
 
+    _traverse_internals = [
+        (
+            "columns",
+            InternalTraversal.dp_fromclause_canonical_column_collection,
+        ),
+        ("name", InternalTraversal.dp_string),
+    ]
+
     named_with_column = True
 
     implicit_returning = False
@@ -1895,17 +1900,6 @@ class TableClause(Immutable, FromClause):
         self._columns.add(c)
         c.table = self
 
-    def get_children(self, column_collections=True, **kwargs):
-        if column_collections:
-            return [c for c in self.c]
-        else:
-            return []
-
-    def _cache_key(self, **kw):
-        return (TableClause, self.name) + tuple(
-            col._cache_key(**kw) for col in self._columns
-        )
-
     @util.dependencies("sqlalchemy.sql.dml")
     def insert(self, dml, values=None, inline=False, **kwargs):
         """Generate an :func:`.insert` construct against this
@@ -1965,6 +1959,13 @@ class TableClause(Immutable, FromClause):
 
 
 class ForUpdateArg(ClauseElement):
+    _traverse_internals = [
+        ("of", InternalTraversal.dp_clauseelement_list),
+        ("nowait", InternalTraversal.dp_boolean),
+        ("read", InternalTraversal.dp_boolean),
+        ("skip_locked", InternalTraversal.dp_boolean),
+    ]
+
     @classmethod
     def parse_legacy_select(self, arg):
         """Parse the for_update argument of :func:`.select`.
@@ -2029,19 +2030,6 @@ class ForUpdateArg(ClauseElement):
     def __hash__(self):
         return id(self)
 
-    def _copy_internals(self, clone=_clone, **kw):
-        if self.of is not None:
-            self.of = [clone(col, **kw) for col in self.of]
-
-    def _cache_key(self, **kw):
-        return (
-            ForUpdateArg,
-            self.nowait,
-            self.read,
-            self.skip_locked,
-            self.of._cache_key(**kw) if self.of is not None else None,
-        )
-
     def __init__(
         self,
         nowait=False,
@@ -2074,6 +2062,7 @@ class SelectBase(
     roles.DMLSelectRole,
     roles.CompoundElementRole,
     roles.InElementRole,
+    HasMemoized,
     HasCTE,
     Executable,
     SupportsCloneAnnotations,
@@ -2092,9 +2081,6 @@ class SelectBase(
 
     _memoized_property = util.group_expirable_memoized_property()
 
-    def _reset_memoizations(self):
-        self._memoized_property.expire_instance(self)
-
     def _generate_fromclause_column_proxies(self, fromclause):
         # type: (FromClause)
         raise NotImplementedError()
@@ -2339,6 +2325,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
     """
 
     __visit_name__ = "grouping"
+    _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
 
     _is_select_container = True
 
@@ -2350,9 +2337,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
     def select_statement(self):
         return self.element
 
-    def get_children(self, **kwargs):
-        return (self.element,)
-
     def self_group(self, against=None):
         # type: (Optional[Any]) -> FromClause
         return self
@@ -2377,12 +2361,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
         """
         return self.element.selected_columns
 
-    def _copy_internals(self, clone=_clone, **kw):
-        self.element = clone(self.element, **kw)
-
-    def _cache_key(self, **kw):
-        return (SelectStatementGrouping, self.element._cache_key(**kw))
-
     @property
     def _from_objects(self):
         return self.element._from_objects
@@ -2758,9 +2736,6 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
     def _label_resolve_dict(self):
         raise NotImplementedError()
 
-    def _copy_internals(self, clone=_clone, **kw):
-        raise NotImplementedError()
-
 
 class CompoundSelect(GenerativeSelect):
     """Forms the basis of ``UNION``, ``UNION ALL``, and other
@@ -2785,6 +2760,16 @@ class CompoundSelect(GenerativeSelect):
 
     __visit_name__ = "compound_select"
 
+    _traverse_internals = [
+        ("selects", InternalTraversal.dp_clauseelement_list),
+        ("_limit_clause", InternalTraversal.dp_clauseelement),
+        ("_offset_clause", InternalTraversal.dp_clauseelement),
+        ("_order_by_clause", InternalTraversal.dp_clauseelement),
+        ("_group_by_clause", InternalTraversal.dp_clauseelement),
+        ("_for_update_arg", InternalTraversal.dp_clauseelement),
+        ("keyword", InternalTraversal.dp_string),
+    ] + SupportsCloneAnnotations._traverse_internals
+
     UNION = util.symbol("UNION")
     UNION_ALL = util.symbol("UNION ALL")
     EXCEPT = util.symbol("EXCEPT")
@@ -3004,47 +2989,6 @@ class CompoundSelect(GenerativeSelect):
         """
         return self.selects[0].selected_columns
 
-    def _copy_internals(self, clone=_clone, **kw):
-        self._reset_memoizations()
-        self.selects = [clone(s, **kw) for s in self.selects]
-        if hasattr(self, "_col_map"):
-            del self._col_map
-        for attr in (
-            "_limit_clause",
-            "_offset_clause",
-            "_order_by_clause",
-            "_group_by_clause",
-            "_for_update_arg",
-        ):
-            if getattr(self, attr) is not None:
-                setattr(self, attr, clone(getattr(self, attr), **kw))
-
-    def get_children(self, **kwargs):
-        return [self._order_by_clause, self._group_by_clause] + list(
-            self.selects
-        )
-
-    def _cache_key(self, **kw):
-        return (
-            (CompoundSelect, self.keyword)
-            + tuple(stmt._cache_key(**kw) for stmt in self.selects)
-            + (
-                self._order_by_clause._cache_key(**kw)
-                if self._order_by_clause is not None
-                else None,
-            )
-            + (
-                self._group_by_clause._cache_key(**kw)
-                if self._group_by_clause is not None
-                else None,
-            )
-            + (
-                self._for_update_arg._cache_key(**kw)
-                if self._for_update_arg is not None
-                else None,
-            )
-        )
-
     def bind(self):
         if self._bind:
             return self._bind
@@ -3193,11 +3137,35 @@ class Select(
     _hints = util.immutabledict()
     _statement_hints = ()
     _distinct = False
-    _from_cloned = None
+    _distinct_on = ()
     _correlate = ()
     _correlate_except = None
     _memoized_property = SelectBase._memoized_property
 
+    _traverse_internals = (
+        [
+            ("_from_obj", InternalTraversal.dp_fromclause_ordered_set),
+            ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+            ("_whereclause", InternalTraversal.dp_clauseelement),
+            ("_having", InternalTraversal.dp_clauseelement),
+            ("_order_by_clause", InternalTraversal.dp_clauseelement_list),
+            ("_group_by_clause", InternalTraversal.dp_clauseelement_list),
+            ("_correlate", InternalTraversal.dp_clauseelement_unordered_set),
+            (
+                "_correlate_except",
+                InternalTraversal.dp_clauseelement_unordered_set,
+            ),
+            ("_for_update_arg", InternalTraversal.dp_clauseelement),
+            ("_statement_hints", InternalTraversal.dp_statement_hint_list),
+            ("_hints", InternalTraversal.dp_table_hint_list),
+            ("_distinct", InternalTraversal.dp_boolean),
+            ("_distinct_on", InternalTraversal.dp_clauseelement_list),
+        ]
+        + HasPrefixes._traverse_internals
+        + HasSuffixes._traverse_internals
+        + SupportsCloneAnnotations._traverse_internals
+    )
+
     @util.deprecated_params(
         autocommit=(
             "0.6",
@@ -3416,13 +3384,14 @@ class Select(
         """
         self._auto_correlate = correlate
         if distinct is not False:
-            if distinct is True:
-                self._distinct = True
-            else:
-                self._distinct = [
-                    coercions.expect(roles.WhereHavingRole, e)
-                    for e in util.to_list(distinct)
-                ]
+            self._distinct = True
+            if not isinstance(distinct, bool):
+                self._distinct_on = tuple(
+                    [
+                        coercions.expect(roles.WhereHavingRole, e)
+                        for e in util.to_list(distinct)
+                    ]
+                )
 
         if from_obj is not None:
             self._from_obj = util.OrderedSet(
@@ -3472,15 +3441,17 @@ class Select(
 
         GenerativeSelect.__init__(self, **kwargs)
 
+    # @_memoized_property
     @property
     def _froms(self):
-        # would love to cache this,
-        # but there's just enough edge cases, particularly now that
-        # declarative encourages construction of SQL expressions
-        # without tables present, to just regen this each time.
+        # current roadblock to caching is two tests that test that the
+        # SELECT can be compiled to a string, then a Table is created against
+        # columns, then it can be compiled again and works.  this is somewhat
+        # valid as people make select() against declarative class where
+        # columns don't have their Table yet and perhaps some operations
+        # call upon _froms and cache it too soon.
         froms = []
         seen = set()
-        translate = self._from_cloned
 
         for item in itertools.chain(
             _from_objects(*self._raw_columns),
@@ -3493,8 +3464,6 @@ class Select(
                 raise exc.InvalidRequestError(
                     "select() construct refers to itself as a FROM"
                 )
-            if translate and item in translate:
-                item = translate[item]
             if not seen.intersection(item._cloned_set):
                 froms.append(item)
             seen.update(item._cloned_set)
@@ -3518,15 +3487,6 @@ class Select(
             itertools.chain(*[_expand_cloned(f._hide_froms) for f in froms])
         )
         if toremove:
-            # if we're maintaining clones of froms,
-            # add the copies out to the toremove list.  only include
-            # clones that are lexical equivalents.
-            if self._from_cloned:
-                toremove.update(
-                    self._from_cloned[f]
-                    for f in toremove.intersection(self._from_cloned)
-                    if self._from_cloned[f]._is_lexical_equivalent(f)
-                )
             # filter out to FROM clauses not in the list,
             # using a list to maintain ordering
             froms = [f for f in froms if f not in toremove]
@@ -3707,7 +3667,6 @@ class Select(
         return False
 
     def _copy_internals(self, clone=_clone, **kw):
-
         # Select() object has been cloned and probably adapted by the
         # given clone function.  Apply the cloning function to internal
         # objects
@@ -3719,37 +3678,42 @@ class Select(
         # as of 0.7.4 we also put the current version of _froms, which
         # gets cleared on each generation.  previously we were "baking"
         # _froms into self._from_obj.
-        self._from_cloned = from_cloned = dict(
-            (f, clone(f, **kw)) for f in self._from_obj.union(self._froms)
-        )
 
-        # 3. update persistent _from_obj with the cloned versions.
-        self._from_obj = util.OrderedSet(
-            from_cloned[f] for f in self._from_obj
+        all_the_froms = list(
+            itertools.chain(
+                _from_objects(*self._raw_columns),
+                _from_objects(self._whereclause)
+                if self._whereclause is not None
+                else (),
+            )
         )
+        new_froms = {f: clone(f, **kw) for f in all_the_froms}
+        # copy FROM collections
 
-        # the _correlate collection is done separately, what can happen
-        # here is the same item is _correlate as in _from_obj but the
-        # _correlate version has an annotation on it - (specifically
-        # RelationshipProperty.Comparator._criterion_exists() does
-        # this). Also keep _correlate liberally open with its previous
-        # contents, as this set is used for matching, not rendering.
-        self._correlate = set(clone(f) for f in self._correlate).union(
-            self._correlate
-        )
+        self._from_obj = util.OrderedSet(
+            clone(f, **kw) for f in self._from_obj
+        ).union(f for f in new_froms.values() if isinstance(f, Join))
 
-        # do something similar for _correlate_except - this is a more
-        # unusual case but same idea applies
+        self._correlate = set(clone(f) for f in self._correlate)
         if self._correlate_except:
             self._correlate_except = set(
                 clone(f) for f in self._correlate_except
-            ).union(self._correlate_except)
+            )
 
         # 4. clone other things.   The difficulty here is that Column
-        # objects are not actually cloned, and refer to their original
-        # .table, resulting in the wrong "from" parent after a clone
-        # operation.  Hence _from_cloned and _from_obj supersede what is
-        # present here.
+        # objects are usually not altered by a straight clone because they
+        # are dependent on the FROM cloning we just did above in order to
+        # be targeted correctly, or a new FROM we have might be a JOIN
+        # object which doesn't have its own columns.  so give the cloner a
+        # hint.
+        def replace(obj, **kw):
+            if isinstance(obj, ColumnClause) and obj.table in new_froms:
+                newelem = new_froms[obj.table].corresponding_column(obj)
+                return newelem
+
+        kw["replace"] = replace
+
+        # TODO: I'd still like to try to leverage the traversal data
         self._raw_columns = [clone(c, **kw) for c in self._raw_columns]
         for attr in (
             "_limit_clause",
@@ -3763,67 +3727,12 @@ class Select(
             if getattr(self, attr) is not None:
                 setattr(self, attr, clone(getattr(self, attr), **kw))
 
-        # erase _froms collection,
-        # etc.
         self._reset_memoizations()
 
     def get_children(self, **kwargs):
-        """return child elements as per the ClauseElement specification."""
-
-        return (
-            self._raw_columns
-            + list(self._froms)
-            + [
-                x
-                for x in (
-                    self._whereclause,
-                    self._having,
-                    self._order_by_clause,
-                    self._group_by_clause,
-                )
-                if x is not None
-            ]
-        )
-
-    def _cache_key(self, **kw):
-        return (
-            (Select,)
-            + ("raw_columns",)
-            + tuple(elem._cache_key(**kw) for elem in self._raw_columns)
-            + ("elements",)
-            + tuple(
-                elem._cache_key(**kw) if elem is not None else None
-                for elem in (
-                    self._whereclause,
-                    self._having,
-                    self._order_by_clause,
-                    self._group_by_clause,
-                )
-            )
-            + ("from_obj",)
-            + tuple(elem._cache_key(**kw) for elem in self._from_obj)
-            + ("correlate",)
-            + tuple(
-                elem._cache_key(**kw)
-                for elem in (
-                    self._correlate if self._correlate is not None else ()
-                )
-            )
-            + ("correlate_except",)
-            + tuple(
-                elem._cache_key(**kw)
-                for elem in (
-                    self._correlate_except
-                    if self._correlate_except is not None
-                    else ()
-                )
-            )
-            + ("for_update",),
-            (
-                self._for_update_arg._cache_key(**kw)
-                if self._for_update_arg is not None
-                else None,
-            ),
+        # TODO: define "get_children" traversal items separately?
+        return self._froms + super(Select, self).get_children(
+            omit_attrs=["_from_obj", "_correlate", "_correlate_except"]
         )
 
     @_generative
@@ -3987,10 +3896,8 @@ class Select(
         """
         if expr:
             expr = [coercions.expect(roles.ByOfRole, e) for e in expr]
-            if isinstance(self._distinct, list):
-                self._distinct = self._distinct + expr
-            else:
-                self._distinct = expr
+            self._distinct = True
+            self._distinct_on = self._distinct_on + tuple(expr)
         else:
             self._distinct = True
 
@@ -4489,6 +4396,11 @@ class TextualSelect(SelectBase):
 
     __visit_name__ = "textual_select"
 
+    _traverse_internals = [
+        ("element", InternalTraversal.dp_clauseelement),
+        ("column_args", InternalTraversal.dp_clauseelement_list),
+    ] + SupportsCloneAnnotations._traverse_internals
+
     _is_textual = True
 
     def __init__(self, text, columns, positional=False):
@@ -4534,18 +4446,6 @@ class TextualSelect(SelectBase):
             c._make_proxy(fromclause) for c in self.column_args
         )
 
-    def _copy_internals(self, clone=_clone, **kw):
-        self._reset_memoizations()
-        self.element = clone(self.element, **kw)
-
-    def get_children(self, **kw):
-        return [self.element]
-
-    def _cache_key(self, **kw):
-        return (TextualSelect, self.element._cache_key(**kw)) + tuple(
-            col._cache_key(**kw) for col in self.column_args
-        )
-
     def _scalar_type(self):
         return self.column_args[0].type
 
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
new file mode 100644 (file)
index 0000000..c0782ce
--- /dev/null
@@ -0,0 +1,768 @@
+from collections import deque
+from collections import namedtuple
+
+from . import operators
+from .visitors import ExtendedInternalTraversal
+from .visitors import InternalTraversal
+from .. import inspect
+from .. import util
+
+SKIP_TRAVERSE = util.symbol("skip_traverse")
+COMPARE_FAILED = False
+COMPARE_SUCCEEDED = True
+NO_CACHE = util.symbol("no_cache")
+
+
+def compare(obj1, obj2, **kw):
+    if kw.get("use_proxies", False):
+        strategy = ColIdentityComparatorStrategy()
+    else:
+        strategy = TraversalComparatorStrategy()
+
+    return strategy.compare(obj1, obj2, **kw)
+
+
+class HasCacheKey(object):
+    _cache_key_traversal = NO_CACHE
+
+    def _gen_cache_key(self, anon_map, bindparams):
+        """return an optional cache key.
+
+        The cache key is a tuple which can contain any series of
+        objects that are hashable and also identifies
+        this object uniquely within the presence of a larger SQL expression
+        or statement, for the purposes of caching the resulting query.
+
+        The cache key should be based on the SQL compiled structure that would
+        ultimately be produced.   That is, two structures that are composed in
+        exactly the same way should produce the same cache key; any difference
+        in the strucures that would affect the SQL string or the type handlers
+        should result in a different cache key.
+
+        If a structure cannot produce a useful cache key, it should raise
+        NotImplementedError, which will result in the entire structure
+        for which it's part of not being useful as a cache key.
+
+
+        """
+
+        if self in anon_map:
+            return (anon_map[self], self.__class__)
+
+        id_ = anon_map[self]
+
+        if self._cache_key_traversal is NO_CACHE:
+            anon_map[NO_CACHE] = True
+            return None
+
+        result = (id_, self.__class__)
+
+        for attrname, obj, meth in _cache_key_traversal.run_generated_dispatch(
+            self, self._cache_key_traversal, "_generated_cache_key_traversal"
+        ):
+            if obj is not None:
+                result += meth(attrname, obj, self, anon_map, bindparams)
+        return result
+
+    def _generate_cache_key(self):
+        """return a cache key.
+
+        The cache key is a tuple which can contain any series of
+        objects that are hashable and also identifies
+        this object uniquely within the presence of a larger SQL expression
+        or statement, for the purposes of caching the resulting query.
+
+        The cache key should be based on the SQL compiled structure that would
+        ultimately be produced.   That is, two structures that are composed in
+        exactly the same way should produce the same cache key; any difference
+        in the strucures that would affect the SQL string or the type handlers
+        should result in a different cache key.
+
+        The cache key returned by this method is an instance of
+        :class:`.CacheKey`, which consists of a tuple representing the
+        cache key, as well as a list of :class:`.BindParameter` objects
+        which are extracted from the expression.   While two expressions
+        that produce identical cache key tuples will themselves generate
+        identical SQL strings, the list of :class:`.BindParameter` objects
+        indicates the bound values which may have different values in
+        each one; these bound parameters must be consulted in order to
+        execute the statement with the correct parameters.
+
+        a :class:`.ClauseElement` structure that does not implement
+        a :meth:`._gen_cache_key` method and does not implement a
+        :attr:`.traverse_internals` attribute will not be cacheable; when
+        such an element is embedded into a larger structure, this method
+        will return None, indicating no cache key is available.
+
+        """
+        bindparams = []
+
+        _anon_map = anon_map()
+        key = self._gen_cache_key(_anon_map, bindparams)
+        if NO_CACHE in _anon_map:
+            return None
+        else:
+            return CacheKey(key, bindparams)
+
+
+class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
+    def __hash__(self):
+        return hash(self.key)
+
+    def __eq__(self, other):
+        return self.key == other.key
+
+
+def _clone(element, **kw):
+    return element._clone()
+
+
+class _CacheKey(ExtendedInternalTraversal):
+    def visit_has_cache_key(self, attrname, obj, parent, anon_map, bindparams):
+        return (attrname, obj._gen_cache_key(anon_map, bindparams))
+
+    def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
+        return self.visit_has_cache_key(
+            attrname, inspect(obj), parent, anon_map, bindparams
+        )
+
+    def visit_clauseelement(self, attrname, obj, parent, anon_map, bindparams):
+        return (attrname, obj._gen_cache_key(anon_map, bindparams))
+
+    def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
+        return (
+            attrname,
+            obj._gen_cache_key(anon_map, bindparams)
+            if isinstance(obj, HasCacheKey)
+            else obj,
+        )
+
+    def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
+        return (
+            attrname,
+            tuple(
+                elem._gen_cache_key(anon_map, bindparams)
+                if isinstance(elem, HasCacheKey)
+                else elem
+                for elem in obj
+            ),
+        )
+
+    def visit_has_cache_key_tuples(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (
+            attrname,
+            tuple(
+                tuple(
+                    elem._gen_cache_key(anon_map, bindparams)
+                    for elem in tup_elem
+                )
+                for tup_elem in obj
+            ),
+        )
+
+    def visit_has_cache_key_list(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (
+            attrname,
+            tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+        )
+
+    def visit_inspectable_list(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return self.visit_has_cache_key_list(
+            attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
+        )
+
+    def visit_clauseelement_list(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (
+            attrname,
+            tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+        )
+
+    def visit_clauseelement_tuples(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return self.visit_has_cache_key_tuples(
+            attrname, obj, parent, anon_map, bindparams
+        )
+
+    def visit_anon_name(self, attrname, obj, parent, anon_map, bindparams):
+        from . import elements
+
+        name = obj
+        if isinstance(name, elements._anonymous_label):
+            name = name.apply_map(anon_map)
+
+        return (attrname, name)
+
+    def visit_fromclause_ordered_set(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (
+            attrname,
+            tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+        )
+
+    def visit_clauseelement_unordered_set(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        cache_keys = [
+            elem._gen_cache_key(anon_map, bindparams) for elem in obj
+        ]
+        return (
+            attrname,
+            tuple(
+                sorted(cache_keys)
+            ),  # cache keys all start with (id_, class)
+        )
+
+    def visit_named_ddl_element(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (attrname, obj.name)
+
+    def visit_prefix_sequence(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (
+            attrname,
+            tuple(
+                (clause._gen_cache_key(anon_map, bindparams), strval)
+                for clause, strval in obj
+            ),
+        )
+
+    def visit_statement_hint_list(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (attrname, obj)
+
+    def visit_table_hint_list(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (
+            attrname,
+            tuple(
+                (
+                    clause._gen_cache_key(anon_map, bindparams),
+                    dialect_name,
+                    text,
+                )
+                for (clause, dialect_name), text in obj.items()
+            ),
+        )
+
+    def visit_type(self, attrname, obj, parent, anon_map, bindparams):
+        return (attrname, obj._gen_cache_key)
+
+    def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
+        return (attrname, tuple((key, obj[key]) for key in sorted(obj)))
+
+    def visit_string_clauseelement_dict(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (
+            attrname,
+            tuple(
+                (key, obj[key]._gen_cache_key(anon_map, bindparams))
+                for key in sorted(obj)
+            ),
+        )
+
+    def visit_string_multi_dict(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (
+            attrname,
+            tuple(
+                (
+                    key,
+                    value._gen_cache_key(anon_map, bindparams)
+                    if isinstance(value, HasCacheKey)
+                    else value,
+                )
+                for key, value in [(key, obj[key]) for key in sorted(obj)]
+            ),
+        )
+
+    def visit_string(self, attrname, obj, parent, anon_map, bindparams):
+        return (attrname, obj)
+
+    def visit_boolean(self, attrname, obj, parent, anon_map, bindparams):
+        return (attrname, obj)
+
+    def visit_operator(self, attrname, obj, parent, anon_map, bindparams):
+        return (attrname, obj)
+
+    def visit_plain_obj(self, attrname, obj, parent, anon_map, bindparams):
+        return (attrname, obj)
+
+    def visit_fromclause_canonical_column_collection(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (
+            attrname,
+            tuple(col._gen_cache_key(anon_map, bindparams) for col in obj),
+        )
+
+    def visit_annotations_state(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (
+            attrname,
+            tuple(
+                (
+                    key,
+                    self.dispatch(sym)(
+                        key, obj[key], obj, anon_map, bindparams
+                    ),
+                )
+                for key, sym in parent._annotation_traversals
+            ),
+        )
+
+    def visit_unknown_structure(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        anon_map[NO_CACHE] = True
+        return ()
+
+
+_cache_key_traversal = _CacheKey()
+
+
+class _CopyInternals(InternalTraversal):
+    """Generate a _copy_internals internal traversal dispatch for classes
+    with a _traverse_internals collection."""
+
+    def visit_clauseelement(self, parent, element, clone=_clone, **kw):
+        return clone(element, **kw)
+
+    def visit_clauseelement_list(self, parent, element, clone=_clone, **kw):
+        return [clone(clause, **kw) for clause in element]
+
+    def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw):
+        return [
+            tuple(clone(tup_elem, **kw) for tup_elem in elem)
+            for elem in element
+        ]
+
+    def visit_string_clauseelement_dict(
+        self, parent, element, clone=_clone, **kw
+    ):
+        return dict(
+            (key, clone(value, **kw)) for key, value in element.items()
+        )
+
+
+_copy_internals = _CopyInternals()
+
+
+class _GetChildren(InternalTraversal):
+    """Generate a _children_traversal internal traversal dispatch for classes
+    with a _traverse_internals collection."""
+
+    def visit_has_cache_key(self, element, **kw):
+        return (element,)
+
+    def visit_clauseelement(self, element, **kw):
+        return (element,)
+
+    def visit_clauseelement_list(self, element, **kw):
+        return tuple(element)
+
+    def visit_clauseelement_tuples(self, element, **kw):
+        tup = ()
+        for elem in element:
+            tup += elem
+        return tup
+
+    def visit_fromclause_canonical_column_collection(self, element, **kw):
+        if kw.get("column_collections", False):
+            return tuple(element)
+        else:
+            return ()
+
+    def visit_string_clauseelement_dict(self, element, **kw):
+        return tuple(element.values())
+
+    def visit_fromclause_ordered_set(self, element, **kw):
+        return tuple(element)
+
+    def visit_clauseelement_unordered_set(self, element, **kw):
+        return tuple(element)
+
+
+_get_children = _GetChildren()
+
+
+@util.dependencies("sqlalchemy.sql.elements")
+def _resolve_name_for_compare(elements, element, name, anon_map, **kw):
+    if isinstance(name, elements._anonymous_label):
+        name = name.apply_map(anon_map)
+
+    return name
+
+
+class anon_map(dict):
+    """A map that creates new keys for missing key access.
+
+    Produces an incrementing sequence given a series of unique keys.
+
+    This is similar to the compiler prefix_anon_map class although simpler.
+
+    Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
+    is otherwise usually used for this type of operation.
+
+    """
+
+    def __init__(self):
+        self.index = 0
+
+    def __missing__(self, key):
+        self[key] = val = str(self.index)
+        self.index += 1
+        return val
+
+
+class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
+    __slots__ = "stack", "cache", "anon_map"
+
+    def __init__(self):
+        self.stack = deque()
+        self.cache = set()
+
+    def _memoized_attr_anon_map(self):
+        return (anon_map(), anon_map())
+
+    def compare(self, obj1, obj2, **kw):
+        stack = self.stack
+        cache = self.cache
+
+        compare_annotations = kw.get("compare_annotations", False)
+
+        stack.append((obj1, obj2))
+
+        while stack:
+            left, right = stack.popleft()
+
+            if left is right:
+                continue
+            elif left is None or right is None:
+                # we know they are different so no match
+                return False
+            elif (left, right) in cache:
+                continue
+            cache.add((left, right))
+
+            visit_name = left.__visit_name__
+            if visit_name != right.__visit_name__:
+                return False
+
+            meth = getattr(self, "compare_%s" % visit_name, None)
+
+            if meth:
+                attributes_compared = meth(left, right, **kw)
+                if attributes_compared is COMPARE_FAILED:
+                    return False
+                elif attributes_compared is SKIP_TRAVERSE:
+                    continue
+
+                # attributes_compared is returned as a list of attribute
+                # names that were "handled" by the comparison method above.
+                # remaining attribute names in the _traverse_internals
+                # will be compared.
+            else:
+                attributes_compared = ()
+
+            for (
+                (left_attrname, left_visit_sym),
+                (right_attrname, right_visit_sym),
+            ) in util.zip_longest(
+                left._traverse_internals,
+                right._traverse_internals,
+                fillvalue=(None, None),
+            ):
+                if (
+                    left_attrname != right_attrname
+                    or left_visit_sym is not right_visit_sym
+                ):
+                    if not compare_annotations and (
+                        (
+                            left_visit_sym
+                            is InternalTraversal.dp_annotations_state,
+                        )
+                        or (
+                            right_visit_sym
+                            is InternalTraversal.dp_annotations_state,
+                        )
+                    ):
+                        continue
+
+                    return False
+                elif left_attrname in attributes_compared:
+                    continue
+
+                dispatch = self.dispatch(left_visit_sym)
+                left_child = getattr(left, left_attrname)
+                right_child = getattr(right, right_attrname)
+                if left_child is None:
+                    if right_child is not None:
+                        return False
+                    else:
+                        continue
+
+                comparison = dispatch(
+                    left, left_child, right, right_child, **kw
+                )
+                if comparison is COMPARE_FAILED:
+                    return False
+
+        return True
+
+    def compare_inner(self, obj1, obj2, **kw):
+        comparator = self.__class__()
+        return comparator.compare(obj1, obj2, **kw)
+
+    def visit_has_cache_key(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
+            self.anon_map[1], []
+        ):
+            return COMPARE_FAILED
+
+    def visit_clauseelement(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        self.stack.append((left, right))
+
+    def visit_fromclause_canonical_column_collection(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        for lcol, rcol in util.zip_longest(left, right, fillvalue=None):
+            self.stack.append((lcol, rcol))
+
+    def visit_fromclause_derived_column_collection(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        pass
+
+    def visit_string_clauseelement_dict(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        for lstr, rstr in util.zip_longest(
+            sorted(left), sorted(right), fillvalue=None
+        ):
+            if lstr != rstr:
+                return COMPARE_FAILED
+            self.stack.append((left[lstr], right[rstr]))
+
+    def visit_annotations_state(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        if not kw.get("compare_annotations", False):
+            return
+
+        for (lstr, lmeth), (rstr, rmeth) in util.zip_longest(
+            left_parent._annotation_traversals,
+            right_parent._annotation_traversals,
+            fillvalue=(None, None),
+        ):
+            if lstr != rstr or (lmeth is not rmeth):
+                return COMPARE_FAILED
+
+            dispatch = self.dispatch(lmeth)
+            left_child = left[lstr]
+            right_child = right[rstr]
+            if left_child is None:
+                if right_child is not None:
+                    return False
+                else:
+                    continue
+
+            comparison = dispatch(None, left_child, None, right_child, **kw)
+            if comparison is COMPARE_FAILED:
+                return comparison
+
+    def visit_clauseelement_tuples(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        for ltup, rtup in util.zip_longest(left, right, fillvalue=None):
+            if ltup is None or rtup is None:
+                return COMPARE_FAILED
+
+            for l, r in util.zip_longest(ltup, rtup, fillvalue=None):
+                self.stack.append((l, r))
+
+    def visit_clauseelement_list(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        for l, r in util.zip_longest(left, right, fillvalue=None):
+            self.stack.append((l, r))
+
+    def _compare_unordered_sequences(self, seq1, seq2, **kw):
+        if seq1 is None:
+            return seq2 is None
+
+        completed = set()
+        for clause in seq1:
+            for other_clause in set(seq2).difference(completed):
+                if self.compare_inner(clause, other_clause, **kw):
+                    completed.add(other_clause)
+                    break
+        return len(completed) == len(seq1) == len(seq2)
+
+    def visit_clauseelement_unordered_set(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        return self._compare_unordered_sequences(left, right, **kw)
+
+    def visit_fromclause_ordered_set(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        for l, r in util.zip_longest(left, right, fillvalue=None):
+            self.stack.append((l, r))
+
+    def visit_string(self, left_parent, left, right_parent, right, **kw):
+        return left == right
+
+    def visit_anon_name(self, left_parent, left, right_parent, right, **kw):
+        return _resolve_name_for_compare(
+            left_parent, left, self.anon_map[0], **kw
+        ) == _resolve_name_for_compare(
+            right_parent, right, self.anon_map[1], **kw
+        )
+
+    def visit_boolean(self, left_parent, left, right_parent, right, **kw):
+        return left == right
+
+    def visit_operator(self, left_parent, left, right_parent, right, **kw):
+        return left is right
+
+    def visit_type(self, left_parent, left, right_parent, right, **kw):
+        return left._compare_type_affinity(right)
+
+    def visit_plain_dict(self, left_parent, left, right_parent, right, **kw):
+        return left == right
+
+    def visit_plain_obj(self, left_parent, left, right_parent, right, **kw):
+        return left == right
+
+    def visit_named_ddl_element(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        if left is None:
+            if right is not None:
+                return COMPARE_FAILED
+
+        return left.name == right.name
+
+    def visit_prefix_sequence(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        for (l_clause, l_str), (r_clause, r_str) in util.zip_longest(
+            left, right, fillvalue=(None, None)
+        ):
+            if l_str != r_str:
+                return COMPARE_FAILED
+            else:
+                self.stack.append((l_clause, r_clause))
+
+    def visit_table_hint_list(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
+        right_keys = sorted(
+            right, key=lambda elem: (elem[0].fullname, elem[1])
+        )
+        for (ltable, ldialect), (rtable, rdialect) in util.zip_longest(
+            left_keys, right_keys, fillvalue=(None, None)
+        ):
+            if ldialect != rdialect:
+                return COMPARE_FAILED
+            elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
+                return COMPARE_FAILED
+            else:
+                self.stack.append((ltable, rtable))
+
+    def visit_statement_hint_list(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        return left == right
+
+    def visit_unknown_structure(
+        self, left_parent, left, right_parent, right, **kw
+    ):
+        raise NotImplementedError()
+
+    def compare_clauselist(self, left, right, **kw):
+        if left.operator is right.operator:
+            if operators.is_associative(left.operator):
+                if self._compare_unordered_sequences(
+                    left.clauses, right.clauses, **kw
+                ):
+                    return ["operator", "clauses"]
+                else:
+                    return COMPARE_FAILED
+            else:
+                return ["operator"]
+        else:
+            return COMPARE_FAILED
+
+    def compare_binary(self, left, right, **kw):
+        if left.operator == right.operator:
+            if operators.is_commutative(left.operator):
+                if (
+                    compare(left.left, right.left, **kw)
+                    and compare(left.right, right.right, **kw)
+                ) or (
+                    compare(left.left, right.right, **kw)
+                    and compare(left.right, right.left, **kw)
+                ):
+                    return ["operator", "negate", "left", "right"]
+                else:
+                    return COMPARE_FAILED
+            else:
+                return ["operator", "negate"]
+        else:
+            return COMPARE_FAILED
+
+
+class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
+    def compare_column_element(
+        self, left, right, use_proxies=True, equivalents=(), **kw
+    ):
+        """Compare ColumnElements using proxies and equivalent collections.
+
+        This is a comparison strategy specific to the ORM.
+        """
+
+        to_compare = (right,)
+        if equivalents and right in equivalents:
+            to_compare = equivalents[right].union(to_compare)
+
+        for oth in to_compare:
+            if use_proxies and left.shares_lineage(oth):
+                return SKIP_TRAVERSE
+            elif hash(left) == hash(right):
+                return SKIP_TRAVERSE
+        else:
+            return COMPARE_FAILED
+
+    def compare_column(self, left, right, **kw):
+        return self.compare_column_element(left, right, **kw)
+
+    def compare_label(self, left, right, **kw):
+        return self.compare_column_element(left, right, **kw)
+
+    def compare_table(self, left, right, **kw):
+        # tables compare on identity, since it's not really feasible to
+        # compare them column by column with the above rules
+        return SKIP_TRAVERSE if left is right else COMPARE_FAILED
index 9c5f5dd475da21b134146a56b2a0eb069c279504..d09bb28bbe69d279e8516cbba9d5ee287156248e 100644 (file)
@@ -12,8 +12,8 @@
 
 from . import operators
 from .base import SchemaEventTarget
-from .visitors import Visitable
-from .visitors import VisitableType
+from .visitors import Traversible
+from .visitors import TraversibleType
 from .. import exc
 from .. import util
 
@@ -28,7 +28,7 @@ INDEXABLE = None
 _resolve_value_to_type = None
 
 
-class TypeEngine(Visitable):
+class TypeEngine(Traversible):
     """The ultimate base class for all SQL datatypes.
 
     Common subclasses of :class:`.TypeEngine` include
@@ -535,8 +535,13 @@ class TypeEngine(Visitable):
         return dialect.type_descriptor(self)
 
     @util.memoized_property
-    def _cache_key(self):
-        return util.constructor_key(self, self.__class__)
+    def _gen_cache_key(self):
+        names = util.get_cls_kwargs(self.__class__)
+        return (self.__class__,) + tuple(
+            (k, self.__dict__[k])
+            for k in names
+            if k in self.__dict__ and not k.startswith("_")
+        )
 
     def adapt(self, cls, **kw):
         """Produce an "adapted" form of this type, given an "impl" class
@@ -617,7 +622,7 @@ class TypeEngine(Visitable):
         return util.generic_repr(self)
 
 
-class VisitableCheckKWArg(util.EnsureKWArgType, VisitableType):
+class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType):
     pass
 
 
index e109852a2b09e6628e1a319c3b1eaea1d75a5c04..8539f4845a145f674404a7c15c0ca80334d182fc 100644 (file)
@@ -734,7 +734,7 @@ def criterion_as_pairs(
     return pairs
 
 
-class ClauseAdapter(visitors.ReplacingCloningVisitor):
+class ClauseAdapter(visitors.ReplacingExternalTraversal):
     """Clones and modifies clauses based on column correspondence.
 
     E.g.::
index 7b2ac285a9cff63432cb6252280e221d8cf24d74..8c06eb8afd7b2e395b80f2189203108b0024456a 100644 (file)
@@ -28,14 +28,10 @@ import operator
 
 from .. import exc
 from .. import util
-
+from ..util import langhelpers
+from ..util import symbol
 
 __all__ = [
-    "VisitableType",
-    "Visitable",
-    "ClauseVisitor",
-    "CloningVisitor",
-    "ReplacingCloningVisitor",
     "iterate",
     "iterate_depthfirst",
     "traverse_using",
@@ -43,85 +39,382 @@ __all__ = [
     "traverse_depthfirst",
     "cloned_traverse",
     "replacement_traverse",
+    "Traversible",
+    "TraversibleType",
+    "ExternalTraversal",
+    "InternalTraversal",
 ]
 
 
-class VisitableType(type):
-    """Metaclass which assigns a ``_compiler_dispatch`` method to classes
-    having a ``__visit_name__`` attribute.
+def _generate_compiler_dispatch(cls):
+    """Generate a _compiler_dispatch() external traversal on classes with a
+    __visit_name__ attribute.
+
+    """
+    visit_name = cls.__visit_name__
+
+    if isinstance(visit_name, util.compat.string_types):
+        # There is an optimization opportunity here because the
+        # the string name of the class's __visit_name__ is known at
+        # this early stage (import time) so it can be pre-constructed.
+        getter = operator.attrgetter("visit_%s" % visit_name)
+
+        def _compiler_dispatch(self, visitor, **kw):
+            try:
+                meth = getter(visitor)
+            except AttributeError:
+                raise exc.UnsupportedCompilationError(visitor, cls)
+            else:
+                return meth(self, **kw)
+
+    else:
+        # The optimization opportunity is lost for this case because the
+        # __visit_name__ is not yet a string. As a result, the visit
+        # string has to be recalculated with each compilation.
+        def _compiler_dispatch(self, visitor, **kw):
+            visit_attr = "visit_%s" % self.__visit_name__
+            try:
+                meth = getattr(visitor, visit_attr)
+            except AttributeError:
+                raise exc.UnsupportedCompilationError(visitor, cls)
+            else:
+                return meth(self, **kw)
+
+    _compiler_dispatch.__doc__ = """Look for an attribute named "visit_"
+        + self.__visit_name__ on the visitor, and call it with the same
+        kw params.
+        """
+    cls._compiler_dispatch = _compiler_dispatch
+
+
+class TraversibleType(type):
+    """Metaclass which assigns dispatch attributes to various kinds of
+    "visitable" classes.
 
-    The ``_compiler_dispatch`` attribute becomes an instance method which
-    looks approximately like the following::
+    Attributes include:
 
-        def _compiler_dispatch (self, visitor, **kw):
-            '''Look for an attribute named "visit_" + self.__visit_name__
-            on the visitor, and call it with the same kw params.'''
-            visit_attr = 'visit_%s' % self.__visit_name__
-            return getattr(visitor, visit_attr)(self, **kw)
+    * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``.
+      This is called "external traversal" because the caller of each visit()
+      method is responsible for sub-traversing the inner elements of each
+      object. This is appropriate for string compilers and other traversals
+      that need to call upon the inner elements in a specific pattern.
 
-    Classes having no ``__visit_name__`` attribute will remain unaffected.
+    * internal traversal collections ``_children_traversal``,
+      ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from
+      an optional ``_traverse_internals`` collection of symbols which comes
+      from the :class:`.InternalTraversal` list of symbols.  This is called
+      "internal traversal" MARKMARK
 
     """
 
     def __init__(cls, clsname, bases, clsdict):
-        if clsname != "Visitable" and hasattr(cls, "__visit_name__"):
-            _generate_dispatch(cls)
+        if clsname != "Traversible":
+            if "__visit_name__" in clsdict:
+                _generate_compiler_dispatch(cls)
+
+        super(TraversibleType, cls).__init__(clsname, bases, clsdict)
 
-        super(VisitableType, cls).__init__(clsname, bases, clsdict)
 
+class Traversible(util.with_metaclass(TraversibleType)):
+    """Base class for visitable objects, applies the
+    :class:`.visitors.TraversibleType` metaclass.
 
-def _generate_dispatch(cls):
-    """Return an optimized visit dispatch function for the cls
-    for use by the compiler.
     """
-    if "__visit_name__" in cls.__dict__:
-        visit_name = cls.__visit_name__
 
-        if isinstance(visit_name, util.compat.string_types):
-            # There is an optimization opportunity here because the
-            # the string name of the class's __visit_name__ is known at
-            # this early stage (import time) so it can be pre-constructed.
-            getter = operator.attrgetter("visit_%s" % visit_name)
 
-            def _compiler_dispatch(self, visitor, **kw):
-                try:
-                    meth = getter(visitor)
-                except AttributeError:
-                    raise exc.UnsupportedCompilationError(visitor, cls)
-                else:
-                    return meth(self, **kw)
+class _InternalTraversalType(type):
+    def __init__(cls, clsname, bases, clsdict):
+        if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"):
+            lookup = {}
+            for key, sym in clsdict.items():
+                if key.startswith("dp_"):
+                    visit_key = key.replace("dp_", "visit_")
+                    sym_name = sym.name
+                    assert sym_name not in lookup, sym_name
+                    lookup[sym] = lookup[sym_name] = visit_key
+            if hasattr(cls, "_dispatch_lookup"):
+                lookup.update(cls._dispatch_lookup)
+            cls._dispatch_lookup = lookup
+
+        super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict)
+
+
+def _generate_dispatcher(visitor, internal_dispatch, method_name):
+    names = []
+    for attrname, visit_sym in internal_dispatch:
+        meth = visitor.dispatch(visit_sym)
+        if meth:
+            visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym]
+            names.append((attrname, visit_name))
+
+    code = (
+        ("    return [\n")
+        + (
+            ", \n".join(
+                "        (%r, self.%s, visitor.%s)"
+                % (attrname, attrname, visit_name)
+                for attrname, visit_name in names
+            )
+        )
+        + ("\n    ]\n")
+    )
+    meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
+    # print(meth_text)
+    return langhelpers._exec_code_in_env(meth_text, {}, method_name)
 
-        else:
-            # The optimization opportunity is lost for this case because the
-            # __visit_name__ is not yet a string. As a result, the visit
-            # string has to be recalculated with each compilation.
-            def _compiler_dispatch(self, visitor, **kw):
-                visit_attr = "visit_%s" % self.__visit_name__
-                try:
-                    meth = getattr(visitor, visit_attr)
-                except AttributeError:
-                    raise exc.UnsupportedCompilationError(visitor, cls)
-                else:
-                    return meth(self, **kw)
-
-        _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__
-            on the visitor, and call it with the same kw params.
-            """
-        cls._compiler_dispatch = _compiler_dispatch
-
-
-class Visitable(util.with_metaclass(VisitableType, object)):
-    """Base class for visitable objects, applies the
-    :class:`.visitors.VisitableType` metaclass.
 
-    The :class:`.Visitable` class is essentially at the base of the
-    :class:`.ClauseElement` hierarchy.
+class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
+    r"""Defines visitor symbols used for internal traversal.
+
+    The :class:`.InternalTraversal` class is used in two ways.  One is that
+    it can serve as the superclass for an object that implements the
+    various visit methods of the class.   The other is that the symbols
+    themselves of :class:`.InternalTraversal` are used within
+    the ``_traverse_internals`` collection.   Such as, the :class:`.Case`
+    object defines ``_travserse_internals`` as ::
+
+        _traverse_internals = [
+            ("value", InternalTraversal.dp_clauseelement),
+            ("whens", InternalTraversal.dp_clauseelement_tuples),
+            ("else_", InternalTraversal.dp_clauseelement),
+        ]
+
+    Above, the :class:`.Case` class indicates its internal state as the
+    attribtues named ``value``, ``whens``, and ``else\_``.    They each
+    link to an :class:`.InternalTraversal` method which indicates the type
+    of datastructure referred towards.
+
+    Using the ``_traverse_internals`` structure, objects of type
+    :class:`.InternalTraversible` will have the following methods automatically
+    implemented:
+
+    * :meth:`.Traversible.get_children`
+
+    * :meth:`.Traversible._copy_internals`
+
+    * :meth:`.Traversible._gen_cache_key`
+
+    Subclasses can also implement these methods directly, particularly for the
+    :meth:`.Traversible._copy_internals` method, when special steps
+    are needed.
+
+    .. versionadded:: 1.4
 
     """
 
+    def dispatch(self, visit_symbol):
+        """Given a method from :class:`.InternalTraversal`, return the
+        corresponding method on a subclass.
 
-class ClauseVisitor(object):
-    """Base class for visitor objects which can traverse using
+        """
+        name = self._dispatch_lookup[visit_symbol]
+        return getattr(self, name, None)
+
+    def run_generated_dispatch(
+        self, target, internal_dispatch, generate_dispatcher_name
+    ):
+        try:
+            dispatcher = target.__class__.__dict__[generate_dispatcher_name]
+        except KeyError:
+            dispatcher = _generate_dispatcher(
+                self, internal_dispatch, generate_dispatcher_name
+            )
+            setattr(target.__class__, generate_dispatcher_name, dispatcher)
+        return dispatcher(target, self)
+
+    dp_has_cache_key = symbol("HC")
+    """Visit a :class:`.HasCacheKey` object."""
+
+    dp_clauseelement = symbol("CE")
+    """Visit a :class:`.ClauseElement` object."""
+
+    dp_fromclause_canonical_column_collection = symbol("FC")
+    """Visit a :class:`.FromClause` object in the context of the
+    ``columns`` attribute.
+
+    The column collection is "canonical", meaning it is the originally
+    defined location of the :class:`.ColumnClause` objects.   Right now
+    this means that the object being visited is a :class:`.TableClause`
+    or :class:`.Table` object only.
+
+    """
+
+    dp_clauseelement_tuples = symbol("CT")
+    """Visit a list of tuples which contain :class:`.ClauseElement`
+    objects.
+
+    """
+
+    dp_clauseelement_list = symbol("CL")
+    """Visit a list of :class:`.ClauseElement` objects.
+
+    """
+
+    dp_clauseelement_unordered_set = symbol("CU")
+    """Visit an unordered set of :class:`.ClauseElement` objects. """
+
+    dp_fromclause_ordered_set = symbol("CO")
+    """Visit an ordered set of :class:`.FromClause` objects. """
+
+    dp_string = symbol("S")
+    """Visit a plain string value.
+
+    Examples include table and column names, bound parameter keys, special
+    keywords such as "UNION", "UNION ALL".
+
+    The string value is considered to be significant for cache key
+    generation.
+
+    """
+
+    dp_anon_name = symbol("AN")
+    """Visit a potentially "anonymized" string value.
+
+    The string value is considered to be significant for cache key
+    generation.
+
+    """
+
+    dp_boolean = symbol("B")
+    """Visit a boolean value.
+
+    The boolean value is considered to be significant for cache key
+    generation.
+
+    """
+
+    dp_operator = symbol("O")
+    """Visit an operator.
+
+    The operator is a function from the :mod:`sqlalchemy.sql.operators`
+    module.
+
+    The operator value is considered to be significant for cache key
+    generation.
+
+    """
+
+    dp_type = symbol("T")
+    """Visit a :class:`.TypeEngine` object
+
+    The type object is considered to be significant for cache key
+    generation.
+
+    """
+
+    dp_plain_dict = symbol("PD")
+    """Visit a dictionary with string keys.
+
+    The keys of the dictionary should be strings, the values should
+    be immutable and hashable.   The dictionary is considered to be
+    significant for cache key generation.
+
+    """
+
+    dp_string_clauseelement_dict = symbol("CD")
+    """Visit a dictionary of string keys to :class:`.ClauseElement`
+    objects.
+
+    """
+
+    dp_string_multi_dict = symbol("MD")
+    """Visit a dictionary of string keys to values which may either be
+    plain immutable/hashable or :class:`.HasCacheKey` objects.
+
+    """
+
+    dp_plain_obj = symbol("PO")
+    """Visit a plain python object.
+
+    The value should be immutable and hashable, such as an integer.
+    The value is considered to be significant for cache key generation.
+
+    """
+
+    dp_annotations_state = symbol("A")
+    """Visit the state of the :class:`.Annotatated` version of an object.
+
+    """
+
+    dp_named_ddl_element = symbol("DD")
+    """Visit a simple named DDL element.
+
+    The current object used by this method is the :class:`.Sequence`.
+
+    The object is only considered to be important for cache key generation
+    as far as its name, but not any other aspects of it.
+
+    """
+
+    dp_prefix_sequence = symbol("PS")
+    """Visit the sequence represented by :class:`.HasPrefixes`
+    or :class:`.HasSuffixes`.
+
+    """
+
+    dp_table_hint_list = symbol("TH")
+    """Visit the ``_hints`` collection of a :class:`.Select` object.
+
+    """
+
+    dp_statement_hint_list = symbol("SH")
+    """Visit the ``_statement_hints`` collection of a :class:`.Select`
+    object.
+
+    """
+
+    dp_unknown_structure = symbol("UK")
+    """Visit an unknown structure.
+
+    """
+
+
+class ExtendedInternalTraversal(InternalTraversal):
+    """defines additional symbols that are useful in caching applications.
+
+    Traversals for :class:`.ClauseElement` objects only need to use
+    those symbols present in :class:`.InternalTraversal`.  However, for
+    additional caching use cases within the ORM, symbols dealing with the
+    :class:`.HasCacheKey` class are added here.
+
+    """
+
+    dp_ignore = symbol("IG")
+    """Specify an object that should be ignored entirely.
+
+    This currently applies function call argument caching where some
+    arguments should not be considered to be part of a cache key.
+
+    """
+
+    dp_inspectable = symbol("IS")
+    """Visit an inspectable object where the return value is a HasCacheKey`
+    object."""
+
+    dp_multi = symbol("M")
+    """Visit an object that may be a :class:`.HasCacheKey` or may be a
+    plain hashable object."""
+
+    dp_multi_list = symbol("MT")
+    """Visit a tuple containing elements that may be :class:`.HasCacheKey` or
+    may be a plain hashable object."""
+
+    dp_has_cache_key_tuples = symbol("HT")
+    """Visit a list of tuples which contain :class:`.HasCacheKey`
+    objects.
+
+    """
+
+    dp_has_cache_key_list = symbol("HL")
+    """Visit a list of :class:`.HasCacheKey` objects."""
+
+    dp_inspectable_list = symbol("IL")
+    """Visit a list of inspectable objects which upon inspection are
+    HasCacheKey objects."""
+
+
+class ExternalTraversal(object):
+    """Base class for visitor objects which can traverse externally using
     the :func:`.visitors.traverse` function.
 
     Direct usage of the :func:`.visitors.traverse` function is usually
@@ -178,7 +471,7 @@ class ClauseVisitor(object):
         return self
 
 
-class CloningVisitor(ClauseVisitor):
+class CloningExternalTraversal(ExternalTraversal):
     """Base class for visitor objects which can traverse using
     the :func:`.visitors.cloned_traverse` function.
 
@@ -203,7 +496,7 @@ class CloningVisitor(ClauseVisitor):
         )
 
 
-class ReplacingCloningVisitor(CloningVisitor):
+class ReplacingExternalTraversal(CloningExternalTraversal):
     """Base class for visitor objects which can traverse using
     the :func:`.visitors.replacement_traverse` function.
 
@@ -233,6 +526,14 @@ class ReplacingCloningVisitor(CloningVisitor):
         return replacement_traverse(obj, self.__traverse_options__, replace)
 
 
+# backwards compatibility
+Visitable = Traversible
+VisitableType = TraversibleType
+ClauseVisitor = ExternalTraversal
+CloningVisitor = CloningExternalTraversal
+ReplacingCloningVisitor = ReplacingExternalTraversal
+
+
 def iterate(obj, opts):
     r"""traverse the given expression structure, returning an iterator.
 
@@ -405,11 +706,18 @@ def cloned_traverse(obj, opts, visitors):
     cloned = {}
     stop_on = set(opts.get("stop_on", []))
 
-    def clone(elem):
+    def clone(elem, **kw):
         if elem in stop_on:
             return elem
         else:
             if id(elem) not in cloned:
+
+                if "replace" in kw:
+                    newelem = kw["replace"](elem)
+                    if newelem is not None:
+                        cloned[id(elem)] = newelem
+                        return newelem
+
                 cloned[id(elem)] = newelem = elem._clone()
                 newelem._copy_internals(clone=clone)
                 meth = visitors.get(newelem.__visit_name__, None)
@@ -461,7 +769,14 @@ def replacement_traverse(obj, opts, replace):
                 stop_on.add(id(newelem))
                 return newelem
             else:
+
                 if elem not in cloned:
+                    if "replace" in kw:
+                        newelem = kw["replace"](elem)
+                        if newelem is not None:
+                            cloned[elem] = newelem
+                            return newelem
+
                     cloned[elem] = newelem = elem._clone()
                     newelem._copy_internals(clone=clone, **kw)
                 return cloned[elem]
index 632f559373d9bb0c3162946cc365d71f9a62bf57..209bc02e3f37a4e7d0bade6fe760a839b91da2d5 100644 (file)
@@ -934,7 +934,7 @@ class BranchedOptionTest(fixtures.MappedTest):
 
         configure_mappers()
 
-    def test_generate_cache_key_unbound_branching(self):
+    def test_generate_path_cache_key_unbound_branching(self):
         A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G")
 
         base = joinedload(A.bs)
@@ -950,11 +950,11 @@ class BranchedOptionTest(fixtures.MappedTest):
         @profiling.function_call_count()
         def go():
             for opt in opts:
-                opt._generate_cache_key(cache_path)
+                opt._generate_path_cache_key(cache_path)
 
         go()
 
-    def test_generate_cache_key_bound_branching(self):
+    def test_generate_path_cache_key_bound_branching(self):
         A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G")
 
         base = Load(A).joinedload(A.bs)
@@ -970,7 +970,7 @@ class BranchedOptionTest(fixtures.MappedTest):
         @profiling.function_call_count()
         def go():
             for opt in opts:
-                opt._generate_cache_key(cache_path)
+                opt._generate_path_cache_key(cache_path)
 
         go()
 
index acefe625aba41bd75f6abedec717f6a0639d36f7..01f0e267f2337032d319fbc9747fd8cfab160f85 100644 (file)
@@ -1533,7 +1533,7 @@ class CustomIntegrationTest(testing.AssertsCompiledSQL, BakedTest):
                 if query._current_path:
                     query._cache_key = "user7_addresses"
 
-            def _generate_cache_key(self, path):
+            def _generate_path_cache_key(self, path):
                 return None
 
         return RelationshipCache()
diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py
new file mode 100644 (file)
index 0000000..79a9484
--- /dev/null
@@ -0,0 +1,120 @@
+from sqlalchemy import inspect
+from sqlalchemy.orm import aliased
+from sqlalchemy.orm import defaultload
+from sqlalchemy.orm import defer
+from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import Load
+from sqlalchemy.orm import subqueryload
+from sqlalchemy.testing import eq_
+from test.orm import _fixtures
+from ..sql.test_compare import CacheKeyFixture
+
+
+class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
+    run_setup_mappers = "once"
+    run_inserts = None
+    run_deletes = None
+
+    @classmethod
+    def setup_mappers(cls):
+        cls._setup_stock_mapping()
+
+    def test_mapper_and_aliased(self):
+        User, Address, Keyword = self.classes("User", "Address", "Keyword")
+
+        self._run_cache_key_fixture(
+            lambda: (inspect(User), inspect(Address), inspect(aliased(User)))
+        )
+
+    def test_attributes(self):
+        User, Address, Keyword = self.classes("User", "Address", "Keyword")
+
+        self._run_cache_key_fixture(
+            lambda: (
+                User.id,
+                Address.id,
+                aliased(User).id,
+                aliased(User, name="foo").id,
+                aliased(User, name="bar").id,
+                User.name,
+                User.addresses,
+                Address.email_address,
+                aliased(User).addresses,
+            )
+        )
+
+    def test_unbound_options(self):
+        User, Address, Keyword, Order, Item = self.classes(
+            "User", "Address", "Keyword", "Order", "Item"
+        )
+
+        self._run_cache_key_fixture(
+            lambda: (
+                joinedload(User.addresses),
+                joinedload("addresses"),
+                joinedload(User.orders).selectinload("items"),
+                joinedload(User.orders).selectinload(Order.items),
+                defer(User.id),
+                defer("id"),
+                defer(Address.id),
+                joinedload(User.addresses).defer(Address.id),
+                joinedload(aliased(User).addresses).defer(Address.id),
+                joinedload(User.addresses).defer("id"),
+                joinedload(User.orders).joinedload(Order.items),
+                joinedload(User.orders).subqueryload(Order.items),
+                subqueryload(User.orders).subqueryload(Order.items),
+                subqueryload(User.orders)
+                .subqueryload(Order.items)
+                .defer(Item.description),
+                defaultload(User.orders).defaultload(Order.items),
+                defaultload(User.orders),
+            )
+        )
+
+    def test_bound_options(self):
+        User, Address, Keyword, Order, Item = self.classes(
+            "User", "Address", "Keyword", "Order", "Item"
+        )
+
+        self._run_cache_key_fixture(
+            lambda: (
+                Load(User).joinedload(User.addresses),
+                Load(User).joinedload(User.orders),
+                Load(User).defer(User.id),
+                Load(User).subqueryload("addresses"),
+                Load(Address).defer("id"),
+                Load(aliased(Address)).defer("id"),
+                Load(User).joinedload(User.addresses).defer(Address.id),
+                Load(User).joinedload(User.orders).joinedload(Order.items),
+                Load(User).joinedload(User.orders).subqueryload(Order.items),
+                Load(User).subqueryload(User.orders).subqueryload(Order.items),
+                Load(User)
+                .subqueryload(User.orders)
+                .subqueryload(Order.items)
+                .defer(Item.description),
+                Load(User).defaultload(User.orders).defaultload(Order.items),
+                Load(User).defaultload(User.orders),
+            )
+        )
+
+    def test_bound_options_equiv_on_strname(self):
+        """Bound loader options resolve on string name so test that the cache
+        key for the string version matches the resolved version.
+
+        """
+        User, Address, Keyword, Order, Item = self.classes(
+            "User", "Address", "Keyword", "Order", "Item"
+        )
+
+        for left, right in [
+            (Load(User).defer(User.id), Load(User).defer("id")),
+            (
+                Load(User).joinedload(User.addresses),
+                Load(User).joinedload("addresses"),
+            ),
+            (
+                Load(User).joinedload(User.orders).joinedload(Order.items),
+                Load(User).joinedload("orders").joinedload("items"),
+            ),
+        ]:
+            eq_(left._generate_cache_key(), right._generate_cache_key())
index bf099e7e6d7e4ccfcc3140eda1eaf59706d6111b..e84d5950c8274f5014e4a949eddd64c73a9f0300 100644 (file)
@@ -1790,7 +1790,7 @@ class SubOptionsTest(PathTest, QueryTest):
         )
 
 
-class CacheKeyTest(PathTest, QueryTest):
+class PathedCacheKeyTest(PathTest, QueryTest):
 
     run_create_tables = False
     run_inserts = None
@@ -1805,7 +1805,7 @@ class CacheKeyTest(PathTest, QueryTest):
 
         opt = joinedload(User.orders).joinedload(Order.items)
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (((Order, "items", Item, ("lazy", "joined")),)),
         )
 
@@ -1821,12 +1821,12 @@ class CacheKeyTest(PathTest, QueryTest):
         opt2 = base.joinedload(Order.address)
 
         eq_(
-            opt1._generate_cache_key(query_path),
+            opt1._generate_path_cache_key(query_path),
             (((Order, "items", Item, ("lazy", "joined")),)),
         )
 
         eq_(
-            opt2._generate_cache_key(query_path),
+            opt2._generate_path_cache_key(query_path),
             (((Order, "address", Address, ("lazy", "joined")),)),
         )
 
@@ -1842,12 +1842,12 @@ class CacheKeyTest(PathTest, QueryTest):
         opt2 = base.joinedload(Order.address)
 
         eq_(
-            opt1._generate_cache_key(query_path),
+            opt1._generate_path_cache_key(query_path),
             (((Order, "items", Item, ("lazy", "joined")),)),
         )
 
         eq_(
-            opt2._generate_cache_key(query_path),
+            opt2._generate_path_cache_key(query_path),
             (((Order, "address", Address, ("lazy", "joined")),)),
         )
 
@@ -1860,7 +1860,7 @@ class CacheKeyTest(PathTest, QueryTest):
 
         opt = Load(User).joinedload(User.orders).joinedload(Order.items)
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (((Order, "items", Item, ("lazy", "joined")),)),
         )
 
@@ -1872,7 +1872,7 @@ class CacheKeyTest(PathTest, QueryTest):
         query_path = self._make_path_registry([User, "addresses"])
 
         opt = joinedload(User.orders).joinedload(Order.items)
-        eq_(opt._generate_cache_key(query_path), None)
+        eq_(opt._generate_path_cache_key(query_path), None)
 
     def test_bound_cache_key_excluded_on_other(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -1882,7 +1882,7 @@ class CacheKeyTest(PathTest, QueryTest):
         query_path = self._make_path_registry([User, "addresses"])
 
         opt = Load(User).joinedload(User.orders).joinedload(Order.items)
-        eq_(opt._generate_cache_key(query_path), None)
+        eq_(opt._generate_path_cache_key(query_path), None)
 
     def test_unbound_cache_key_excluded_on_aliased(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -1901,7 +1901,7 @@ class CacheKeyTest(PathTest, QueryTest):
         query_path = self._make_path_registry([User, "orders"])
 
         opt = joinedload(aliased(User).orders).joinedload(Order.items)
-        eq_(opt._generate_cache_key(query_path), None)
+        eq_(opt._generate_path_cache_key(query_path), None)
 
     def test_bound_cache_key_wildcard_one(self):
         # do not change this test, it is testing
@@ -1911,7 +1911,7 @@ class CacheKeyTest(PathTest, QueryTest):
         query_path = self._make_path_registry([User, "addresses"])
 
         opt = Load(User).lazyload("*")
-        eq_(opt._generate_cache_key(query_path), None)
+        eq_(opt._generate_path_cache_key(query_path), None)
 
     def test_unbound_cache_key_wildcard_one(self):
         User, Address = self.classes("User", "Address")
@@ -1920,7 +1920,7 @@ class CacheKeyTest(PathTest, QueryTest):
 
         opt = lazyload("*")
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (("relationship:_sa_default", ("lazy", "select")),),
         )
 
@@ -1933,7 +1933,7 @@ class CacheKeyTest(PathTest, QueryTest):
 
         opt = Load(User).lazyload("orders").lazyload("*")
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (
                 ("orders", Order, ("lazy", "select")),
                 ("orders", Order, "relationship:*", ("lazy", "select")),
@@ -1949,7 +1949,7 @@ class CacheKeyTest(PathTest, QueryTest):
 
         opt = lazyload("orders").lazyload("*")
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (
                 ("orders", Order, ("lazy", "select")),
                 ("orders", Order, "relationship:*", ("lazy", "select")),
@@ -1968,7 +1968,7 @@ class CacheKeyTest(PathTest, QueryTest):
         )
 
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (
                 (SubItem, ("lazy", "subquery")),
                 ("extra_keywords", Keyword, ("lazy", "subquery")),
@@ -1987,7 +1987,7 @@ class CacheKeyTest(PathTest, QueryTest):
         )
 
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (
                 (SubItem, ("lazy", "subquery")),
                 ("extra_keywords", Keyword, ("lazy", "subquery")),
@@ -2008,7 +2008,7 @@ class CacheKeyTest(PathTest, QueryTest):
         )
 
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (
                 (SubItem, ("lazy", "subquery")),
                 ("extra_keywords", Keyword, ("lazy", "subquery")),
@@ -2029,7 +2029,7 @@ class CacheKeyTest(PathTest, QueryTest):
         )
 
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (
                 (SubItem, ("lazy", "subquery")),
                 ("extra_keywords", Keyword, ("lazy", "subquery")),
@@ -2056,7 +2056,7 @@ class CacheKeyTest(PathTest, QueryTest):
         opt = subqueryload(User.orders).subqueryload(
             Order.items.of_type(SubItem)
         )
-        eq_(opt._generate_cache_key(query_path), None)
+        eq_(opt._generate_path_cache_key(query_path), None)
 
     def test_unbound_cache_key_excluded_of_type_unsafe(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -2078,7 +2078,7 @@ class CacheKeyTest(PathTest, QueryTest):
         opt = subqueryload(User.orders).subqueryload(
             Order.items.of_type(aliased(SubItem))
         )
-        eq_(opt._generate_cache_key(query_path), None)
+        eq_(opt._generate_path_cache_key(query_path), None)
 
     def test_bound_cache_key_excluded_of_type_safe(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -2102,7 +2102,7 @@ class CacheKeyTest(PathTest, QueryTest):
             .subqueryload(User.orders)
             .subqueryload(Order.items.of_type(SubItem))
         )
-        eq_(opt._generate_cache_key(query_path), None)
+        eq_(opt._generate_path_cache_key(query_path), None)
 
     def test_bound_cache_key_excluded_of_type_unsafe(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -2126,7 +2126,7 @@ class CacheKeyTest(PathTest, QueryTest):
             .subqueryload(User.orders)
             .subqueryload(Order.items.of_type(aliased(SubItem)))
         )
-        eq_(opt._generate_cache_key(query_path), None)
+        eq_(opt._generate_path_cache_key(query_path), None)
 
     def test_unbound_cache_key_included_of_type_safe(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -2137,7 +2137,7 @@ class CacheKeyTest(PathTest, QueryTest):
 
         opt = joinedload(User.orders).joinedload(Order.items.of_type(SubItem))
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             ((Order, "items", SubItem, ("lazy", "joined")),),
         )
 
@@ -2155,7 +2155,7 @@ class CacheKeyTest(PathTest, QueryTest):
         )
 
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             ((Order, "items", SubItem, ("lazy", "joined")),),
         )
 
@@ -2169,7 +2169,7 @@ class CacheKeyTest(PathTest, QueryTest):
         opt = joinedload(User.orders).joinedload(
             Order.items.of_type(aliased(SubItem))
         )
-        eq_(opt._generate_cache_key(query_path), False)
+        eq_(opt._generate_path_cache_key(query_path), False)
 
     def test_unbound_cache_key_included_unsafe_option_two(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -2181,7 +2181,7 @@ class CacheKeyTest(PathTest, QueryTest):
         opt = joinedload(User.orders).joinedload(
             Order.items.of_type(aliased(SubItem))
         )
-        eq_(opt._generate_cache_key(query_path), False)
+        eq_(opt._generate_path_cache_key(query_path), False)
 
     def test_unbound_cache_key_included_unsafe_option_three(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -2193,7 +2193,7 @@ class CacheKeyTest(PathTest, QueryTest):
         opt = joinedload(User.orders).joinedload(
             Order.items.of_type(aliased(SubItem))
         )
-        eq_(opt._generate_cache_key(query_path), False)
+        eq_(opt._generate_path_cache_key(query_path), False)
 
     def test_unbound_cache_key_included_unsafe_query(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -2204,7 +2204,7 @@ class CacheKeyTest(PathTest, QueryTest):
         query_path = self._make_path_registry([inspect(au), "orders"])
 
         opt = joinedload(au.orders).joinedload(Order.items)
-        eq_(opt._generate_cache_key(query_path), False)
+        eq_(opt._generate_path_cache_key(query_path), False)
 
     def test_unbound_cache_key_included_safe_w_deferred(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -2219,7 +2219,7 @@ class CacheKeyTest(PathTest, QueryTest):
             .defer(Address.user_id)
         )
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (
                 (
                     Address,
@@ -2247,12 +2247,12 @@ class CacheKeyTest(PathTest, QueryTest):
         )
 
         eq_(
-            opt1._generate_cache_key(query_path),
+            opt1._generate_path_cache_key(query_path),
             ((Order, "items", Item, ("lazy", "joined")),),
         )
 
         eq_(
-            opt2._generate_cache_key(query_path),
+            opt2._generate_path_cache_key(query_path),
             (
                 (Order, "address", Address, ("lazy", "joined")),
                 (
@@ -2288,7 +2288,7 @@ class CacheKeyTest(PathTest, QueryTest):
             .defer(Address.user_id)
         )
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (
                 (
                     Address,
@@ -2316,12 +2316,12 @@ class CacheKeyTest(PathTest, QueryTest):
         )
 
         eq_(
-            opt1._generate_cache_key(query_path),
+            opt1._generate_path_cache_key(query_path),
             ((Order, "items", Item, ("lazy", "joined")),),
         )
 
         eq_(
-            opt2._generate_cache_key(query_path),
+            opt2._generate_path_cache_key(query_path),
             (
                 (Order, "address", Address, ("lazy", "joined")),
                 (
@@ -2356,7 +2356,7 @@ class CacheKeyTest(PathTest, QueryTest):
         query_path = self._make_path_registry([User, "orders"])
 
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (
                 (
                     Order,
@@ -2385,7 +2385,7 @@ class CacheKeyTest(PathTest, QueryTest):
 
         au = aliased(User)
         opt = Load(au).joinedload(au.orders).joinedload(Order.items)
-        eq_(opt._generate_cache_key(query_path), None)
+        eq_(opt._generate_path_cache_key(query_path), None)
 
     def test_bound_cache_key_included_unsafe_option_one(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -2399,7 +2399,7 @@ class CacheKeyTest(PathTest, QueryTest):
             .joinedload(User.orders)
             .joinedload(Order.items.of_type(aliased(SubItem)))
         )
-        eq_(opt._generate_cache_key(query_path), False)
+        eq_(opt._generate_path_cache_key(query_path), False)
 
     def test_bound_cache_key_included_unsafe_option_two(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -2413,7 +2413,7 @@ class CacheKeyTest(PathTest, QueryTest):
             .joinedload(User.orders)
             .joinedload(Order.items.of_type(aliased(SubItem)))
         )
-        eq_(opt._generate_cache_key(query_path), False)
+        eq_(opt._generate_path_cache_key(query_path), False)
 
     def test_bound_cache_key_included_unsafe_option_three(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -2427,7 +2427,7 @@ class CacheKeyTest(PathTest, QueryTest):
             .joinedload(User.orders)
             .joinedload(Order.items.of_type(aliased(SubItem)))
         )
-        eq_(opt._generate_cache_key(query_path), False)
+        eq_(opt._generate_path_cache_key(query_path), False)
 
     def test_bound_cache_key_included_unsafe_query(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -2438,7 +2438,7 @@ class CacheKeyTest(PathTest, QueryTest):
         query_path = self._make_path_registry([inspect(au), "orders"])
 
         opt = Load(au).joinedload(au.orders).joinedload(Order.items)
-        eq_(opt._generate_cache_key(query_path), False)
+        eq_(opt._generate_path_cache_key(query_path), False)
 
     def test_bound_cache_key_included_safe_w_option(self):
         User, Address, Order, Item, SubItem = self.classes(
@@ -2454,7 +2454,7 @@ class CacheKeyTest(PathTest, QueryTest):
         query_path = self._make_path_registry([User, "orders"])
 
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (
                 (
                     Order,
@@ -2483,7 +2483,7 @@ class CacheKeyTest(PathTest, QueryTest):
 
         opt = defaultload(User.addresses).load_only("id", "email_address")
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (
                 (Address, "id", ("deferred", False), ("instrument", True)),
                 (
@@ -2513,7 +2513,7 @@ class CacheKeyTest(PathTest, QueryTest):
             Address.id, Address.email_address
         )
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (
                 (Address, "id", ("deferred", False), ("instrument", True)),
                 (
@@ -2545,7 +2545,7 @@ class CacheKeyTest(PathTest, QueryTest):
             .load_only("id", "email_address")
         )
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             (
                 (Address, "id", ("deferred", False), ("instrument", True)),
                 (
@@ -2572,7 +2572,7 @@ class CacheKeyTest(PathTest, QueryTest):
         opt = defaultload(User.addresses).undefer_group("xyz")
 
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             ((Address, "column:*", ("undefer_group_xyz", True)),),
         )
 
@@ -2584,6 +2584,6 @@ class CacheKeyTest(PathTest, QueryTest):
         opt = Load(User).defaultload(User.addresses).undefer_group("xyz")
 
         eq_(
-            opt._generate_cache_key(query_path),
+            opt._generate_path_cache_key(query_path),
             ((Address, "column:*", ("undefer_group_xyz", True)),),
         )
index d48a8ed3389dea2842c9a670ff6ced54e7c006d9..5d21960b70376008a03c5227fd51df343440f633 100644 (file)
@@ -32,6 +32,7 @@ from sqlalchemy.sql import operators
 from sqlalchemy.sql import True_
 from sqlalchemy.sql import type_coerce
 from sqlalchemy.sql import visitors
+from sqlalchemy.sql.base import HasCacheKey
 from sqlalchemy.sql.elements import _label_reference
 from sqlalchemy.sql.elements import _textual_label_reference
 from sqlalchemy.sql.elements import Annotated
@@ -46,13 +47,13 @@ from sqlalchemy.sql.functions import FunctionElement
 from sqlalchemy.sql.functions import GenericFunction
 from sqlalchemy.sql.functions import ReturnTypeFromArgs
 from sqlalchemy.sql.selectable import _OffsetLimitParam
+from sqlalchemy.sql.selectable import AliasedReturnsRows
 from sqlalchemy.sql.selectable import FromGrouping
 from sqlalchemy.sql.selectable import Selectable
 from sqlalchemy.sql.selectable import SelectStatementGrouping
-from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.sql.visitors import InternalTraversal
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
-from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import ne_
@@ -63,8 +64,17 @@ meta = MetaData()
 meta2 = MetaData()
 
 table_a = Table("a", meta, Column("a", Integer), Column("b", String))
+table_b_like_a = Table("b2", meta, Column("a", Integer), Column("b", String))
+
 table_a_2 = Table("a", meta2, Column("a", Integer), Column("b", String))
 
+table_a_2_fs = Table(
+    "a", meta2, Column("a", Integer), Column("b", String), schema="fs"
+)
+table_a_2_bs = Table(
+    "a", meta2, Column("a", Integer), Column("b", String), schema="bs"
+)
+
 table_b = Table("b", meta, Column("a", Integer), Column("b", Integer))
 
 table_c = Table("c", meta, Column("x", Integer), Column("y", Integer))
@@ -72,8 +82,18 @@ table_c = Table("c", meta, Column("x", Integer), Column("y", Integer))
 table_d = Table("d", meta, Column("y", Integer), Column("z", Integer))
 
 
-class CompareAndCopyTest(fixtures.TestBase):
+class MyEntity(HasCacheKey):
+    def __init__(self, name, element):
+        self.name = name
+        self.element = element
+
+    _cache_key_traversal = [
+        ("name", InternalTraversal.dp_string),
+        ("element", InternalTraversal.dp_clauseelement),
+    ]
+
 
+class CoreFixtures(object):
     # lambdas which return a tuple of ColumnElement objects.
     # must return at least two objects that should compare differently.
     # to test more varieties of "difference" additional objects can be added.
@@ -100,11 +120,47 @@ class CompareAndCopyTest(fixtures.TestBase):
             text("select a, b, c from table").columns(
                 a=Integer, b=String, c=Integer
             ),
+            text("select a, b, c from table where foo=:bar").bindparams(
+                bindparam("bar", Integer)
+            ),
+            text("select a, b, c from table where foo=:foo").bindparams(
+                bindparam("foo", Integer)
+            ),
+            text("select a, b, c from table where foo=:bar").bindparams(
+                bindparam("bar", String)
+            ),
         ),
         lambda: (
             column("q") == column("x"),
             column("q") == column("y"),
             column("z") == column("x"),
+            column("z") + column("x"),
+            column("z") - column("x"),
+            column("x") - column("z"),
+            column("z") > column("x"),
+            # note these two are mathematically equivalent but for now they
+            # are considered to be different
+            column("z") >= column("x"),
+            column("x") <= column("z"),
+            column("q").between(5, 6),
+            column("q").between(5, 6, symmetric=True),
+            column("q").like("somstr"),
+            column("q").like("somstr", escape="\\"),
+            column("q").like("somstr", escape="X"),
+        ),
+        lambda: (
+            table_a.c.a,
+            table_a.c.a._annotate({"orm": True}),
+            table_a.c.a._annotate({"orm": True})._annotate({"bar": False}),
+            table_a.c.a._annotate(
+                {"orm": True, "parententity": MyEntity("a", table_a)}
+            ),
+            table_a.c.a._annotate(
+                {"orm": True, "parententity": MyEntity("b", table_a)}
+            ),
+            table_a.c.a._annotate(
+                {"orm": True, "parententity": MyEntity("b", select([table_a]))}
+            ),
         ),
         lambda: (
             cast(column("q"), Integer),
@@ -225,6 +281,58 @@ class CompareAndCopyTest(fixtures.TestBase):
             .where(table_a.c.b == 5)
             .correlate_except(table_b),
         ),
+        lambda: (
+            select([table_a.c.a]).cte(),
+            select([table_a.c.a]).cte(recursive=True),
+            select([table_a.c.a]).cte(name="some_cte", recursive=True),
+            select([table_a.c.a]).cte(name="some_cte"),
+            select([table_a.c.a]).cte(name="some_cte").alias("other_cte"),
+            select([table_a.c.a])
+            .cte(name="some_cte")
+            .union_all(select([table_a.c.a])),
+            select([table_a.c.a])
+            .cte(name="some_cte")
+            .union_all(select([table_a.c.b])),
+            select([table_a.c.a]).lateral(),
+            select([table_a.c.a]).lateral(name="bar"),
+            table_a.tablesample(func.bernoulli(1)),
+            table_a.tablesample(func.bernoulli(1), seed=func.random()),
+            table_a.tablesample(func.bernoulli(1), seed=func.other_random()),
+            table_a.tablesample(func.hoho(1)),
+            table_a.tablesample(func.bernoulli(1), name="bar"),
+            table_a.tablesample(
+                func.bernoulli(1), name="bar", seed=func.random()
+            ),
+        ),
+        lambda: (
+            select([table_a.c.a]),
+            select([table_a.c.a]).prefix_with("foo"),
+            select([table_a.c.a]).prefix_with("foo", dialect="mysql"),
+            select([table_a.c.a]).prefix_with("foo", dialect="postgresql"),
+            select([table_a.c.a]).prefix_with("bar"),
+            select([table_a.c.a]).suffix_with("bar"),
+        ),
+        lambda: (
+            select([table_a_2.c.a]),
+            select([table_a_2_fs.c.a]),
+            select([table_a_2_bs.c.a]),
+        ),
+        lambda: (
+            select([table_a.c.a]),
+            select([table_a.c.a]).with_hint(None, "some hint"),
+            select([table_a.c.a]).with_hint(None, "some other hint"),
+            select([table_a.c.a]).with_hint(table_a, "some hint"),
+            select([table_a.c.a])
+            .with_hint(table_a, "some hint")
+            .with_hint(None, "some other hint"),
+            select([table_a.c.a]).with_hint(table_a, "some other hint"),
+            select([table_a.c.a]).with_hint(
+                table_a, "some hint", dialect_name="mysql"
+            ),
+            select([table_a.c.a]).with_hint(
+                table_a, "some hint", dialect_name="postgresql"
+            ),
+        ),
         lambda: (
             table_a.join(table_b, table_a.c.a == table_b.c.a),
             table_a.join(
@@ -273,12 +381,202 @@ class CompareAndCopyTest(fixtures.TestBase):
             table("a", column("x"), column("y", Integer)),
             table("a", column("q"), column("y", Integer)),
         ),
-        lambda: (
-            Table("a", MetaData(), Column("q", Integer), Column("b", String)),
-            Table("b", MetaData(), Column("q", Integer), Column("b", String)),
-        ),
+        lambda: (table_a, table_b),
     ]
 
+    def _complex_fixtures():
+        def one():
+            a1 = table_a.alias()
+            a2 = table_b_like_a.alias()
+
+            stmt = (
+                select([table_a.c.a, a1.c.b, a2.c.b])
+                .where(table_a.c.b == a1.c.b)
+                .where(a1.c.b == a2.c.b)
+                .where(a1.c.a == 5)
+            )
+
+            return stmt
+
+        def one_diff():
+            a1 = table_b_like_a.alias()
+            a2 = table_a.alias()
+
+            stmt = (
+                select([table_a.c.a, a1.c.b, a2.c.b])
+                .where(table_a.c.b == a1.c.b)
+                .where(a1.c.b == a2.c.b)
+                .where(a1.c.a == 5)
+            )
+
+            return stmt
+
+        def two():
+            inner = one().subquery()
+
+            stmt = select([table_b.c.a, inner.c.a, inner.c.b]).select_from(
+                table_b.join(inner, table_b.c.b == inner.c.b)
+            )
+
+            return stmt
+
+        def three():
+
+            a1 = table_a.alias()
+            a2 = table_a.alias()
+            ex = exists().where(table_b.c.b == a1.c.a)
+
+            stmt = (
+                select([a1.c.a, a2.c.a])
+                .select_from(a1.join(a2, a1.c.b == a2.c.b))
+                .where(ex)
+            )
+            return stmt
+
+        return [one(), one_diff(), two(), three()]
+
+    fixtures.append(_complex_fixtures)
+
+
+class CacheKeyFixture(object):
+    def _run_cache_key_fixture(self, fixture):
+        case_a = fixture()
+        case_b = fixture()
+
+        for a, b in itertools.combinations_with_replacement(
+            range(len(case_a)), 2
+        ):
+            if a == b:
+                a_key = case_a[a]._generate_cache_key()
+                b_key = case_b[b]._generate_cache_key()
+                eq_(a_key.key, b_key.key)
+
+                for a_param, b_param in zip(
+                    a_key.bindparams, b_key.bindparams
+                ):
+                    assert a_param.compare(b_param, compare_values=False)
+            else:
+                a_key = case_a[a]._generate_cache_key()
+                b_key = case_b[b]._generate_cache_key()
+
+                if a_key.key == b_key.key:
+                    for a_param, b_param in zip(
+                        a_key.bindparams, b_key.bindparams
+                    ):
+                        if not a_param.compare(b_param, compare_values=True):
+                            break
+                    else:
+                        # this fails unconditionally since we could not
+                        # find bound parameter values that differed.
+                        # Usually we intended to get two distinct keys here
+                        # so the failure will be more descriptive using the
+                        # ne_() assertion.
+                        ne_(a_key.key, b_key.key)
+                else:
+                    ne_(a_key.key, b_key.key)
+
+            # ClauseElement-specific test to ensure the cache key
+            # collected all the bound parameters
+            if isinstance(case_a[a], ClauseElement) and isinstance(
+                case_b[b], ClauseElement
+            ):
+                assert_a_params = []
+                assert_b_params = []
+                visitors.traverse_depthfirst(
+                    case_a[a], {}, {"bindparam": assert_a_params.append}
+                )
+                visitors.traverse_depthfirst(
+                    case_b[b], {}, {"bindparam": assert_b_params.append}
+                )
+
+                # note we're asserting the order of the params as well as
+                # if there are dupes or not.  ordering has to be deterministic
+                # and matches what a traversal would provide.
+                # regular traverse_depthfirst does produce dupes in cases like
+                # select([some_alias]).
+                #    select_from(join(some_alias, other_table))
+                # where a bound parameter is inside of some_alias.  the
+                # cache key case is more minimalistic
+                eq_(
+                    sorted(a_key.bindparams, key=lambda b: b.key),
+                    sorted(
+                        util.unique_list(assert_a_params), key=lambda b: b.key
+                    ),
+                )
+                eq_(
+                    sorted(b_key.bindparams, key=lambda b: b.key),
+                    sorted(
+                        util.unique_list(assert_b_params), key=lambda b: b.key
+                    ),
+                )
+
+
+class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
+    def test_cache_key(self):
+        for fixture in self.fixtures:
+            self._run_cache_key_fixture(fixture)
+
+    def test_cache_key_unknown_traverse(self):
+        class Foobar1(ClauseElement):
+            _traverse_internals = [
+                ("key", InternalTraversal.dp_anon_name),
+                ("type_", InternalTraversal.dp_unknown_structure),
+            ]
+
+            def __init__(self, key, type_):
+                self.key = key
+                self.type_ = type_
+
+        f1 = Foobar1("foo", String())
+        eq_(f1._generate_cache_key(), None)
+
+    def test_cache_key_no_method(self):
+        class Foobar1(ClauseElement):
+            pass
+
+        class Foobar2(ColumnElement):
+            pass
+
+        # the None for cache key will prevent objects
+        # which contain these elements from being cached.
+        f1 = Foobar1()
+        eq_(f1._generate_cache_key(), None)
+
+        f2 = Foobar2()
+        eq_(f2._generate_cache_key(), None)
+
+        s1 = select([column("q"), Foobar2()])
+
+        eq_(s1._generate_cache_key(), None)
+
+    def test_get_children_no_method(self):
+        class Foobar1(ClauseElement):
+            pass
+
+        class Foobar2(ColumnElement):
+            pass
+
+        f1 = Foobar1()
+        eq_(f1.get_children(), [])
+
+        f2 = Foobar2()
+        eq_(f2.get_children(), [])
+
+    def test_copy_internals_no_method(self):
+        class Foobar1(ClauseElement):
+            pass
+
+        class Foobar2(ColumnElement):
+            pass
+
+        f1 = Foobar1()
+        f2 = Foobar2()
+
+        f1._copy_internals()
+        f2._copy_internals()
+
+
+class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
     @classmethod
     def setup_class(cls):
         # TODO: we need to get dialects here somehow, perhaps in test_suite?
@@ -293,7 +591,10 @@ class CompareAndCopyTest(fixtures.TestBase):
             cls
             for cls in class_hierarchy(ClauseElement)
             if issubclass(cls, (ColumnElement, Selectable))
-            and "__init__" in cls.__dict__
+            and (
+                "__init__" in cls.__dict__
+                or issubclass(cls, AliasedReturnsRows)
+            )
             and not issubclass(cls, (Annotated))
             and "orm" not in cls.__module__
             and "compiler" not in cls.__module__
@@ -318,123 +619,16 @@ class CompareAndCopyTest(fixtures.TestBase):
             ):
                 if a == b:
                     is_true(
-                        case_a[a].compare(
-                            case_b[b], arbitrary_expression=True
-                        ),
+                        case_a[a].compare(case_b[b], compare_annotations=True),
                         "%r != %r" % (case_a[a], case_b[b]),
                     )
 
                 else:
                     is_false(
-                        case_a[a].compare(
-                            case_b[b], arbitrary_expression=True
-                        ),
+                        case_a[a].compare(case_b[b], compare_annotations=True),
                         "%r == %r" % (case_a[a], case_b[b]),
                     )
 
-    def test_cache_key(self):
-        def assert_params_append(assert_params):
-            def append(param):
-                if param._value_required_for_cache:
-                    assert_params.append(param)
-                else:
-                    is_(param.value, None)
-
-            return append
-
-        for fixture in self.fixtures:
-            case_a = fixture()
-            case_b = fixture()
-
-            for a, b in itertools.combinations_with_replacement(
-                range(len(case_a)), 2
-            ):
-
-                assert_a_params = []
-                assert_b_params = []
-
-                visitors.traverse_depthfirst(
-                    case_a[a],
-                    {},
-                    {"bindparam": assert_params_append(assert_a_params)},
-                )
-                visitors.traverse_depthfirst(
-                    case_b[b],
-                    {},
-                    {"bindparam": assert_params_append(assert_b_params)},
-                )
-                if assert_a_params:
-                    assert_raises_message(
-                        NotImplementedError,
-                        "bindparams collection argument required ",
-                        case_a[a]._cache_key,
-                    )
-                if assert_b_params:
-                    assert_raises_message(
-                        NotImplementedError,
-                        "bindparams collection argument required ",
-                        case_b[b]._cache_key,
-                    )
-
-                if not assert_a_params and not assert_b_params:
-                    if a == b:
-                        eq_(case_a[a]._cache_key(), case_b[b]._cache_key())
-                    else:
-                        ne_(case_a[a]._cache_key(), case_b[b]._cache_key())
-
-    def test_cache_key_gather_bindparams(self):
-        for fixture in self.fixtures:
-            case_a = fixture()
-            case_b = fixture()
-
-            # in the "bindparams" case, the cache keys for bound parameters
-            # with only different values will be the same, but the params
-            # themselves are gathered into a collection.
-            for a, b in itertools.combinations_with_replacement(
-                range(len(case_a)), 2
-            ):
-                a_params = {"bindparams": []}
-                b_params = {"bindparams": []}
-                if a == b:
-                    a_key = case_a[a]._cache_key(**a_params)
-                    b_key = case_b[b]._cache_key(**b_params)
-                    eq_(a_key, b_key)
-
-                    if a_params["bindparams"]:
-                        for a_param, b_param in zip(
-                            a_params["bindparams"], b_params["bindparams"]
-                        ):
-                            assert a_param.compare(b_param)
-                else:
-                    a_key = case_a[a]._cache_key(**a_params)
-                    b_key = case_b[b]._cache_key(**b_params)
-
-                    if a_key == b_key:
-                        for a_param, b_param in zip(
-                            a_params["bindparams"], b_params["bindparams"]
-                        ):
-                            if not a_param.compare(b_param):
-                                break
-                        else:
-                            assert False, "Bound parameters are all the same"
-                    else:
-                        ne_(a_key, b_key)
-
-                assert_a_params = []
-                assert_b_params = []
-                visitors.traverse_depthfirst(
-                    case_a[a], {}, {"bindparam": assert_a_params.append}
-                )
-                visitors.traverse_depthfirst(
-                    case_b[b], {}, {"bindparam": assert_b_params.append}
-                )
-
-                # note we're asserting the order of the params as well as
-                # if there are dupes or not.  ordering has to be deterministic
-                # and matches what a traversal would provide.
-                eq_(a_params["bindparams"], assert_a_params)
-                eq_(b_params["bindparams"], assert_b_params)
-
     def test_compare_col_identity(self):
         stmt1 = (
             select([table_a.c.a, table_b.c.b])
@@ -473,8 +667,9 @@ class CompareAndCopyTest(fixtures.TestBase):
 
             assert case_a[0].compare(case_b[0])
 
-            clone = case_a[0]._clone()
-            clone._copy_internals()
+            clone = visitors.replacement_traverse(
+                case_a[0], {}, lambda elem: None
+            )
 
             assert clone.compare(case_b[0])
 
@@ -511,6 +706,37 @@ class CompareAndCopyTest(fixtures.TestBase):
 
 
 class CompareClausesTest(fixtures.TestBase):
+    def test_compare_metadata_tables(self):
+        # metadata Table objects cache on their own identity, not their
+        # structure.   This is mainly to reduce the size of cache keys
+        # as well as reduce computational overhead, as Table objects have
+        # very large internal state and they are also generally global
+        # objects.
+
+        t1 = Table("a", MetaData(), Column("q", Integer), Column("p", Integer))
+        t2 = Table("a", MetaData(), Column("q", Integer), Column("p", Integer))
+
+        ne_(t1._generate_cache_key(), t2._generate_cache_key())
+
+        eq_(t1._generate_cache_key().key, (t1,))
+
+    def test_compare_adhoc_tables(self):
+        # non-metadata tables compare on their structure.  these objects are
+        # not commonly used.
+
+        # note this test is a bit redundant as we have a similar test
+        # via the fixtures also
+        t1 = table("a", Column("q", Integer), Column("p", Integer))
+        t2 = table("a", Column("q", Integer), Column("p", Integer))
+        t3 = table("b", Column("q", Integer), Column("p", Integer))
+        t4 = table("a", Column("q", Integer), Column("x", Integer))
+
+        eq_(t1._generate_cache_key(), t2._generate_cache_key())
+
+        ne_(t1._generate_cache_key(), t3._generate_cache_key())
+        ne_(t1._generate_cache_key(), t4._generate_cache_key())
+        ne_(t3._generate_cache_key(), t4._generate_cache_key())
+
     def test_compare_comparison_associative(self):
 
         l1 = table_c.c.x == table_d.c.y
@@ -521,6 +747,15 @@ class CompareClausesTest(fixtures.TestBase):
         is_true(l1.compare(l2))
         is_false(l1.compare(l3))
 
+    def test_compare_comparison_non_commutative_inverses(self):
+        l1 = table_c.c.x >= table_d.c.y
+        l2 = table_d.c.y < table_c.c.x
+        l3 = table_d.c.y <= table_c.c.x
+
+        # we're not doing this kind of commutativity right now.
+        is_false(l1.compare(l2))
+        is_false(l1.compare(l3))
+
     def test_compare_clauselist_associative(self):
 
         l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z)
@@ -624,3 +859,45 @@ class CompareClausesTest(fixtures.TestBase):
                 use_proxies=True,
             )
         )
+
+    def test_compare_annotated_clears_mapping(self):
+        t = table("t", column("x"), column("y"))
+        x_a = t.c.x._annotate({"foo": True})
+        x_b = t.c.x._annotate({"foo": True})
+
+        is_true(x_a.compare(x_b, compare_annotations=True))
+        is_false(
+            x_a.compare(x_b._annotate({"bar": True}), compare_annotations=True)
+        )
+
+        s1 = select([t.c.x])._annotate({"foo": True})
+        s2 = select([t.c.x])._annotate({"foo": True})
+
+        is_true(s1.compare(s2, compare_annotations=True))
+
+        is_false(
+            s1.compare(s2._annotate({"bar": True}), compare_annotations=True)
+        )
+
+    def test_compare_annotated_wo_annotations(self):
+        t = table("t", column("x"), column("y"))
+        x_a = t.c.x._annotate({})
+        x_b = t.c.x._annotate({"foo": True})
+
+        is_true(t.c.x.compare(x_a))
+        is_true(x_b.compare(x_a))
+
+        is_true(x_a.compare(t.c.x))
+        is_false(x_a.compare(t.c.y))
+        is_false(t.c.y.compare(x_a))
+        is_true((t.c.x == 5).compare(x_a == 5))
+        is_false((t.c.y == 5).compare(x_a == 5))
+
+        s = select([t]).subquery()
+        x_p = s.c.x
+        is_false(x_a.compare(x_p))
+        is_false(t.c.x.compare(x_p))
+        x_p_a = x_p._annotate({})
+        is_true(x_p_a.compare(x_p))
+        is_true(x_p.compare(x_p_a))
+        is_false(x_p_a.compare(x_a))
similarity index 99%
rename from test/sql/test_generative.py
rename to test/sql/test_external_traversal.py
index 8d347a522abba995216510a8392aa8948f789c6d..8bfe5cf6f8d1416fbc7500fa9c1dad10e0fc96e0 100644 (file)
@@ -55,6 +55,7 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults):
         # identity semantics.
         class A(ClauseElement):
             __visit_name__ = "a"
+            _traverse_internals = []
 
             def __init__(self, expr):
                 self.expr = expr
index 637f1f8a5a4a3a032979f3bec84782e7c26b3f24..06cfdc4b5ab4eb23e83dc730bdf39880a902979b 100644 (file)
@@ -118,11 +118,14 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
             )
         )
 
+        modifiers = operator(left, right).modifiers
+
         assert operator(left, right).compare(
             BinaryExpression(
                 coercions.expect(roles.WhereHavingRole, left),
                 coercions.expect(roles.WhereHavingRole, right),
                 operator,
+                modifiers=modifiers,
             )
         )
 
index 2bc7ccc9317a9b47f9c40991045f5f5d0a1d669e..184e4a99c27a86ff4b8a5306c49112f7f7cdf6ef 100644 (file)
@@ -1070,7 +1070,7 @@ class SelectableTest(
         s4 = s3.with_only_columns([table2.c.b])
         self.assert_compile(s4, "SELECT t2.b FROM t2")
 
-    def test_from_list_warning_against_existing(self):
+    def test_from_list_against_existing_one(self):
         c1 = Column("c1", Integer)
         s = select([c1])
 
@@ -1081,7 +1081,7 @@ class SelectableTest(
 
         self.assert_compile(s, "SELECT t.c1 FROM t")
 
-    def test_from_list_recovers_after_warning(self):
+    def test_from_list_against_existing_two(self):
         c1 = Column("c1", Integer)
         c2 = Column("c2", Integer)
 
@@ -1090,18 +1090,11 @@ class SelectableTest(
         # force a compile.
         eq_(str(s), "SELECT c1")
 
-        @testing.emits_warning()
-        def go():
-            return Table("t", MetaData(), c1, c2)
-
-        t = go()
+        t = Table("t", MetaData(), c1, c2)
 
         eq_(c1._from_objects, [t])
         eq_(c2._from_objects, [t])
 
-        # 's' has been baked.  Can't afford
-        # not caching select._froms.
-        # hopefully the warning will clue the user
         self.assert_compile(s, "SELECT t.c1 FROM t")
         self.assert_compile(select([c1]), "SELECT t.c1 FROM t")
         self.assert_compile(select([c2]), "SELECT t.c2 FROM t")
@@ -1124,6 +1117,26 @@ class SelectableTest(
             "foo",
         )
 
+    def test_whereclause_adapted(self):
+        table1 = table("t1", column("a"))
+
+        s1 = select([table1]).subquery()
+
+        s2 = select([s1]).where(s1.c.a == 5)
+
+        assert s2._whereclause.left.table is s1
+
+        ta = select([table1]).subquery()
+
+        s3 = sql_util.ClauseAdapter(ta).traverse(s2)
+
+        assert s1 not in s3._froms
+
+        # these are new assumptions with the newer approach that
+        # actively swaps out whereclause and others
+        assert s3._whereclause.left.table is not s1
+        assert s3._whereclause.left.table in s3._froms
+
 
 class RefreshForNewColTest(fixtures.TestBase):
     def test_join_uninit(self):
@@ -2241,25 +2254,6 @@ class AnnotationsTest(fixtures.TestBase):
             annot = obj._annotate({})
             ne_(set([obj]), set([annot]))
 
-    def test_compare(self):
-        t = table("t", column("x"), column("y"))
-        x_a = t.c.x._annotate({})
-        assert t.c.x.compare(x_a)
-        assert x_a.compare(t.c.x)
-        assert not x_a.compare(t.c.y)
-        assert not t.c.y.compare(x_a)
-        assert (t.c.x == 5).compare(x_a == 5)
-        assert not (t.c.y == 5).compare(x_a == 5)
-
-        s = select([t]).subquery()
-        x_p = s.c.x
-        assert not x_a.compare(x_p)
-        assert not t.c.x.compare(x_p)
-        x_p_a = x_p._annotate({})
-        assert x_p_a.compare(x_p)
-        assert x_p.compare(x_p_a)
-        assert not x_p_a.compare(x_a)
-
     def test_proxy_set_iteration_includes_annotated(self):
         from sqlalchemy.schema import Column
 
@@ -2542,13 +2536,13 @@ class AnnotationsTest(fixtures.TestBase):
         ):
             # the columns clause isn't changed at all
             assert sel._raw_columns[0].table is a1
-            assert sel._froms[0] is sel._froms[1].left
+            assert sel._froms[0].element is sel._froms[1].left.element
 
             eq_(str(s), str(sel))
 
         # when we are modifying annotations sets only
-        # partially, each element is copied unconditionally
-        # when encountered.
+        # partially, elements are copied uniquely based on id().
+        # this is new as of 1.4, previously they'd be copied every time
         for sel in (
             sql_util._deep_deannotate(s, {"foo": "bar"}),
             sql_util._deep_annotate(s, {"foo": "bar"}),
index 988d5331eb6ad6a5c3b23ed5773916b0d4af7b9e..48d6de6db0d17abf9496aef636711e5a3f5e4b9c 100644 (file)
@@ -7,6 +7,6 @@ from sqlalchemy.testing import fixtures
 class MiscTest(fixtures.TestBase):
     def test_column_element_no_visit(self):
         class MyElement(ColumnElement):
-            pass
+            _traverse_internals = []
 
         eq_(sql_util.find_tables(MyElement(), check_columns=True), [])