]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
remove internal use of metaclasses
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 10 Jan 2022 21:48:05 +0000 (16:48 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 11 Jan 2022 18:06:57 +0000 (13:06 -0500)
All but one metaclass used internally can now
be replaced using __init_subclass__().  Within this
patch we remove:

* events._EventMeta
* sql.visitors.TraversibleType
* sql.visitors.InternalTraversibleType
* testing.fixtures.FindFixture
* testing.fixtures.FindFixtureDeclarative
* langhelpers.EnsureKWArgType
* sql.functions._GenericMeta
* sql.type_api.VisitableCheckKWArg (was a mixture of TraversibleType
  and EnsureKWArgType)

The remaining internal class is MetaOptions used by the
sql.Options object which is in turn currently mostly for
ORM internal use, as this type implements class level overrides
for the ``+`` operator.

For declarative, removing DeclarativeMeta in place of
an `__init_subclass__()` class would not be fully feasible as
it would break backwards compatibility with applications that
refer to this class explicitly, but also DeclarativeMeta intercepts
class-level attribute set and delete operations which is a widely
used pattern.   An option for declarative base to use
`__init_subclass__()` should be provided but this is out of
scope for this particular change.

Change-Id: I8aa898c7ab59d887739037d34b1cbab36521ab78
References: #6810

26 files changed:
lib/sqlalchemy/event/base.py
lib/sqlalchemy/ext/mutable.py
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/path_registry.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategy_options.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/cache_key.py [new file with mode: 0644]
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/lambdas.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/traversals.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/testing/fixtures.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/langhelpers.py
test/orm/test_events.py
test/sql/test_lambdas.py

index c5b03dd721a72d786f64f54c30e9a6b76374f864..25d36924081e3a7718f801585fc58b43b45e88ab 100644 (file)
@@ -15,13 +15,16 @@ at the class level of a particular ``_Dispatch`` class as well as within
 instances of ``_Dispatch``.
 
 """
+from typing import ClassVar
+from typing import Optional
+from typing import Type
 import weakref
 
 from .attr import _ClsLevelDispatch
 from .attr import _EmptyListener
 from .attr import _JoinedListener
 from .. import util
-
+from ..util.typing import Protocol
 
 _registrars = util.defaultdict(list)
 
@@ -63,8 +66,8 @@ class _Dispatch:
     of the :class:`._Dispatch` class is returned.
 
     A :class:`._Dispatch` class is generated for each :class:`.Events`
-    class defined, by the :func:`._create_dispatcher_class` function.
-    The original :class:`.Events` classes remain untouched.
+    class defined, by the :meth:`._HasEventsDispatch._create_dispatcher_class`
+    method.  The original :class:`.Events` classes remain untouched.
     This decouples the construction of :class:`.Events` subclasses from
     the implementation used by the event internals, and allows
     inspecting tools like Sphinx to work in an unsurprising
@@ -78,6 +81,13 @@ class _Dispatch:
 
     _empty_listener_reg = weakref.WeakKeyDictionary()
 
+    _events: Type["_HasEventsDispatch"]
+    """reference back to the Events class.
+
+    Bidirectional against _HasEventsDispatch.dispatch
+
+    """
+
     def __init__(self, parent, instance_cls=None):
         self._parent = parent
         self._instance_cls = instance_cls
@@ -159,56 +169,6 @@ class _Dispatch:
             ls.for_modify(self).clear()
 
 
-class _EventMeta(type):
-    """Intercept new Event subclasses and create
-    associated _Dispatch classes."""
-
-    def __init__(cls, classname, bases, dict_):
-        _create_dispatcher_class(cls, classname, bases, dict_)
-        type.__init__(cls, classname, bases, dict_)
-
-
-def _create_dispatcher_class(cls, classname, bases, dict_):
-    """Create a :class:`._Dispatch` class corresponding to an
-    :class:`.Events` class."""
-
-    # there's all kinds of ways to do this,
-    # i.e. make a Dispatch class that shares the '_listen' method
-    # of the Event class, this is the straight monkeypatch.
-    if hasattr(cls, "dispatch"):
-        dispatch_base = cls.dispatch.__class__
-    else:
-        dispatch_base = _Dispatch
-
-    event_names = [k for k in dict_ if _is_event_name(k)]
-    dispatch_cls = type(
-        "%sDispatch" % classname, (dispatch_base,), {"__slots__": event_names}
-    )
-
-    dispatch_cls._event_names = event_names
-
-    dispatch_inst = cls._set_dispatch(cls, dispatch_cls)
-    for k in dispatch_cls._event_names:
-        setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k]))
-        _registrars[k].append(cls)
-
-    for super_ in dispatch_cls.__bases__:
-        if issubclass(super_, _Dispatch) and super_ is not _Dispatch:
-            for ls in super_._events.dispatch._event_descriptors:
-                setattr(dispatch_inst, ls.name, ls)
-                dispatch_cls._event_names.append(ls.name)
-
-    if getattr(cls, "_dispatch_target", None):
-        the_cls = cls._dispatch_target
-        if (
-            hasattr(the_cls, "__slots__")
-            and "_slots_dispatch" in the_cls.__slots__
-        ):
-            cls._dispatch_target.dispatch = slots_dispatcher(cls)
-        else:
-            cls._dispatch_target.dispatch = dispatcher(cls)
-
-
 def _remove_dispatcher(cls):
     for k in cls.dispatch._event_names:
         _registrars[k].remove(cls)
@@ -216,8 +176,31 @@ def _remove_dispatcher(cls):
             del _registrars[k]
 
 
-class Events(metaclass=_EventMeta):
-    """Define event listening functions for a particular target type."""
+class _HasEventsDispatchProto(Protocol):
+    """protocol for non-event classes that will also receive the 'dispatch'
+    attribute in the form of a descriptor.
+
+    """
+
+    dispatch: ClassVar["dispatcher"]
+
+
+class _HasEventsDispatch:
+    _dispatch_target: Optional[Type[_HasEventsDispatchProto]]
+    """class which will receive the .dispatch collection"""
+
+    dispatch: _Dispatch
+    """reference back to the _Dispatch class.
+
+    Bidirectional against _Dispatch._events
+
+    """
+
+    def __init_subclass__(cls) -> None:
+        """Intercept new Event subclasses and create associated _Dispatch
+        classes."""
+
+        cls._create_dispatcher_class(cls.__name__, cls.__bases__, cls.__dict__)
 
     @staticmethod
     def _set_dispatch(cls, dispatch_cls):
@@ -230,6 +213,54 @@ class Events(metaclass=_EventMeta):
         dispatch_cls._events = cls
         return cls.dispatch
 
+    @classmethod
+    def _create_dispatcher_class(cls, classname, bases, dict_):
+        """Create a :class:`._Dispatch` class corresponding to an
+        :class:`.Events` class."""
+
+        # there's all kinds of ways to do this,
+        # i.e. make a Dispatch class that shares the '_listen' method
+        # of the Event class, this is the straight monkeypatch.
+        if hasattr(cls, "dispatch"):
+            dispatch_base = cls.dispatch.__class__
+        else:
+            dispatch_base = _Dispatch
+
+        event_names = [k for k in dict_ if _is_event_name(k)]
+        dispatch_cls = type(
+            "%sDispatch" % classname,
+            (dispatch_base,),
+            {"__slots__": event_names},
+        )
+
+        dispatch_cls._event_names = event_names
+
+        dispatch_inst = cls._set_dispatch(cls, dispatch_cls)
+        for k in dispatch_cls._event_names:
+            setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k]))
+            _registrars[k].append(cls)
+
+        for super_ in dispatch_cls.__bases__:
+            if issubclass(super_, _Dispatch) and super_ is not _Dispatch:
+                for ls in super_._events.dispatch._event_descriptors:
+                    setattr(dispatch_inst, ls.name, ls)
+                    dispatch_cls._event_names.append(ls.name)
+
+        if getattr(cls, "_dispatch_target", None):
+            dispatch_target_cls = cls._dispatch_target
+            assert dispatch_target_cls is not None
+            if (
+                hasattr(dispatch_target_cls, "__slots__")
+                and "_slots_dispatch" in dispatch_target_cls.__slots__
+            ):
+                dispatch_target_cls.dispatch = slots_dispatcher(cls)
+            else:
+                dispatch_target_cls.dispatch = dispatcher(cls)
+
+
+class Events(_HasEventsDispatch):
+    """Define event listening functions for a particular target type."""
+
     @classmethod
     def _accept_with(cls, target):
         def dispatch_is(*types):
index a2e1a38260b04a783c75630f863299ea49c2b7b2..d75cf667bd7eaa303ea4fe1abf40f26d55066f9d 100644 (file)
@@ -269,10 +269,10 @@ and to also route attribute set events via ``__setattr__`` to the
         def __ne__(self, other):
             return not self.__eq__(other)
 
-The :class:`.MutableComposite` class uses a Python metaclass to automatically
-establish listeners for any usage of :func:`_orm.composite` that specifies our
-``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` class,
-listeners are established which will route change events from ``Point``
+The :class:`.MutableComposite` class makes use of class mapping events to
+automatically establish listeners for any usage of :func:`_orm.composite` that
+specifies our ``Point`` type. Below, when ``Point`` is mapped to the ``Vertex``
+class, listeners are established which will route change events from ``Point``
 objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes::
 
     from sqlalchemy.orm import composite, mapper
index 1094fa51642f11cdb6bd44999b6d41aeb249a3b4..2a5b1bb2b0e2467c586c0f656316c075528cc8cf 100644 (file)
@@ -51,6 +51,10 @@ def has_inherited_table(cls):
 
 
 class DeclarativeMeta(type):
+    # DeclarativeMeta could be replaced by __subclass_init__()
+    # except for the class-level __setattr__() and __delattr__ hooks,
+    # which are still very important.
+
     def __init__(cls, classname, bases, dict_, **kw):
         # early-consume registry from the initial declarative base,
         # assign privately to not conflict with subclass attributes named
index 6c84f89cfad3b1352e588553e011d55a1423c021..d842df2215657769a307855712c21fd8b571df3e 100644 (file)
@@ -37,7 +37,7 @@ from ..sql import operators
 from ..sql import roles
 from ..sql import visitors
 from ..sql.base import ExecutableOption
-from ..sql.traversals import HasCacheKey
+from ..sql.cache_key import HasCacheKey
 
 
 __all__ = (
index 2e64696d9c580ebeef4970fdc72ac96d8fa2bcd9..0d87739cc196b3dc36a13f3c2485eec8a31bc4d3 100644 (file)
@@ -20,7 +20,7 @@ from .. import exc
 from .. import inspection
 from .. import util
 from ..sql import visitors
-from ..sql.traversals import HasCacheKey
+from ..sql.cache_key import HasCacheKey
 
 log = logging.getLogger(__name__)
 
index ad31c24329b609df03373c85fd3ffa9706aa96cd..e6d16f1788a68bbe6d71164f522b89172d81e941 100644 (file)
@@ -134,7 +134,7 @@ class Query(
     # local Query builder state, not needed for
     # compilation or execution
     _enable_assertions = True
-    _last_joined_entity = None
+
     _statement = None
 
     # mirrors that of ClauseElement, used to propagate the "orm"
index 4df275c71de612220e0ff5a3200f03c67ba63232..c2cfbb9fc2184f73711421ccd2b4bb6583dc281e 100644 (file)
@@ -29,9 +29,9 @@ from .. import exc as sa_exc
 from .. import inspect
 from .. import util
 from ..sql import and_
+from ..sql import cache_key
 from ..sql import coercions
 from ..sql import roles
-from ..sql import traversals
 from ..sql import visitors
 from ..sql.base import _generative
 from ..sql.base import Generative
@@ -1316,7 +1316,7 @@ class _WildcardLoad(_AbstractLoad):
         self.__dict__.update(state)
 
 
-class _LoadElement(traversals.HasCacheKey):
+class _LoadElement(cache_key.HasCacheKey):
     """represents strategy information to select for a LoaderStrategy
     and pass options to it.
 
index 805f7b1a02037cc5150329538b9deb782fdfee87..6ab9a75f6f01857280123a141c7b0cd2c9f05ca4 100644 (file)
@@ -20,9 +20,9 @@ import typing
 
 from . import roles
 from . import visitors
-from .traversals import HasCacheKey  # noqa
+from .cache_key import HasCacheKey  # noqa
+from .cache_key import MemoizedHasCacheKey  # noqa
 from .traversals import HasCopyInternals  # noqa
-from .traversals import MemoizedHasCacheKey  # noqa
 from .visitors import ClauseVisitor
 from .visitors import ExtendedInternalTraversal
 from .visitors import InternalTraversal
@@ -37,7 +37,6 @@ try:
 except ImportError:
     from ._py_util import prefix_anon_map  # noqa
 
-
 coercions = None
 elements = None
 type_api = None
@@ -610,18 +609,13 @@ class HasCompileState(Generative):
 
 
 class _MetaOptions(type):
-    """metaclass for the Options class."""
+    """metaclass for the Options class.
 
-    def __init__(cls, classname, bases, dict_):
-        cls._cache_attrs = tuple(
-            sorted(
-                d
-                for d in dict_
-                if not d.startswith("__")
-                and d not in ("_cache_key_traversal",)
-            )
-        )
-        type.__init__(cls, classname, bases, dict_)
+    This metaclass is actually necessary despite the availability of the
+    ``__init_subclass__()`` hook as this type also provides custom class-level
+    behavior for the ``__add__()`` method.
+
+    """
 
     def __add__(self, other):
         o1 = self()
@@ -640,6 +634,18 @@ class _MetaOptions(type):
 class Options(metaclass=_MetaOptions):
     """A cacheable option dictionary with defaults."""
 
+    def __init_subclass__(cls) -> None:
+        dict_ = cls.__dict__
+        cls._cache_attrs = tuple(
+            sorted(
+                d
+                for d in dict_
+                if not d.startswith("__")
+                and d not in ("_cache_key_traversal",)
+            )
+        )
+        super().__init_subclass__()
+
     def __init__(self, **kw):
         self.__dict__.update(kw)
 
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py
new file mode 100644 (file)
index 0000000..8dd44db
--- /dev/null
@@ -0,0 +1,762 @@
+# sql/cache_key.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from collections import namedtuple
+import enum
+from itertools import zip_longest
+from typing import Callable
+from typing import Union
+
+from .visitors import anon_map
+from .visitors import ExtendedInternalTraversal
+from .visitors import InternalTraversal
+from .. import util
+from ..inspection import inspect
+from ..util import HasMemoized
+from ..util.typing import Literal
+
+
+class CacheConst(enum.Enum):
+    NO_CACHE = 0
+
+
+NO_CACHE = CacheConst.NO_CACHE
+
+
+class CacheTraverseTarget(enum.Enum):
+    CACHE_IN_PLACE = 0
+    CALL_GEN_CACHE_KEY = 1
+    STATIC_CACHE_KEY = 2
+    PROPAGATE_ATTRS = 3
+    ANON_NAME = 4
+
+
+(
+    CACHE_IN_PLACE,
+    CALL_GEN_CACHE_KEY,
+    STATIC_CACHE_KEY,
+    PROPAGATE_ATTRS,
+    ANON_NAME,
+) = tuple(CacheTraverseTarget)
+
+
+class HasCacheKey:
+    """Mixin for objects which can produce a cache key.
+
+    .. seealso::
+
+        :class:`.CacheKey`
+
+        :ref:`sql_caching`
+
+    """
+
+    _cache_key_traversal = NO_CACHE
+
+    _is_has_cache_key = True
+
+    _hierarchy_supports_caching = True
+    """private attribute which may be set to False to prevent the
+    inherit_cache warning from being emitted for a hierarchy of subclasses.
+
+    Currently applies to the DDLElement hierarchy which does not implement
+    caching.
+
+    """
+
+    inherit_cache = None
+    """Indicate if this :class:`.HasCacheKey` instance should make use of the
+    cache key generation scheme used by its immediate superclass.
+
+    The attribute defaults to ``None``, which indicates that a construct has
+    not yet taken into account whether or not its appropriate for it to
+    participate in caching; this is functionally equivalent to setting the
+    value to ``False``, except that a warning is also emitted.
+
+    This flag can be set to ``True`` on a particular class, if the SQL that
+    corresponds to the object does not change based on attributes which
+    are local to this class, and not its superclass.
+
+    .. seealso::
+
+        :ref:`compilerext_caching` - General guideslines for setting the
+        :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user
+        defined SQL constructs.
+
+    """
+
+    __slots__ = ()
+
+    @classmethod
+    def _generate_cache_attrs(cls):
+        """generate cache key dispatcher for a new class.
+
+        This sets the _generated_cache_key_traversal attribute once called
+        so should only be called once per class.
+
+        """
+        inherit_cache = cls.__dict__.get("inherit_cache", None)
+        inherit = bool(inherit_cache)
+
+        if inherit:
+            _cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
+            if _cache_key_traversal is None:
+                try:
+                    _cache_key_traversal = cls._traverse_internals
+                except AttributeError:
+                    cls._generated_cache_key_traversal = NO_CACHE
+                    return NO_CACHE
+
+            # TODO: wouldn't we instead get this from our superclass?
+            # also, our superclass may not have this yet, but in any case,
+            # we'd generate for the superclass that has it.   this is a little
+            # more complicated, so for the moment this is a little less
+            # efficient on startup but simpler.
+            return _cache_key_traversal_visitor.generate_dispatch(
+                cls, _cache_key_traversal, "_generated_cache_key_traversal"
+            )
+        else:
+            _cache_key_traversal = cls.__dict__.get(
+                "_cache_key_traversal", None
+            )
+            if _cache_key_traversal is None:
+                _cache_key_traversal = cls.__dict__.get(
+                    "_traverse_internals", None
+                )
+                if _cache_key_traversal is None:
+                    cls._generated_cache_key_traversal = NO_CACHE
+                    if (
+                        inherit_cache is None
+                        and cls._hierarchy_supports_caching
+                    ):
+                        util.warn(
+                            "Class %s will not make use of SQL compilation "
+                            "caching as it does not set the 'inherit_cache' "
+                            "attribute to ``True``.  This can have "
+                            "significant performance implications including "
+                            "some performance degradations in comparison to "
+                            "prior SQLAlchemy versions.  Set this attribute "
+                            "to True if this object can make use of the cache "
+                            "key generated by the superclass.  Alternatively, "
+                            "this attribute may be set to False which will "
+                            "disable this warning." % (cls.__name__),
+                            code="cprf",
+                        )
+                    return NO_CACHE
+
+            return _cache_key_traversal_visitor.generate_dispatch(
+                cls, _cache_key_traversal, "_generated_cache_key_traversal"
+            )
+
+    @util.preload_module("sqlalchemy.sql.elements")
+    def _gen_cache_key(self, anon_map, bindparams):
+        """return an optional cache key.
+
+        The cache key is a tuple which can contain any series of
+        objects that are hashable and also identifies
+        this object uniquely within the presence of a larger SQL expression
+        or statement, for the purposes of caching the resulting query.
+
+        The cache key should be based on the SQL compiled structure that would
+        ultimately be produced.   That is, two structures that are composed in
+        exactly the same way should produce the same cache key; any difference
+        in the structures that would affect the SQL string or the type handlers
+        should result in a different cache key.
+
+        If a structure cannot produce a useful cache key, the NO_CACHE
+        symbol should be added to the anon_map and the method should
+        return None.
+
+        """
+
+        cls = self.__class__
+
+        id_, found = anon_map.get_anon(self)
+        if found:
+            return (id_, cls)
+
+        dispatcher: Union[
+            Literal[CacheConst.NO_CACHE],
+            Callable[[HasCacheKey, "_CacheKeyTraversal"], "CacheKey"],
+        ]
+
+        try:
+            dispatcher = cls.__dict__["_generated_cache_key_traversal"]
+        except KeyError:
+            # most of the dispatchers are generated up front
+            # in sqlalchemy/sql/__init__.py ->
+            # traversals.py-> _preconfigure_traversals().
+            # this block will generate any remaining dispatchers.
+            dispatcher = cls._generate_cache_attrs()
+
+        if dispatcher is NO_CACHE:
+            anon_map[NO_CACHE] = True
+            return None
+
+        result = (id_, cls)
+
+        # inline of _cache_key_traversal_visitor.run_generated_dispatch()
+
+        for attrname, obj, meth in dispatcher(
+            self, _cache_key_traversal_visitor
+        ):
+            if obj is not None:
+                # TODO: see if C code can help here as Python lacks an
+                # efficient switch construct
+
+                if meth is STATIC_CACHE_KEY:
+                    sck = obj._static_cache_key
+                    if sck is NO_CACHE:
+                        anon_map[NO_CACHE] = True
+                        return None
+                    result += (attrname, sck)
+                elif meth is ANON_NAME:
+                    elements = util.preloaded.sql_elements
+                    if isinstance(obj, elements._anonymous_label):
+                        obj = obj.apply_map(anon_map)
+                    result += (attrname, obj)
+                elif meth is CALL_GEN_CACHE_KEY:
+                    result += (
+                        attrname,
+                        obj._gen_cache_key(anon_map, bindparams),
+                    )
+
+                # remaining cache functions are against
+                # Python tuples, dicts, lists, etc. so we can skip
+                # if they are empty
+                elif obj:
+                    if meth is CACHE_IN_PLACE:
+                        result += (attrname, obj)
+                    elif meth is PROPAGATE_ATTRS:
+                        result += (
+                            attrname,
+                            obj["compile_state_plugin"],
+                            obj["plugin_subject"]._gen_cache_key(
+                                anon_map, bindparams
+                            )
+                            if obj["plugin_subject"]
+                            else None,
+                        )
+                    elif meth is InternalTraversal.dp_annotations_key:
+                        # obj is here is the _annotations dict.   however, we
+                        # want to use the memoized cache key version of it. for
+                        # Columns, this should be long lived.   For select()
+                        # statements, not so much, but they usually won't have
+                        # annotations.
+                        result += self._annotations_cache_key
+                    elif (
+                        meth is InternalTraversal.dp_clauseelement_list
+                        or meth is InternalTraversal.dp_clauseelement_tuple
+                        or meth
+                        is InternalTraversal.dp_memoized_select_entities
+                    ):
+                        result += (
+                            attrname,
+                            tuple(
+                                [
+                                    elem._gen_cache_key(anon_map, bindparams)
+                                    for elem in obj
+                                ]
+                            ),
+                        )
+                    else:
+                        result += meth(
+                            attrname, obj, self, anon_map, bindparams
+                        )
+        return result
+
+    def _generate_cache_key(self):
+        """return a cache key.
+
+        The cache key is a tuple which can contain any series of
+        objects that are hashable and also identifies
+        this object uniquely within the presence of a larger SQL expression
+        or statement, for the purposes of caching the resulting query.
+
+        The cache key should be based on the SQL compiled structure that would
+        ultimately be produced.   That is, two structures that are composed in
+        exactly the same way should produce the same cache key; any difference
+        in the structures that would affect the SQL string or the type handlers
+        should result in a different cache key.
+
+        The cache key returned by this method is an instance of
+        :class:`.CacheKey`, which consists of a tuple representing the
+        cache key, as well as a list of :class:`.BindParameter` objects
+        which are extracted from the expression.   While two expressions
+        that produce identical cache key tuples will themselves generate
+        identical SQL strings, the list of :class:`.BindParameter` objects
+        indicates the bound values which may have different values in
+        each one; these bound parameters must be consulted in order to
+        execute the statement with the correct parameters.
+
+        a :class:`_expression.ClauseElement` structure that does not implement
+        a :meth:`._gen_cache_key` method and does not implement a
+        :attr:`.traverse_internals` attribute will not be cacheable; when
+        such an element is embedded into a larger structure, this method
+        will return None, indicating no cache key is available.
+
+        """
+
+        bindparams = []
+
+        _anon_map = anon_map()
+        key = self._gen_cache_key(_anon_map, bindparams)
+        if NO_CACHE in _anon_map:
+            return None
+        else:
+            return CacheKey(key, bindparams)
+
+    @classmethod
+    def _generate_cache_key_for_object(cls, obj):
+        bindparams = []
+
+        _anon_map = anon_map()
+        key = obj._gen_cache_key(_anon_map, bindparams)
+        if NO_CACHE in _anon_map:
+            return None
+        else:
+            return CacheKey(key, bindparams)
+
+
+class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
+    @HasMemoized.memoized_instancemethod
+    def _generate_cache_key(self):
+        return HasCacheKey._generate_cache_key(self)
+
+
+class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
+    """The key used to identify a SQL statement construct in the
+    SQL compilation cache.
+
+    .. seealso::
+
+        :ref:`sql_caching`
+
+    """
+
+    def __hash__(self):
+        """CacheKey itself is not hashable - hash the .key portion"""
+
+        return None
+
+    def to_offline_string(self, statement_cache, statement, parameters):
+        """Generate an "offline string" form of this :class:`.CacheKey`
+
+        The "offline string" is basically the string SQL for the
+        statement plus a repr of the bound parameter values in series.
+        Whereas the :class:`.CacheKey` object is dependent on in-memory
+        identities in order to work as a cache key, the "offline" version
+        is suitable for a cache that will work for other processes as well.
+
+        The given ``statement_cache`` is a dictionary-like object where the
+        string form of the statement itself will be cached.  This dictionary
+        should be in a longer lived scope in order to reduce the time spent
+        stringifying statements.
+
+
+        """
+        if self.key not in statement_cache:
+            statement_cache[self.key] = sql_str = str(statement)
+        else:
+            sql_str = statement_cache[self.key]
+
+        if not self.bindparams:
+            param_tuple = tuple(parameters[key] for key in sorted(parameters))
+        else:
+            param_tuple = tuple(
+                parameters.get(bindparam.key, bindparam.value)
+                for bindparam in self.bindparams
+            )
+
+        return repr((sql_str, param_tuple))
+
+    def __eq__(self, other):
+        return self.key == other.key
+
+    @classmethod
+    def _diff_tuples(cls, left, right):
+        ck1 = CacheKey(left, [])
+        ck2 = CacheKey(right, [])
+        return ck1._diff(ck2)
+
+    def _whats_different(self, other):
+
+        k1 = self.key
+        k2 = other.key
+
+        stack = []
+        pickup_index = 0
+        while True:
+            s1, s2 = k1, k2
+            for idx in stack:
+                s1 = s1[idx]
+                s2 = s2[idx]
+
+            for idx, (e1, e2) in enumerate(zip_longest(s1, s2)):
+                if idx < pickup_index:
+                    continue
+                if e1 != e2:
+                    if isinstance(e1, tuple) and isinstance(e2, tuple):
+                        stack.append(idx)
+                        break
+                    else:
+                        yield "key%s[%d]:  %s != %s" % (
+                            "".join("[%d]" % id_ for id_ in stack),
+                            idx,
+                            e1,
+                            e2,
+                        )
+            else:
+                pickup_index = stack.pop(-1)
+                break
+
+    def _diff(self, other):
+        return ", ".join(self._whats_different(other))
+
+    def __str__(self):
+        stack = [self.key]
+
+        output = []
+        sentinel = object()
+        indent = -1
+        while stack:
+            elem = stack.pop(0)
+            if elem is sentinel:
+                output.append((" " * (indent * 2)) + "),")
+                indent -= 1
+            elif isinstance(elem, tuple):
+                if not elem:
+                    output.append((" " * ((indent + 1) * 2)) + "()")
+                else:
+                    indent += 1
+                    stack = list(elem) + [sentinel] + stack
+                    output.append((" " * (indent * 2)) + "(")
+            else:
+                if isinstance(elem, HasCacheKey):
+                    repr_ = "<%s object at %s>" % (
+                        type(elem).__name__,
+                        hex(id(elem)),
+                    )
+                else:
+                    repr_ = repr(elem)
+                output.append((" " * (indent * 2)) + "  " + repr_ + ", ")
+
+        return "CacheKey(key=%s)" % ("\n".join(output),)
+
+    def _generate_param_dict(self):
+        """used for testing"""
+
+        from .compiler import prefix_anon_map
+
+        _anon_map = prefix_anon_map()
+        return {b.key % _anon_map: b.effective_value for b in self.bindparams}
+
+    def _apply_params_to_element(self, original_cache_key, target_element):
+        translate = {
+            k.key: v.value
+            for k, v in zip(original_cache_key.bindparams, self.bindparams)
+        }
+
+        return target_element.params(translate)
+
+
+class _CacheKeyTraversal(ExtendedInternalTraversal):
+    # very common elements are inlined into the main _get_cache_key() method
+    # to produce a dramatic savings in Python function call overhead
+
+    visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY
+    visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
+    visit_annotations_key = InternalTraversal.dp_annotations_key
+    visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
+    visit_memoized_select_entities = (
+        InternalTraversal.dp_memoized_select_entities
+    )
+
+    visit_string = (
+        visit_boolean
+    ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE
+    visit_statement_hint_list = CACHE_IN_PLACE
+    visit_type = STATIC_CACHE_KEY
+    visit_anon_name = ANON_NAME
+
+    visit_propagate_attrs = PROPAGATE_ATTRS
+
+    def visit_with_context_options(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return tuple((fn.__code__, c_key) for fn, c_key in obj)
+
+    def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
+        return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
+
+    def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
+        return tuple(obj)
+
+    def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
+        return (
+            attrname,
+            obj._gen_cache_key(anon_map, bindparams)
+            if isinstance(obj, HasCacheKey)
+            else obj,
+        )
+
+    def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
+        return (
+            attrname,
+            tuple(
+                elem._gen_cache_key(anon_map, bindparams)
+                if isinstance(elem, HasCacheKey)
+                else elem
+                for elem in obj
+            ),
+        )
+
+    def visit_has_cache_key_tuples(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        if not obj:
+            return ()
+        return (
+            attrname,
+            tuple(
+                tuple(
+                    elem._gen_cache_key(anon_map, bindparams)
+                    for elem in tup_elem
+                )
+                for tup_elem in obj
+            ),
+        )
+
+    def visit_has_cache_key_list(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        if not obj:
+            return ()
+        return (
+            attrname,
+            tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+        )
+
+    def visit_executable_options(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        if not obj:
+            return ()
+        return (
+            attrname,
+            tuple(
+                elem._gen_cache_key(anon_map, bindparams)
+                for elem in obj
+                if elem._is_has_cache_key
+            ),
+        )
+
+    def visit_inspectable_list(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return self.visit_has_cache_key_list(
+            attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
+        )
+
+    def visit_clauseelement_tuples(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return self.visit_has_cache_key_tuples(
+            attrname, obj, parent, anon_map, bindparams
+        )
+
+    def visit_fromclause_ordered_set(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        if not obj:
+            return ()
+        return (
+            attrname,
+            tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]),
+        )
+
+    def visit_clauseelement_unordered_set(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        if not obj:
+            return ()
+        cache_keys = [
+            elem._gen_cache_key(anon_map, bindparams) for elem in obj
+        ]
+        return (
+            attrname,
+            tuple(
+                sorted(cache_keys)
+            ),  # cache keys all start with (id_, class)
+        )
+
+    def visit_named_ddl_element(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (attrname, obj.name)
+
+    def visit_prefix_sequence(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        if not obj:
+            return ()
+
+        return (
+            attrname,
+            tuple(
+                [
+                    (clause._gen_cache_key(anon_map, bindparams), strval)
+                    for clause, strval in obj
+                ]
+            ),
+        )
+
+    def visit_setup_join_tuple(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return tuple(
+            (
+                target._gen_cache_key(anon_map, bindparams),
+                onclause._gen_cache_key(anon_map, bindparams)
+                if onclause is not None
+                else None,
+                from_._gen_cache_key(anon_map, bindparams)
+                if from_ is not None
+                else None,
+                tuple([(key, flags[key]) for key in sorted(flags)]),
+            )
+            for (target, onclause, from_, flags) in obj
+        )
+
+    def visit_table_hint_list(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        if not obj:
+            return ()
+
+        return (
+            attrname,
+            tuple(
+                [
+                    (
+                        clause._gen_cache_key(anon_map, bindparams),
+                        dialect_name,
+                        text,
+                    )
+                    for (clause, dialect_name), text in obj.items()
+                ]
+            ),
+        )
+
+    def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
+        return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
+
+    def visit_dialect_options(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (
+            attrname,
+            tuple(
+                (
+                    dialect_name,
+                    tuple(
+                        [
+                            (key, obj[dialect_name][key])
+                            for key in sorted(obj[dialect_name])
+                        ]
+                    ),
+                )
+                for dialect_name in sorted(obj)
+            ),
+        )
+
+    def visit_string_clauseelement_dict(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (
+            attrname,
+            tuple(
+                (key, obj[key]._gen_cache_key(anon_map, bindparams))
+                for key in sorted(obj)
+            ),
+        )
+
+    def visit_string_multi_dict(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (
+            attrname,
+            tuple(
+                (
+                    key,
+                    value._gen_cache_key(anon_map, bindparams)
+                    if isinstance(value, HasCacheKey)
+                    else value,
+                )
+                for key, value in [(key, obj[key]) for key in sorted(obj)]
+            ),
+        )
+
+    def visit_fromclause_canonical_column_collection(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        # inlining into the internals of ColumnCollection
+        return (
+            attrname,
+            tuple(
+                col._gen_cache_key(anon_map, bindparams)
+                for k, col in obj._collection
+            ),
+        )
+
+    def visit_unknown_structure(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        anon_map[NO_CACHE] = True
+        return ()
+
+    def visit_dml_ordered_values(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        return (
+            attrname,
+            tuple(
+                (
+                    key._gen_cache_key(anon_map, bindparams)
+                    if hasattr(key, "__clause_element__")
+                    else key,
+                    value._gen_cache_key(anon_map, bindparams),
+                )
+                for key, value in obj
+            ),
+        )
+
+    def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+        # in py37 we can assume two dictionaries created in the same
+        # insert ordering will retain that sorting
+        return (
+            attrname,
+            tuple(
+                (
+                    k._gen_cache_key(anon_map, bindparams)
+                    if hasattr(k, "__clause_element__")
+                    else k,
+                    obj[k]._gen_cache_key(anon_map, bindparams),
+                )
+                for k in obj
+            ),
+        )
+
+    def visit_dml_multi_values(
+        self, attrname, obj, parent, anon_map, bindparams
+    ):
+        # multivalues are simply not cacheable right now
+        anon_map[NO_CACHE] = True
+        return ()
+
+
+_cache_key_traversal_visitor = _CacheKeyTraversal()
index 95697806e612679df53a5e25024730747cc748b0..fe2b498c8e1fd24b0a385e49092ab2ef17afaaf4 100644 (file)
@@ -14,7 +14,7 @@ from . import roles
 from . import visitors
 from .base import ExecutableOption
 from .base import Options
-from .traversals import HasCacheKey
+from .cache_key import HasCacheKey
 from .visitors import Visitable
 from .. import exc
 from .. import inspection
index 697550df447a2c294d29712116b6fc39fdc629bf..cb10811c6aa461dd024e64db9aa2ab451c614c3c 100644 (file)
@@ -503,7 +503,7 @@ class Compiled:
         return self.construct_params()
 
 
-class TypeCompiler(metaclass=util.EnsureKWArgType):
+class TypeCompiler(util.EnsureKWArg):
     """Produces DDL specification for TypeEngine objects."""
 
     ensure_kwarg = r"visit_\w+"
index 12282de0559dbf8d6c656c74396c2160b42e4fb6..a025cce357fb7bb0e1efb7720663279122aefe5e 100644 (file)
@@ -29,10 +29,10 @@ from .base import HasMemoized
 from .base import Immutable
 from .base import NO_ARG
 from .base import SingletonConstant
+from .cache_key import MemoizedHasCacheKey
+from .cache_key import NO_CACHE
 from .coercions import _document_text_coercion
 from .traversals import HasCopyInternals
-from .traversals import MemoizedHasCacheKey
-from .traversals import NO_CACHE
 from .visitors import cloned_traverse
 from .visitors import InternalTraversal
 from .visitors import traverse
index 2a3cd07d0aa8ed63122ed91eaf827b6ea3c905c0..54f67b930dedfe87d13abd40aa8a5082ad4dcf58 100644 (file)
@@ -94,6 +94,7 @@ from .base import _from_objects
 from .base import _select_iterables
 from .base import ColumnCollection
 from .base import Executable
+from .cache_key import CacheKey
 from .dml import Delete
 from .dml import Insert
 from .dml import Update
@@ -173,7 +174,6 @@ from .selectable import TableValuedAlias
 from .selectable import TextAsFrom
 from .selectable import TextualSelect
 from .selectable import Values
-from .traversals import CacheKey
 from .visitors import Visitable
 from ..util.langhelpers import public_factory
 
index b7b9257b42210a8b65ca554dc37a0caf8c2dd7ba..3b6da71757f997bb975dad9e174b5e739a64478a 100644 (file)
@@ -37,8 +37,8 @@ from .elements import WithinGroup
 from .selectable import FromClause
 from .selectable import Select
 from .selectable import TableValuedAlias
+from .type_api import TypeEngine
 from .visitors import InternalTraversal
-from .visitors import TraversibleType
 from .. import util
 
 
@@ -48,7 +48,7 @@ _registry = util.defaultdict(dict)
 def register_function(identifier, fn, package="_default"):
     """Associate a callable with a particular func. name.
 
-    This is normally called by _GenericMeta, but is also
+    This is normally called by GenericFunction, but is also
     available by itself so that a non-Function construct
     can be associated with the :data:`.func` accessor (i.e.
     CAST, EXTRACT).
@@ -828,7 +828,11 @@ class Function(FunctionElement):
         ("type", InternalTraversal.dp_type),
     ]
 
-    type = sqltypes.NULLTYPE
+    name: str
+
+    identifier: str
+
+    type: TypeEngine = sqltypes.NULLTYPE
     """A :class:`_types.TypeEngine` object which refers to the SQL return
     type represented by this SQL function.
 
@@ -871,30 +875,7 @@ class Function(FunctionElement):
         )
 
 
-class _GenericMeta(TraversibleType):
-    def __init__(cls, clsname, bases, clsdict):
-        if annotation.Annotated not in cls.__mro__:
-            cls.name = name = clsdict.get("name", clsname)
-            cls.identifier = identifier = clsdict.get("identifier", name)
-            package = clsdict.pop("package", "_default")
-            # legacy
-            if "__return_type__" in clsdict:
-                cls.type = clsdict["__return_type__"]
-
-            # Check _register attribute status
-            cls._register = getattr(cls, "_register", True)
-
-            # Register the function if required
-            if cls._register:
-                register_function(identifier, cls, package)
-            else:
-                # Set _register to True to register child classes by default
-                cls._register = True
-
-        super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
-
-
-class GenericFunction(Function, metaclass=_GenericMeta):
+class GenericFunction(Function):
     """Define a 'generic' function.
 
     A generic function is a pre-established :class:`.Function`
@@ -986,9 +967,34 @@ class GenericFunction(Function, metaclass=_GenericMeta):
     """
 
     coerce_arguments = True
-    _register = False
     inherit_cache = True
 
+    name = "GenericFunction"
+
+    def __init_subclass__(cls) -> None:
+        if annotation.Annotated not in cls.__mro__:
+            cls._register_generic_function(cls.__name__, cls.__dict__)
+        super().__init_subclass__()
+
+    @classmethod
+    def _register_generic_function(cls, clsname, clsdict):
+        cls.name = name = clsdict.get("name", clsname)
+        cls.identifier = identifier = clsdict.get("identifier", name)
+        package = clsdict.get("package", "_default")
+        # legacy
+        if "__return_type__" in clsdict:
+            cls.type = clsdict["__return_type__"]
+
+        # Check _register attribute status
+        cls._register = getattr(cls, "_register", True)
+
+        # Register the function if required
+        if cls._register:
+            register_function(identifier, cls, package)
+        else:
+            # Set _register to True to register child classes by default
+            cls._register = True
+
     def __init__(self, *args, **kwargs):
         parsed_args = kwargs.pop("_parsed_args", None)
         if parsed_args is None:
@@ -1006,6 +1012,7 @@ class GenericFunction(Function, metaclass=_GenericMeta):
         self.clause_expr = ClauseList(
             operator=operators.comma_op, group_contents=True, *parsed_args
         ).self_group()
+
         self.type = sqltypes.to_instance(
             kwargs.pop("type_", None) or getattr(self, "type", None)
         )
index 2387e551e7101b84af02b5ffb2f7532808743246..d71c85d609037b44f36566b4bc56a04cad9ea0d4 100644 (file)
@@ -11,6 +11,7 @@ import operator
 import types
 import weakref
 
+from . import cache_key as _cache_key
 from . import coercions
 from . import elements
 from . import roles
@@ -185,7 +186,7 @@ class LambdaElement(elements.ClauseElement):
         else:
             parent_closure_cache_key = ()
 
-        if parent_closure_cache_key is not traversals.NO_CACHE:
+        if parent_closure_cache_key is not _cache_key.NO_CACHE:
             anon_map = traversals.anon_map()
             cache_key = tuple(
                 [
@@ -194,7 +195,7 @@ class LambdaElement(elements.ClauseElement):
                 ]
             )
 
-            if traversals.NO_CACHE not in anon_map:
+            if _cache_key.NO_CACHE not in anon_map:
                 cache_key = parent_closure_cache_key + cache_key
 
                 self.closure_cache_key = cache_key
@@ -204,17 +205,17 @@ class LambdaElement(elements.ClauseElement):
                 except KeyError:
                     rec = None
             else:
-                cache_key = traversals.NO_CACHE
+                cache_key = _cache_key.NO_CACHE
                 rec = None
 
         else:
-            cache_key = traversals.NO_CACHE
+            cache_key = _cache_key.NO_CACHE
             rec = None
 
         self.closure_cache_key = cache_key
 
         if rec is None:
-            if cache_key is not traversals.NO_CACHE:
+            if cache_key is not _cache_key.NO_CACHE:
                 rec = AnalyzedFunction(
                     tracker, self, apply_propagate_attrs, fn
                 )
@@ -233,7 +234,7 @@ class LambdaElement(elements.ClauseElement):
 
         self._rec = rec
 
-        if cache_key is not traversals.NO_CACHE:
+        if cache_key is not _cache_key.NO_CACHE:
             if self.parent_lambda is not None:
                 bindparams[:0] = self.parent_lambda._resolved_bindparams
 
@@ -326,8 +327,8 @@ class LambdaElement(elements.ClauseElement):
         return expr
 
     def _gen_cache_key(self, anon_map, bindparams):
-        if self.closure_cache_key is traversals.NO_CACHE:
-            anon_map[traversals.NO_CACHE] = True
+        if self.closure_cache_key is _cache_key.NO_CACHE:
+            anon_map[_cache_key.NO_CACHE] = True
             return None
 
         cache_key = (
@@ -808,7 +809,7 @@ class AnalyzedCode:
                     for tup_elem in opts.track_on[idx]
                 )
 
-        elif isinstance(elem, traversals.HasCacheKey):
+        elif isinstance(elem, _cache_key.HasCacheKey):
 
             def get(closure, opts, anon_map, bindparams):
                 return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
@@ -834,7 +835,7 @@ class AnalyzedCode:
 
         """
 
-        if isinstance(cell_contents, traversals.HasCacheKey):
+        if isinstance(cell_contents, _cache_key.HasCacheKey):
 
             def get(closure, opts, anon_map, bindparams):
 
@@ -1166,7 +1167,7 @@ class PyWrapper(ColumnOperators):
             and not isinstance(
                 # TODO: coverage where an ORM option or similar is here
                 value,
-                traversals.HasCacheKey,
+                _cache_key.HasCacheKey,
             )
         ):
             name = object.__getattribute__(self, "_name")
index 655d98b02feb5119ffe4e885edee567500e5c47e..e674c4b74d657ec50577a5e57796549fb194d9c1 100644 (file)
@@ -18,6 +18,7 @@ import typing
 from typing import Type
 from typing import Union
 
+from . import cache_key
 from . import coercions
 from . import operators
 from . import roles
@@ -4300,7 +4301,7 @@ class _SelectFromElements:
 
 
 class _MemoizedSelectEntities(
-    traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible
+    cache_key.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible
 ):
     __visit_name__ = "memoized_select_entities"
 
index 05007eff1f9ff105e3901406121bf27f1be6039e..d5506cda223bf12263af4fa2cff83c7a6255a4ab 100644 (file)
@@ -22,11 +22,11 @@ from . import roles
 from . import type_api
 from .base import NO_ARG
 from .base import SchemaEventTarget
+from .cache_key import HasCacheKey
 from .elements import _NONE_NAME
 from .elements import quoted_name
 from .elements import Slice
 from .elements import TypeCoerce as type_coerce  # noqa
-from .traversals import HasCacheKey
 from .traversals import InternalTraversal
 from .type_api import Emulated
 from .type_api import NativeForEmulated  # noqa
index b689fe57897088e6084da51cc48dc428519cab5d..2fa3a040839bdb86b16c85347db2659daf245d75 100644 (file)
@@ -1,32 +1,25 @@
+# sql/traversals.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
 from collections import deque
-from collections import namedtuple
 import collections.abc as collections_abc
 import itertools
 from itertools import zip_longest
 import operator
 
 from . import operators
-from .visitors import ExtendedInternalTraversal
+from .visitors import anon_map
 from .visitors import InternalTraversal
 from .. import util
-from ..inspection import inspect
-from ..util import HasMemoized
-
-try:
-    from sqlalchemy.cyextension.util import cache_anon_map as anon_map  # noqa
-except ImportError:
-    from ._py_util import cache_anon_map as anon_map  # noqa
 
 
 SKIP_TRAVERSE = util.symbol("skip_traverse")
 COMPARE_FAILED = False
 COMPARE_SUCCEEDED = True
-NO_CACHE = util.symbol("no_cache")
-CACHE_IN_PLACE = util.symbol("cache_in_place")
-CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key")
-STATIC_CACHE_KEY = util.symbol("static_cache_key")
-PROPAGATE_ATTRS = util.symbol("propagate_attrs")
-ANON_NAME = util.symbol("anon_name")
 
 
 def compare(obj1, obj2, **kw):
@@ -54,729 +47,10 @@ def _preconfigure_traversals(target_hierarchy):
             )
 
 
-class HasCacheKey:
-    """Mixin for objects which can produce a cache key.
-
-    .. seealso::
-
-        :class:`.CacheKey`
-
-        :ref:`sql_caching`
-
-    """
-
-    _cache_key_traversal = NO_CACHE
-
-    _is_has_cache_key = True
-
-    _hierarchy_supports_caching = True
-    """private attribute which may be set to False to prevent the
-    inherit_cache warning from being emitted for a hierarchy of subclasses.
-
-    Currently applies to the DDLElement hierarchy which does not implement
-    caching.
-
-    """
-
-    inherit_cache = None
-    """Indicate if this :class:`.HasCacheKey` instance should make use of the
-    cache key generation scheme used by its immediate superclass.
-
-    The attribute defaults to ``None``, which indicates that a construct has
-    not yet taken into account whether or not its appropriate for it to
-    participate in caching; this is functionally equivalent to setting the
-    value to ``False``, except that a warning is also emitted.
-
-    This flag can be set to ``True`` on a particular class, if the SQL that
-    corresponds to the object does not change based on attributes which
-    are local to this class, and not its superclass.
-
-    .. seealso::
-
-        :ref:`compilerext_caching` - General guideslines for setting the
-        :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user
-        defined SQL constructs.
-
-    """
-
-    __slots__ = ()
-
-    @classmethod
-    def _generate_cache_attrs(cls):
-        """generate cache key dispatcher for a new class.
-
-        This sets the _generated_cache_key_traversal attribute once called
-        so should only be called once per class.
-
-        """
-        inherit_cache = cls.__dict__.get("inherit_cache", None)
-        inherit = bool(inherit_cache)
-
-        if inherit:
-            _cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
-            if _cache_key_traversal is None:
-                try:
-                    _cache_key_traversal = cls._traverse_internals
-                except AttributeError:
-                    cls._generated_cache_key_traversal = NO_CACHE
-                    return NO_CACHE
-
-            # TODO: wouldn't we instead get this from our superclass?
-            # also, our superclass may not have this yet, but in any case,
-            # we'd generate for the superclass that has it.   this is a little
-            # more complicated, so for the moment this is a little less
-            # efficient on startup but simpler.
-            return _cache_key_traversal_visitor.generate_dispatch(
-                cls, _cache_key_traversal, "_generated_cache_key_traversal"
-            )
-        else:
-            _cache_key_traversal = cls.__dict__.get(
-                "_cache_key_traversal", None
-            )
-            if _cache_key_traversal is None:
-                _cache_key_traversal = cls.__dict__.get(
-                    "_traverse_internals", None
-                )
-                if _cache_key_traversal is None:
-                    cls._generated_cache_key_traversal = NO_CACHE
-                    if (
-                        inherit_cache is None
-                        and cls._hierarchy_supports_caching
-                    ):
-                        util.warn(
-                            "Class %s will not make use of SQL compilation "
-                            "caching as it does not set the 'inherit_cache' "
-                            "attribute to ``True``.  This can have "
-                            "significant performance implications including "
-                            "some performance degradations in comparison to "
-                            "prior SQLAlchemy versions.  Set this attribute "
-                            "to True if this object can make use of the cache "
-                            "key generated by the superclass.  Alternatively, "
-                            "this attribute may be set to False which will "
-                            "disable this warning." % (cls.__name__),
-                            code="cprf",
-                        )
-                    return NO_CACHE
-
-            return _cache_key_traversal_visitor.generate_dispatch(
-                cls, _cache_key_traversal, "_generated_cache_key_traversal"
-            )
-
-    @util.preload_module("sqlalchemy.sql.elements")
-    def _gen_cache_key(self, anon_map, bindparams):
-        """return an optional cache key.
-
-        The cache key is a tuple which can contain any series of
-        objects that are hashable and also identifies
-        this object uniquely within the presence of a larger SQL expression
-        or statement, for the purposes of caching the resulting query.
-
-        The cache key should be based on the SQL compiled structure that would
-        ultimately be produced.   That is, two structures that are composed in
-        exactly the same way should produce the same cache key; any difference
-        in the structures that would affect the SQL string or the type handlers
-        should result in a different cache key.
-
-        If a structure cannot produce a useful cache key, the NO_CACHE
-        symbol should be added to the anon_map and the method should
-        return None.
-
-        """
-
-        cls = self.__class__
-
-        id_, found = anon_map.get_anon(self)
-        if found:
-            return (id_, cls)
-
-        try:
-            dispatcher = cls.__dict__["_generated_cache_key_traversal"]
-        except KeyError:
-            # most of the dispatchers are generated up front
-            # in sqlalchemy/sql/__init__.py ->
-            # traversals.py-> _preconfigure_traversals().
-            # this block will generate any remaining dispatchers.
-            dispatcher = cls._generate_cache_attrs()
-
-        if dispatcher is NO_CACHE:
-            anon_map[NO_CACHE] = True
-            return None
-
-        result = (id_, cls)
-
-        # inline of _cache_key_traversal_visitor.run_generated_dispatch()
-
-        for attrname, obj, meth in dispatcher(
-            self, _cache_key_traversal_visitor
-        ):
-            if obj is not None:
-                # TODO: see if C code can help here as Python lacks an
-                # efficient switch construct
-
-                if meth is STATIC_CACHE_KEY:
-                    sck = obj._static_cache_key
-                    if sck is NO_CACHE:
-                        anon_map[NO_CACHE] = True
-                        return None
-                    result += (attrname, sck)
-                elif meth is ANON_NAME:
-                    elements = util.preloaded.sql_elements
-                    if isinstance(obj, elements._anonymous_label):
-                        obj = obj.apply_map(anon_map)
-                    result += (attrname, obj)
-                elif meth is CALL_GEN_CACHE_KEY:
-                    result += (
-                        attrname,
-                        obj._gen_cache_key(anon_map, bindparams),
-                    )
-
-                # remaining cache functions are against
-                # Python tuples, dicts, lists, etc. so we can skip
-                # if they are empty
-                elif obj:
-                    if meth is CACHE_IN_PLACE:
-                        result += (attrname, obj)
-                    elif meth is PROPAGATE_ATTRS:
-                        result += (
-                            attrname,
-                            obj["compile_state_plugin"],
-                            obj["plugin_subject"]._gen_cache_key(
-                                anon_map, bindparams
-                            )
-                            if obj["plugin_subject"]
-                            else None,
-                        )
-                    elif meth is InternalTraversal.dp_annotations_key:
-                        # obj is here is the _annotations dict.   however, we
-                        # want to use the memoized cache key version of it. for
-                        # Columns, this should be long lived.   For select()
-                        # statements, not so much, but they usually won't have
-                        # annotations.
-                        result += self._annotations_cache_key
-                    elif (
-                        meth is InternalTraversal.dp_clauseelement_list
-                        or meth is InternalTraversal.dp_clauseelement_tuple
-                        or meth
-                        is InternalTraversal.dp_memoized_select_entities
-                    ):
-                        result += (
-                            attrname,
-                            tuple(
-                                [
-                                    elem._gen_cache_key(anon_map, bindparams)
-                                    for elem in obj
-                                ]
-                            ),
-                        )
-                    else:
-                        result += meth(
-                            attrname, obj, self, anon_map, bindparams
-                        )
-        return result
-
-    def _generate_cache_key(self):
-        """return a cache key.
-
-        The cache key is a tuple which can contain any series of
-        objects that are hashable and also identifies
-        this object uniquely within the presence of a larger SQL expression
-        or statement, for the purposes of caching the resulting query.
-
-        The cache key should be based on the SQL compiled structure that would
-        ultimately be produced.   That is, two structures that are composed in
-        exactly the same way should produce the same cache key; any difference
-        in the structures that would affect the SQL string or the type handlers
-        should result in a different cache key.
-
-        The cache key returned by this method is an instance of
-        :class:`.CacheKey`, which consists of a tuple representing the
-        cache key, as well as a list of :class:`.BindParameter` objects
-        which are extracted from the expression.   While two expressions
-        that produce identical cache key tuples will themselves generate
-        identical SQL strings, the list of :class:`.BindParameter` objects
-        indicates the bound values which may have different values in
-        each one; these bound parameters must be consulted in order to
-        execute the statement with the correct parameters.
-
-        a :class:`_expression.ClauseElement` structure that does not implement
-        a :meth:`._gen_cache_key` method and does not implement a
-        :attr:`.traverse_internals` attribute will not be cacheable; when
-        such an element is embedded into a larger structure, this method
-        will return None, indicating no cache key is available.
-
-        """
-
-        bindparams = []
-
-        _anon_map = anon_map()
-        key = self._gen_cache_key(_anon_map, bindparams)
-        if NO_CACHE in _anon_map:
-            return None
-        else:
-            return CacheKey(key, bindparams)
-
-    @classmethod
-    def _generate_cache_key_for_object(cls, obj):
-        bindparams = []
-
-        _anon_map = anon_map()
-        key = obj._gen_cache_key(_anon_map, bindparams)
-        if NO_CACHE in _anon_map:
-            return None
-        else:
-            return CacheKey(key, bindparams)
-
-
-class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
-    @HasMemoized.memoized_instancemethod
-    def _generate_cache_key(self):
-        return HasCacheKey._generate_cache_key(self)
-
-
-class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
-    """The key used to identify a SQL statement construct in the
-    SQL compilation cache.
-
-    .. seealso::
-
-        :ref:`sql_caching`
-
-    """
-
-    def __hash__(self):
-        """CacheKey itself is not hashable - hash the .key portion"""
-
-        return None
-
-    def to_offline_string(self, statement_cache, statement, parameters):
-        """Generate an "offline string" form of this :class:`.CacheKey`
-
-        The "offline string" is basically the string SQL for the
-        statement plus a repr of the bound parameter values in series.
-        Whereas the :class:`.CacheKey` object is dependent on in-memory
-        identities in order to work as a cache key, the "offline" version
-        is suitable for a cache that will work for other processes as well.
-
-        The given ``statement_cache`` is a dictionary-like object where the
-        string form of the statement itself will be cached.  This dictionary
-        should be in a longer lived scope in order to reduce the time spent
-        stringifying statements.
-
-
-        """
-        if self.key not in statement_cache:
-            statement_cache[self.key] = sql_str = str(statement)
-        else:
-            sql_str = statement_cache[self.key]
-
-        if not self.bindparams:
-            param_tuple = tuple(parameters[key] for key in sorted(parameters))
-        else:
-            param_tuple = tuple(
-                parameters.get(bindparam.key, bindparam.value)
-                for bindparam in self.bindparams
-            )
-
-        return repr((sql_str, param_tuple))
-
-    def __eq__(self, other):
-        return self.key == other.key
-
-    @classmethod
-    def _diff_tuples(cls, left, right):
-        ck1 = CacheKey(left, [])
-        ck2 = CacheKey(right, [])
-        return ck1._diff(ck2)
-
-    def _whats_different(self, other):
-
-        k1 = self.key
-        k2 = other.key
-
-        stack = []
-        pickup_index = 0
-        while True:
-            s1, s2 = k1, k2
-            for idx in stack:
-                s1 = s1[idx]
-                s2 = s2[idx]
-
-            for idx, (e1, e2) in enumerate(zip_longest(s1, s2)):
-                if idx < pickup_index:
-                    continue
-                if e1 != e2:
-                    if isinstance(e1, tuple) and isinstance(e2, tuple):
-                        stack.append(idx)
-                        break
-                    else:
-                        yield "key%s[%d]:  %s != %s" % (
-                            "".join("[%d]" % id_ for id_ in stack),
-                            idx,
-                            e1,
-                            e2,
-                        )
-            else:
-                pickup_index = stack.pop(-1)
-                break
-
-    def _diff(self, other):
-        return ", ".join(self._whats_different(other))
-
-    def __str__(self):
-        stack = [self.key]
-
-        output = []
-        sentinel = object()
-        indent = -1
-        while stack:
-            elem = stack.pop(0)
-            if elem is sentinel:
-                output.append((" " * (indent * 2)) + "),")
-                indent -= 1
-            elif isinstance(elem, tuple):
-                if not elem:
-                    output.append((" " * ((indent + 1) * 2)) + "()")
-                else:
-                    indent += 1
-                    stack = list(elem) + [sentinel] + stack
-                    output.append((" " * (indent * 2)) + "(")
-            else:
-                if isinstance(elem, HasCacheKey):
-                    repr_ = "<%s object at %s>" % (
-                        type(elem).__name__,
-                        hex(id(elem)),
-                    )
-                else:
-                    repr_ = repr(elem)
-                output.append((" " * (indent * 2)) + "  " + repr_ + ", ")
-
-        return "CacheKey(key=%s)" % ("\n".join(output),)
-
-    def _generate_param_dict(self):
-        """used for testing"""
-
-        from .compiler import prefix_anon_map
-
-        _anon_map = prefix_anon_map()
-        return {b.key % _anon_map: b.effective_value for b in self.bindparams}
-
-    def _apply_params_to_element(self, original_cache_key, target_element):
-        translate = {
-            k.key: v.value
-            for k, v in zip(original_cache_key.bindparams, self.bindparams)
-        }
-
-        return target_element.params(translate)
-
-
 def _clone(element, **kw):
     return element._clone()
 
 
-class _CacheKey(ExtendedInternalTraversal):
-    # very common elements are inlined into the main _get_cache_key() method
-    # to produce a dramatic savings in Python function call overhead
-
-    visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY
-    visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
-    visit_annotations_key = InternalTraversal.dp_annotations_key
-    visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
-    visit_memoized_select_entities = (
-        InternalTraversal.dp_memoized_select_entities
-    )
-
-    visit_string = (
-        visit_boolean
-    ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE
-    visit_statement_hint_list = CACHE_IN_PLACE
-    visit_type = STATIC_CACHE_KEY
-    visit_anon_name = ANON_NAME
-
-    visit_propagate_attrs = PROPAGATE_ATTRS
-
-    def visit_with_context_options(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        return tuple((fn.__code__, c_key) for fn, c_key in obj)
-
-    def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
-        return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
-
-    def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
-        return tuple(obj)
-
-    def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
-        return (
-            attrname,
-            obj._gen_cache_key(anon_map, bindparams)
-            if isinstance(obj, HasCacheKey)
-            else obj,
-        )
-
-    def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
-        return (
-            attrname,
-            tuple(
-                elem._gen_cache_key(anon_map, bindparams)
-                if isinstance(elem, HasCacheKey)
-                else elem
-                for elem in obj
-            ),
-        )
-
-    def visit_has_cache_key_tuples(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        if not obj:
-            return ()
-        return (
-            attrname,
-            tuple(
-                tuple(
-                    elem._gen_cache_key(anon_map, bindparams)
-                    for elem in tup_elem
-                )
-                for tup_elem in obj
-            ),
-        )
-
-    def visit_has_cache_key_list(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        if not obj:
-            return ()
-        return (
-            attrname,
-            tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
-        )
-
-    def visit_executable_options(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        if not obj:
-            return ()
-        return (
-            attrname,
-            tuple(
-                elem._gen_cache_key(anon_map, bindparams)
-                for elem in obj
-                if elem._is_has_cache_key
-            ),
-        )
-
-    def visit_inspectable_list(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        return self.visit_has_cache_key_list(
-            attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
-        )
-
-    def visit_clauseelement_tuples(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        return self.visit_has_cache_key_tuples(
-            attrname, obj, parent, anon_map, bindparams
-        )
-
-    def visit_fromclause_ordered_set(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        if not obj:
-            return ()
-        return (
-            attrname,
-            tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]),
-        )
-
-    def visit_clauseelement_unordered_set(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        if not obj:
-            return ()
-        cache_keys = [
-            elem._gen_cache_key(anon_map, bindparams) for elem in obj
-        ]
-        return (
-            attrname,
-            tuple(
-                sorted(cache_keys)
-            ),  # cache keys all start with (id_, class)
-        )
-
-    def visit_named_ddl_element(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        return (attrname, obj.name)
-
-    def visit_prefix_sequence(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        if not obj:
-            return ()
-
-        return (
-            attrname,
-            tuple(
-                [
-                    (clause._gen_cache_key(anon_map, bindparams), strval)
-                    for clause, strval in obj
-                ]
-            ),
-        )
-
-    def visit_setup_join_tuple(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        is_legacy = "legacy" in attrname
-
-        return tuple(
-            (
-                target
-                if is_legacy and isinstance(target, str)
-                else target._gen_cache_key(anon_map, bindparams),
-                onclause
-                if is_legacy and isinstance(onclause, str)
-                else onclause._gen_cache_key(anon_map, bindparams)
-                if onclause is not None
-                else None,
-                from_._gen_cache_key(anon_map, bindparams)
-                if from_ is not None
-                else None,
-                tuple([(key, flags[key]) for key in sorted(flags)]),
-            )
-            for (target, onclause, from_, flags) in obj
-        )
-
-    def visit_table_hint_list(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        if not obj:
-            return ()
-
-        return (
-            attrname,
-            tuple(
-                [
-                    (
-                        clause._gen_cache_key(anon_map, bindparams),
-                        dialect_name,
-                        text,
-                    )
-                    for (clause, dialect_name), text in obj.items()
-                ]
-            ),
-        )
-
-    def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
-        return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
-
-    def visit_dialect_options(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        return (
-            attrname,
-            tuple(
-                (
-                    dialect_name,
-                    tuple(
-                        [
-                            (key, obj[dialect_name][key])
-                            for key in sorted(obj[dialect_name])
-                        ]
-                    ),
-                )
-                for dialect_name in sorted(obj)
-            ),
-        )
-
-    def visit_string_clauseelement_dict(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        return (
-            attrname,
-            tuple(
-                (key, obj[key]._gen_cache_key(anon_map, bindparams))
-                for key in sorted(obj)
-            ),
-        )
-
-    def visit_string_multi_dict(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        return (
-            attrname,
-            tuple(
-                (
-                    key,
-                    value._gen_cache_key(anon_map, bindparams)
-                    if isinstance(value, HasCacheKey)
-                    else value,
-                )
-                for key, value in [(key, obj[key]) for key in sorted(obj)]
-            ),
-        )
-
-    def visit_fromclause_canonical_column_collection(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        # inlining into the internals of ColumnCollection
-        return (
-            attrname,
-            tuple(
-                col._gen_cache_key(anon_map, bindparams)
-                for k, col in obj._collection
-            ),
-        )
-
-    def visit_unknown_structure(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        anon_map[NO_CACHE] = True
-        return ()
-
-    def visit_dml_ordered_values(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        return (
-            attrname,
-            tuple(
-                (
-                    key._gen_cache_key(anon_map, bindparams)
-                    if hasattr(key, "__clause_element__")
-                    else key,
-                    value._gen_cache_key(anon_map, bindparams),
-                )
-                for key, value in obj
-            ),
-        )
-
-    def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
-        # in py37 we can assume two dictionaries created in the same
-        # insert ordering will retain that sorting
-        return (
-            attrname,
-            tuple(
-                (
-                    k._gen_cache_key(anon_map, bindparams)
-                    if hasattr(k, "__clause_element__")
-                    else k,
-                    obj[k]._gen_cache_key(anon_map, bindparams),
-                )
-                for k in obj
-            ),
-        )
-
-    def visit_dml_multi_values(
-        self, attrname, obj, parent, anon_map, bindparams
-    ):
-        # multivalues are simply not cacheable right now
-        anon_map[NO_CACHE] = True
-        return ()
-
-
-_cache_key_traversal_visitor = _CacheKey()
-
-
 class HasCopyInternals:
     __slots__ = ()
 
@@ -813,7 +87,7 @@ class HasCopyInternals:
                     setattr(self, attrname, result)
 
 
-class _CopyInternals(InternalTraversal):
+class _CopyInternalsTraversal(InternalTraversal):
     """Generate a _copy_internals internal traversal dispatch for classes
     with a _traverse_internals collection."""
 
@@ -936,7 +210,7 @@ class _CopyInternals(InternalTraversal):
         return element
 
 
-_copy_internals = _CopyInternals()
+_copy_internals = _CopyInternalsTraversal()
 
 
 def _flatten_clauseelement(element):
@@ -948,7 +222,7 @@ def _flatten_clauseelement(element):
     return element
 
 
-class _GetChildren(InternalTraversal):
+class _GetChildrenTraversal(InternalTraversal):
     """Generate a _children_traversal internal traversal dispatch for classes
     with a _traverse_internals collection."""
 
@@ -1019,7 +293,7 @@ class _GetChildren(InternalTraversal):
         return ()
 
 
-_get_children = _GetChildren()
+_get_children = _GetChildrenTraversal()
 
 
 @util.preload_module("sqlalchemy.sql.elements")
index 9dd702410b4f90dde6eef23cb8d00c533f0f77ab..55289cb8534cc27ae046d3f6994001fd01f9e15e 100644 (file)
@@ -13,9 +13,8 @@ import typing
 
 from . import operators
 from .base import SchemaEventTarget
-from .traversals import NO_CACHE
+from .cache_key import NO_CACHE
 from .visitors import Traversible
-from .visitors import TraversibleType
 from .. import exc
 from .. import util
 
@@ -858,10 +857,6 @@ class TypeEngine(Traversible):
         return util.generic_repr(self)
 
 
-class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType):
-    pass
-
-
 class ExternalType:
     """mixin that defines attributes and behaviors specific to third-party
     datatypes.
@@ -1038,7 +1033,7 @@ class ExternalType:
         return NO_CACHE
 
 
-class UserDefinedType(ExternalType, TypeEngine, metaclass=VisitableCheckKWArg):
+class UserDefinedType(ExternalType, TypeEngine, util.EnsureKWArg):
     """Base for user defined types.
 
     This should be the base of new types.  Note that
index 69e83e46a639d6acf5e61c7d42641be1f8576af6..63067585ed1c1f584da6a97eb882cca822d66435 100644 (file)
@@ -22,6 +22,7 @@ from .annotation import _shallow_annotate  # noqa
 from .base import _expand_cloned
 from .base import _from_objects
 from .base import ColumnSet
+from .cache_key import HasCacheKey  # noqa
 from .ddl import sort_tables  # noqa
 from .elements import _find_columns  # noqa
 from .elements import _label_reference
@@ -41,7 +42,6 @@ from .selectable import Join
 from .selectable import ScalarSelect
 from .selectable import SelectBase
 from .selectable import TableClause
-from .traversals import HasCacheKey  # noqa
 from .. import exc
 from .. import util
 
index 87fe3694445b77afb1c2dfd98ea69fa4cef98203..70c4dc133275d54e9c97028396543c72d0501581 100644 (file)
@@ -32,6 +32,11 @@ from .. import util
 from ..util import langhelpers
 from ..util import symbol
 
+try:
+    from sqlalchemy.cyextension.util import cache_anon_map as anon_map  # noqa
+except ImportError:
+    from ._py_util import cache_anon_map as anon_map  # noqa
+
 __all__ = [
     "iterate",
     "traverse_using",
@@ -39,88 +44,77 @@ __all__ = [
     "cloned_traverse",
     "replacement_traverse",
     "Traversible",
-    "TraversibleType",
     "ExternalTraversal",
     "InternalTraversal",
 ]
 
 
-def _generate_compiler_dispatch(cls):
-    """Generate a _compiler_dispatch() external traversal on classes with a
-    __visit_name__ attribute.
-
-    """
-    visit_name = cls.__visit_name__
-
-    if "_compiler_dispatch" in cls.__dict__:
-        # class has a fixed _compiler_dispatch() method.
-        # copy it to "original" so that we can get it back if
-        # sqlalchemy.ext.compiles overrides it.
-        cls._original_compiler_dispatch = cls._compiler_dispatch
-        return
-
-    if not isinstance(visit_name, str):
-        raise exc.InvalidRequestError(
-            "__visit_name__ on class %s must be a string at the class level"
-            % cls.__name__
-        )
-
-    name = "visit_%s" % visit_name
-    getter = operator.attrgetter(name)
-
-    def _compiler_dispatch(self, visitor, **kw):
-        """Look for an attribute named "visit_<visit_name>" on the
-        visitor, and call it with the same kw params.
-
-        """
-        try:
-            meth = getter(visitor)
-        except AttributeError as err:
-            return visitor.visit_unsupported_compilation(self, err, **kw)
-
-        else:
-            return meth(self, **kw)
-
-    cls._compiler_dispatch = (
-        cls._original_compiler_dispatch
-    ) = _compiler_dispatch
+class Traversible:
+    """Base class for visitable objects."""
 
+    __slots__ = ()
 
-class TraversibleType(type):
-    """Metaclass which assigns dispatch attributes to various kinds of
-    "visitable" classes.
+    __visit_name__: str
 
-    Attributes include:
+    def __init_subclass__(cls) -> None:
+        if "__visit_name__" in cls.__dict__:
+            cls._generate_compiler_dispatch()
+        super().__init_subclass__()
 
-    * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``.
-      This is called "external traversal" because the caller of each visit()
-      method is responsible for sub-traversing the inner elements of each
-      object. This is appropriate for string compilers and other traversals
-      that need to call upon the inner elements in a specific pattern.
+    @classmethod
+    def _generate_compiler_dispatch(cls):
+        """Assign dispatch attributes to various kinds of
+        "visitable" classes.
 
-    * internal traversal collections ``_children_traversal``,
-      ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from
-      an optional ``_traverse_internals`` collection of symbols which comes
-      from the :class:`.InternalTraversal` list of symbols.  This is called
-      "internal traversal" MARKMARK
+        Attributes include:
 
-    """
+        * The ``_compiler_dispatch`` method, corresponding to
+          ``__visit_name__``. This is called "external traversal" because the
+          caller of each visit() method is responsible for sub-traversing the
+          inner elements of each object. This is appropriate for string
+          compilers and other traversals that need to call upon the inner
+          elements in a specific pattern.
 
-    def __init__(cls, clsname, bases, clsdict):
-        if clsname != "Traversible":
-            if "__visit_name__" in clsdict:
-                _generate_compiler_dispatch(cls)
+        * internal traversal collections ``_children_traversal``,
+          ``_cache_key_traversal``, ``_copy_internals_traversal``, generated
+          from an optional ``_traverse_internals`` collection of symbols which
+          comes from the :class:`.InternalTraversal` list of symbols. This is
+          called "internal traversal".
 
-        super(TraversibleType, cls).__init__(clsname, bases, clsdict)
+        """
+        visit_name = cls.__visit_name__
+
+        if "_compiler_dispatch" in cls.__dict__:
+            # class has a fixed _compiler_dispatch() method.
+            # copy it to "original" so that we can get it back if
+            # sqlalchemy.ext.compiles overrides it.
+            cls._original_compiler_dispatch = cls._compiler_dispatch
+            return
+
+        if not isinstance(visit_name, str):
+            raise exc.InvalidRequestError(
+                f"__visit_name__ on class {cls.__name__} must be a string "
+                "at the class level"
+            )
 
+        name = "visit_%s" % visit_name
+        getter = operator.attrgetter(name)
 
-class Traversible(metaclass=TraversibleType):
-    """Base class for visitable objects, applies the
-    :class:`.visitors.TraversibleType` metaclass.
+        def _compiler_dispatch(self, visitor, **kw):
+            """Look for an attribute named "visit_<visit_name>" on the
+            visitor, and call it with the same kw params.
 
-    """
+            """
+            try:
+                meth = getter(visitor)
+            except AttributeError as err:
+                return visitor.visit_unsupported_compilation(self, err, **kw)
+            else:
+                return meth(self, **kw)
 
-    __slots__ = ()
+        cls._compiler_dispatch = (
+            cls._original_compiler_dispatch
+        ) = _compiler_dispatch
 
     def __class_getitem__(cls, key):
         # allow generic classes in py3.9+
@@ -159,48 +153,90 @@ class Traversible(metaclass=TraversibleType):
         )
 
 
-class _InternalTraversalType(type):
-    def __init__(cls, clsname, bases, clsdict):
-        if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"):
-            lookup = {}
-            for key, sym in clsdict.items():
-                if key.startswith("dp_"):
-                    visit_key = key.replace("dp_", "visit_")
-                    sym_name = sym.name
-                    assert sym_name not in lookup, sym_name
-                    lookup[sym] = lookup[sym_name] = visit_key
-            if hasattr(cls, "_dispatch_lookup"):
-                lookup.update(cls._dispatch_lookup)
-            cls._dispatch_lookup = lookup
+class _HasTraversalDispatch:
+    r"""Define infrastructure for the :class:`.InternalTraversal` class.
 
-        super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict)
+    .. versionadded:: 2.0
 
+    """
 
-def _generate_dispatcher(visitor, internal_dispatch, method_name):
-    names = []
-    for attrname, visit_sym in internal_dispatch:
-        meth = visitor.dispatch(visit_sym)
-        if meth:
-            visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym]
-            names.append((attrname, visit_name))
-
-    code = (
-        ("    return [\n")
-        + (
-            ", \n".join(
-                "        (%r, self.%s, visitor.%s)"
-                % (attrname, attrname, visit_name)
-                for attrname, visit_name in names
+    def __init_subclass__(cls) -> None:
+        cls._generate_traversal_dispatch()
+        super().__init_subclass__()
+
+    def dispatch(self, visit_symbol):
+        """Given a method from :class:`._HasTraversalDispatch`, return the
+        corresponding method on a subclass.
+
+        """
+        name = self._dispatch_lookup[visit_symbol]
+        return getattr(self, name, None)
+
+    def run_generated_dispatch(
+        self, target, internal_dispatch, generate_dispatcher_name
+    ):
+        try:
+            dispatcher = target.__class__.__dict__[generate_dispatcher_name]
+        except KeyError:
+            # most of the dispatchers are generated up front
+            # in sqlalchemy/sql/__init__.py ->
+            # traversals.py-> _preconfigure_traversals().
+            # this block will generate any remaining dispatchers.
+            dispatcher = self.generate_dispatch(
+                target.__class__, internal_dispatch, generate_dispatcher_name
+            )
+        return dispatcher(target, self)
+
+    def generate_dispatch(
+        self, target_cls, internal_dispatch, generate_dispatcher_name
+    ):
+        dispatcher = self._generate_dispatcher(
+            internal_dispatch, generate_dispatcher_name
+        )
+        # assert isinstance(target_cls, type)
+        setattr(target_cls, generate_dispatcher_name, dispatcher)
+        return dispatcher
+
+    @classmethod
+    def _generate_traversal_dispatch(cls):
+        lookup = {}
+        clsdict = cls.__dict__
+        for key, sym in clsdict.items():
+            if key.startswith("dp_"):
+                visit_key = key.replace("dp_", "visit_")
+                sym_name = sym.name
+                assert sym_name not in lookup, sym_name
+                lookup[sym] = lookup[sym_name] = visit_key
+        if hasattr(cls, "_dispatch_lookup"):
+            lookup.update(cls._dispatch_lookup)
+        cls._dispatch_lookup = lookup
+
+    def _generate_dispatcher(self, internal_dispatch, method_name):
+        names = []
+        for attrname, visit_sym in internal_dispatch:
+            meth = self.dispatch(visit_sym)
+            if meth:
+                visit_name = ExtendedInternalTraversal._dispatch_lookup[
+                    visit_sym
+                ]
+                names.append((attrname, visit_name))
+
+        code = (
+            ("    return [\n")
+            + (
+                ", \n".join(
+                    "        (%r, self.%s, visitor.%s)"
+                    % (attrname, attrname, visit_name)
+                    for attrname, visit_name in names
+                )
             )
+            + ("\n    ]\n")
         )
-        + ("\n    ]\n")
-    )
-    meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
-    # print(meth_text)
-    return langhelpers._exec_code_in_env(meth_text, {}, method_name)
+        meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
+        return langhelpers._exec_code_in_env(meth_text, {}, method_name)
 
 
-class InternalTraversal(metaclass=_InternalTraversalType):
+class InternalTraversal(_HasTraversalDispatch):
     r"""Defines visitor symbols used for internal traversal.
 
     The :class:`.InternalTraversal` class is used in two ways.  One is that
@@ -239,39 +275,6 @@ class InternalTraversal(metaclass=_InternalTraversalType):
 
     """
 
-    def dispatch(self, visit_symbol):
-        """Given a method from :class:`.InternalTraversal`, return the
-        corresponding method on a subclass.
-
-        """
-        name = self._dispatch_lookup[visit_symbol]
-        return getattr(self, name, None)
-
-    def run_generated_dispatch(
-        self, target, internal_dispatch, generate_dispatcher_name
-    ):
-        try:
-            dispatcher = target.__class__.__dict__[generate_dispatcher_name]
-        except KeyError:
-            # most of the dispatchers are generated up front
-            # in sqlalchemy/sql/__init__.py ->
-            # traversals.py-> _preconfigure_traversals().
-            # this block will generate any remaining dispatchers.
-            dispatcher = self.generate_dispatch(
-                target.__class__, internal_dispatch, generate_dispatcher_name
-            )
-        return dispatcher(target, self)
-
-    def generate_dispatch(
-        self, target_cls, internal_dispatch, generate_dispatcher_name
-    ):
-        dispatcher = _generate_dispatcher(
-            self, internal_dispatch, generate_dispatcher_name
-        )
-        # assert isinstance(target_cls, type)
-        setattr(target_cls, generate_dispatcher_name, dispatcher)
-        return dispatcher
-
     dp_has_cache_key = symbol("HC")
     """Visit a :class:`.HasCacheKey` object."""
 
@@ -623,7 +626,6 @@ class ReplacingExternalTraversal(CloningExternalTraversal):
 
 # backwards compatibility
 Visitable = Traversible
-VisitableType = TraversibleType
 ClauseVisitor = ExternalTraversal
 CloningVisitor = CloningExternalTraversal
 ReplacingCloningVisitor = ReplacingExternalTraversal
index 74c86e85ab6d20a9d9a8b686c581a33cbbfa6829..ecc20f1638b134fcb9628d30b7872091621016a8 100644 (file)
@@ -21,7 +21,6 @@ from .. import event
 from .. import util
 from ..orm import declarative_base
 from ..orm import registry
-from ..orm.decl_api import DeclarativeMeta
 from ..schema import sort_tables_and_constraints
 
 
@@ -647,15 +646,11 @@ class MappedTest(TablesTest, assertions.AssertsExecutionResults):
         """
         cls_registry = cls.classes
 
-        assert cls_registry is not None
-
-        class FindFixture(type):
-            def __init__(cls, classname, bases, dict_):
-                cls_registry[classname] = cls
-                type.__init__(cls, classname, bases, dict_)
-
-        class _Base(metaclass=FindFixture):
-            pass
+        class _Base:
+            def __init_subclass__(cls) -> None:
+                assert cls_registry is not None
+                cls_registry[cls.__name__] = cls
+                super().__init_subclass__()
 
         class Basic(BasicEntity, _Base):
             pass
@@ -699,17 +694,16 @@ class DeclarativeMappedTest(MappedTest):
     def _with_register_classes(cls, fn):
         cls_registry = cls.classes
 
-        class FindFixtureDeclarative(DeclarativeMeta):
-            def __init__(cls, classname, bases, dict_):
-                cls_registry[classname] = cls
-                DeclarativeMeta.__init__(cls, classname, bases, dict_)
-
         class DeclarativeBasic:
             __table_cls__ = schema.Table
 
+            def __init_subclass__(cls) -> None:
+                assert cls_registry is not None
+                cls_registry[cls.__name__] = cls
+                super().__init_subclass__()
+
         _DeclBase = declarative_base(
             metadata=cls._tables_metadata,
-            metaclass=FindFixtureDeclarative,
             cls=DeclarativeBasic,
         )
 
index 2d2ff35651cddf0c76f2c1911babd6f6f9e5037e..7c03bcd4ba35159877262540b9a68fed950d8ef3 100644 (file)
@@ -98,7 +98,7 @@ from .langhelpers import decorator
 from .langhelpers import dictlike_iteritems
 from .langhelpers import duck_type_collection
 from .langhelpers import ellipses_string
-from .langhelpers import EnsureKWArgType
+from .langhelpers import EnsureKWArg
 from .langhelpers import format_argspec_init
 from .langhelpers import format_argspec_plus
 from .langhelpers import generic_repr
index 1b277cdee9f0c9d68de8b7f534ceaf7dd6517ee8..66c5308678241a7452eb07f05c7d7f1396aeb06f 100644 (file)
@@ -1743,14 +1743,28 @@ def attrsetter(attrname):
     return env["set"]
 
 
-class EnsureKWArgType(type):
+class EnsureKWArg:
     r"""Apply translation of functions to accept \**kw arguments if they
     don't already.
 
+    Used to ensure cross-compatibility with third party legacy code, for things
+    like compiler visit methods that need to accept ``**kw`` arguments,
+    but may have been copied from old code that didn't accept them.
+
+    """
+
+    ensure_kwarg: str
+    """a regular expression that indicates method names for which the method
+    should accept ``**kw`` arguments.
+
+    The class will scan for methods matching the name template and decorate
+    them if necessary to ensure ``**kw`` parameters are accepted.
+
     """
 
-    def __init__(cls, clsname, bases, clsdict):
+    def __init_subclass__(cls) -> None:
         fn_reg = cls.ensure_kwarg
+        clsdict = cls.__dict__
         if fn_reg:
             for key in clsdict:
                 m = re.match(fn_reg, key)
@@ -1758,11 +1772,12 @@ class EnsureKWArgType(type):
                     fn = clsdict[key]
                     spec = compat.inspect_getfullargspec(fn)
                     if not spec.varkw:
-                        clsdict[key] = wrapped = cls._wrap_w_kw(fn)
+                        wrapped = cls._wrap_w_kw(fn)
                         setattr(cls, key, wrapped)
-        super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict)
+        super().__init_subclass__()
 
-    def _wrap_w_kw(self, fn):
+    @classmethod
+    def _wrap_w_kw(cls, fn):
         def wrap(*arg, **kw):
             return fn(*arg)
 
index 92ef241ed75e2756bddbdc1c678b9c7b2e4b67d9..37b4d0ae1f58f5f485fd61ef1b45ffee493cbdce 100644 (file)
@@ -32,7 +32,7 @@ from sqlalchemy.orm import Session
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.orm import subqueryload
 from sqlalchemy.orm import UserDefinedOption
-from sqlalchemy.sql.traversals import NO_CACHE
+from sqlalchemy.sql.cache_key import NO_CACHE
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
index bbf9716f5836b810c64f664ea67d7cca9b2a3c7b..43aed0672542e5b0c6be00ec36cf242702dacdd9 100644 (file)
@@ -18,7 +18,7 @@ from sqlalchemy.sql import select
 from sqlalchemy.sql import table
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql.base import ExecutableOption
-from sqlalchemy.sql.traversals import HasCacheKey
+from sqlalchemy.sql.cache_key import HasCacheKey
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_