From: Mike Bayer Date: Mon, 10 Jan 2022 21:48:05 +0000 (-0500) Subject: remove internal use of metaclasses X-Git-Tag: rel_2_0_0b1~546^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3a23e8ed29180e914883a263ec83373ecbd02efa;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git remove internal use of metaclasses All but one metaclass used internally can now be replaced using __init_subclass__(). Within this patch we remove: * events._EventMeta * sql.visitors.TraversibleType * sql.visitors.InternalTraversibleType * testing.fixtures.FindFixture * testing.fixtures.FindFixtureDeclarative * langhelpers.EnsureKWArgType * sql.functions._GenericMeta * sql.type_api.VisitableCheckKWArg (was a mixture of TraversibleType and EnsureKWArgType) The remaining internal class is MetaOptions used by the sql.Options object which is in turn currently mostly for ORM internal use, as this type implements class level overrides for the ``+`` operator. For declarative, removing DeclarativeMeta in place of an `__init_subclass__()` class would not be fully feasible as it would break backwards compatibility with applications that refer to this class explicitly, but also DeclarativeMeta intercepts class-level attribute set and delete operations which is a widely used pattern. An option for declarative base to use `__init_subclass__()` should be provided but this is out of scope for this particular change. Change-Id: I8aa898c7ab59d887739037d34b1cbab36521ab78 References: #6810 --- diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index c5b03dd721..25d3692408 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -15,13 +15,16 @@ at the class level of a particular ``_Dispatch`` class as well as within instances of ``_Dispatch``. """ +from typing import ClassVar +from typing import Optional +from typing import Type import weakref from .attr import _ClsLevelDispatch from .attr import _EmptyListener from .attr import _JoinedListener from .. import util - +from ..util.typing import Protocol _registrars = util.defaultdict(list) @@ -63,8 +66,8 @@ class _Dispatch: of the :class:`._Dispatch` class is returned. A :class:`._Dispatch` class is generated for each :class:`.Events` - class defined, by the :func:`._create_dispatcher_class` function. - The original :class:`.Events` classes remain untouched. + class defined, by the :meth:`._HasEventsDispatch._create_dispatcher_class` + method. The original :class:`.Events` classes remain untouched. This decouples the construction of :class:`.Events` subclasses from the implementation used by the event internals, and allows inspecting tools like Sphinx to work in an unsurprising @@ -78,6 +81,13 @@ class _Dispatch: _empty_listener_reg = weakref.WeakKeyDictionary() + _events: Type["_HasEventsDispatch"] + """reference back to the Events class. + + Bidirectional against _HasEventsDispatch.dispatch + + """ + def __init__(self, parent, instance_cls=None): self._parent = parent self._instance_cls = instance_cls @@ -159,56 +169,6 @@ class _Dispatch: ls.for_modify(self).clear() -class _EventMeta(type): - """Intercept new Event subclasses and create - associated _Dispatch classes.""" - - def __init__(cls, classname, bases, dict_): - _create_dispatcher_class(cls, classname, bases, dict_) - type.__init__(cls, classname, bases, dict_) - - -def _create_dispatcher_class(cls, classname, bases, dict_): - """Create a :class:`._Dispatch` class corresponding to an - :class:`.Events` class.""" - - # there's all kinds of ways to do this, - # i.e. make a Dispatch class that shares the '_listen' method - # of the Event class, this is the straight monkeypatch. - if hasattr(cls, "dispatch"): - dispatch_base = cls.dispatch.__class__ - else: - dispatch_base = _Dispatch - - event_names = [k for k in dict_ if _is_event_name(k)] - dispatch_cls = type( - "%sDispatch" % classname, (dispatch_base,), {"__slots__": event_names} - ) - - dispatch_cls._event_names = event_names - - dispatch_inst = cls._set_dispatch(cls, dispatch_cls) - for k in dispatch_cls._event_names: - setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k])) - _registrars[k].append(cls) - - for super_ in dispatch_cls.__bases__: - if issubclass(super_, _Dispatch) and super_ is not _Dispatch: - for ls in super_._events.dispatch._event_descriptors: - setattr(dispatch_inst, ls.name, ls) - dispatch_cls._event_names.append(ls.name) - - if getattr(cls, "_dispatch_target", None): - the_cls = cls._dispatch_target - if ( - hasattr(the_cls, "__slots__") - and "_slots_dispatch" in the_cls.__slots__ - ): - cls._dispatch_target.dispatch = slots_dispatcher(cls) - else: - cls._dispatch_target.dispatch = dispatcher(cls) - - def _remove_dispatcher(cls): for k in cls.dispatch._event_names: _registrars[k].remove(cls) @@ -216,8 +176,31 @@ def _remove_dispatcher(cls): del _registrars[k] -class Events(metaclass=_EventMeta): - """Define event listening functions for a particular target type.""" +class _HasEventsDispatchProto(Protocol): + """protocol for non-event classes that will also receive the 'dispatch' + attribute in the form of a descriptor. + + """ + + dispatch: ClassVar["dispatcher"] + + +class _HasEventsDispatch: + _dispatch_target: Optional[Type[_HasEventsDispatchProto]] + """class which will receive the .dispatch collection""" + + dispatch: _Dispatch + """reference back to the _Dispatch class. + + Bidirectional against _Dispatch._events + + """ + + def __init_subclass__(cls) -> None: + """Intercept new Event subclasses and create associated _Dispatch + classes.""" + + cls._create_dispatcher_class(cls.__name__, cls.__bases__, cls.__dict__) @staticmethod def _set_dispatch(cls, dispatch_cls): @@ -230,6 +213,54 @@ class Events(metaclass=_EventMeta): dispatch_cls._events = cls return cls.dispatch + @classmethod + def _create_dispatcher_class(cls, classname, bases, dict_): + """Create a :class:`._Dispatch` class corresponding to an + :class:`.Events` class.""" + + # there's all kinds of ways to do this, + # i.e. make a Dispatch class that shares the '_listen' method + # of the Event class, this is the straight monkeypatch. + if hasattr(cls, "dispatch"): + dispatch_base = cls.dispatch.__class__ + else: + dispatch_base = _Dispatch + + event_names = [k for k in dict_ if _is_event_name(k)] + dispatch_cls = type( + "%sDispatch" % classname, + (dispatch_base,), + {"__slots__": event_names}, + ) + + dispatch_cls._event_names = event_names + + dispatch_inst = cls._set_dispatch(cls, dispatch_cls) + for k in dispatch_cls._event_names: + setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k])) + _registrars[k].append(cls) + + for super_ in dispatch_cls.__bases__: + if issubclass(super_, _Dispatch) and super_ is not _Dispatch: + for ls in super_._events.dispatch._event_descriptors: + setattr(dispatch_inst, ls.name, ls) + dispatch_cls._event_names.append(ls.name) + + if getattr(cls, "_dispatch_target", None): + dispatch_target_cls = cls._dispatch_target + assert dispatch_target_cls is not None + if ( + hasattr(dispatch_target_cls, "__slots__") + and "_slots_dispatch" in dispatch_target_cls.__slots__ + ): + dispatch_target_cls.dispatch = slots_dispatcher(cls) + else: + dispatch_target_cls.dispatch = dispatcher(cls) + + +class Events(_HasEventsDispatch): + """Define event listening functions for a particular target type.""" + @classmethod def _accept_with(cls, target): def dispatch_is(*types): diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index a2e1a38260..d75cf667bd 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -269,10 +269,10 @@ and to also route attribute set events via ``__setattr__`` to the def __ne__(self, other): return not self.__eq__(other) -The :class:`.MutableComposite` class uses a Python metaclass to automatically -establish listeners for any usage of :func:`_orm.composite` that specifies our -``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` class, -listeners are established which will route change events from ``Point`` +The :class:`.MutableComposite` class makes use of class mapping events to +automatically establish listeners for any usage of :func:`_orm.composite` that +specifies our ``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` +class, listeners are established which will route change events from ``Point`` objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes:: from sqlalchemy.orm import composite, mapper diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 1094fa5164..2a5b1bb2b0 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -51,6 +51,10 @@ def has_inherited_table(cls): class DeclarativeMeta(type): + # DeclarativeMeta could be replaced by __subclass_init__() + # except for the class-level __setattr__() and __delattr__ hooks, + # which are still very important. + def __init__(cls, classname, bases, dict_, **kw): # early-consume registry from the initial declarative base, # assign privately to not conflict with subclass attributes named diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 6c84f89cfa..d842df2215 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -37,7 +37,7 @@ from ..sql import operators from ..sql import roles from ..sql import visitors from ..sql.base import ExecutableOption -from ..sql.traversals import HasCacheKey +from ..sql.cache_key import HasCacheKey __all__ = ( diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 2e64696d9c..0d87739cc1 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -20,7 +20,7 @@ from .. import exc from .. import inspection from .. import util from ..sql import visitors -from ..sql.traversals import HasCacheKey +from ..sql.cache_key import HasCacheKey log = logging.getLogger(__name__) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index ad31c24329..e6d16f1788 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -134,7 +134,7 @@ class Query( # local Query builder state, not needed for # compilation or execution _enable_assertions = True - _last_joined_entity = None + _statement = None # mirrors that of ClauseElement, used to propagate the "orm" diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 4df275c71d..c2cfbb9fc2 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -29,9 +29,9 @@ from .. import exc as sa_exc from .. import inspect from .. import util from ..sql import and_ +from ..sql import cache_key from ..sql import coercions from ..sql import roles -from ..sql import traversals from ..sql import visitors from ..sql.base import _generative from ..sql.base import Generative @@ -1316,7 +1316,7 @@ class _WildcardLoad(_AbstractLoad): self.__dict__.update(state) -class _LoadElement(traversals.HasCacheKey): +class _LoadElement(cache_key.HasCacheKey): """represents strategy information to select for a LoaderStrategy and pass options to it. diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 805f7b1a02..6ab9a75f6f 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -20,9 +20,9 @@ import typing from . import roles from . import visitors -from .traversals import HasCacheKey # noqa +from .cache_key import HasCacheKey # noqa +from .cache_key import MemoizedHasCacheKey # noqa from .traversals import HasCopyInternals # noqa -from .traversals import MemoizedHasCacheKey # noqa from .visitors import ClauseVisitor from .visitors import ExtendedInternalTraversal from .visitors import InternalTraversal @@ -37,7 +37,6 @@ try: except ImportError: from ._py_util import prefix_anon_map # noqa - coercions = None elements = None type_api = None @@ -610,18 +609,13 @@ class HasCompileState(Generative): class _MetaOptions(type): - """metaclass for the Options class.""" + """metaclass for the Options class. - def __init__(cls, classname, bases, dict_): - cls._cache_attrs = tuple( - sorted( - d - for d in dict_ - if not d.startswith("__") - and d not in ("_cache_key_traversal",) - ) - ) - type.__init__(cls, classname, bases, dict_) + This metaclass is actually necessary despite the availability of the + ``__init_subclass__()`` hook as this type also provides custom class-level + behavior for the ``__add__()`` method. + + """ def __add__(self, other): o1 = self() @@ -640,6 +634,18 @@ class _MetaOptions(type): class Options(metaclass=_MetaOptions): """A cacheable option dictionary with defaults.""" + def __init_subclass__(cls) -> None: + dict_ = cls.__dict__ + cls._cache_attrs = tuple( + sorted( + d + for d in dict_ + if not d.startswith("__") + and d not in ("_cache_key_traversal",) + ) + ) + super().__init_subclass__() + def __init__(self, **kw): self.__dict__.update(kw) diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py new file mode 100644 index 0000000000..8dd44dbf08 --- /dev/null +++ b/lib/sqlalchemy/sql/cache_key.py @@ -0,0 +1,762 @@ +# sql/cache_key.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from collections import namedtuple +import enum +from itertools import zip_longest +from typing import Callable +from typing import Union + +from .visitors import anon_map +from .visitors import ExtendedInternalTraversal +from .visitors import InternalTraversal +from .. import util +from ..inspection import inspect +from ..util import HasMemoized +from ..util.typing import Literal + + +class CacheConst(enum.Enum): + NO_CACHE = 0 + + +NO_CACHE = CacheConst.NO_CACHE + + +class CacheTraverseTarget(enum.Enum): + CACHE_IN_PLACE = 0 + CALL_GEN_CACHE_KEY = 1 + STATIC_CACHE_KEY = 2 + PROPAGATE_ATTRS = 3 + ANON_NAME = 4 + + +( + CACHE_IN_PLACE, + CALL_GEN_CACHE_KEY, + STATIC_CACHE_KEY, + PROPAGATE_ATTRS, + ANON_NAME, +) = tuple(CacheTraverseTarget) + + +class HasCacheKey: + """Mixin for objects which can produce a cache key. + + .. seealso:: + + :class:`.CacheKey` + + :ref:`sql_caching` + + """ + + _cache_key_traversal = NO_CACHE + + _is_has_cache_key = True + + _hierarchy_supports_caching = True + """private attribute which may be set to False to prevent the + inherit_cache warning from being emitted for a hierarchy of subclasses. + + Currently applies to the DDLElement hierarchy which does not implement + caching. + + """ + + inherit_cache = None + """Indicate if this :class:`.HasCacheKey` instance should make use of the + cache key generation scheme used by its immediate superclass. + + The attribute defaults to ``None``, which indicates that a construct has + not yet taken into account whether or not its appropriate for it to + participate in caching; this is functionally equivalent to setting the + value to ``False``, except that a warning is also emitted. + + This flag can be set to ``True`` on a particular class, if the SQL that + corresponds to the object does not change based on attributes which + are local to this class, and not its superclass. + + .. seealso:: + + :ref:`compilerext_caching` - General guideslines for setting the + :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user + defined SQL constructs. + + """ + + __slots__ = () + + @classmethod + def _generate_cache_attrs(cls): + """generate cache key dispatcher for a new class. + + This sets the _generated_cache_key_traversal attribute once called + so should only be called once per class. + + """ + inherit_cache = cls.__dict__.get("inherit_cache", None) + inherit = bool(inherit_cache) + + if inherit: + _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) + if _cache_key_traversal is None: + try: + _cache_key_traversal = cls._traverse_internals + except AttributeError: + cls._generated_cache_key_traversal = NO_CACHE + return NO_CACHE + + # TODO: wouldn't we instead get this from our superclass? + # also, our superclass may not have this yet, but in any case, + # we'd generate for the superclass that has it. this is a little + # more complicated, so for the moment this is a little less + # efficient on startup but simpler. + return _cache_key_traversal_visitor.generate_dispatch( + cls, _cache_key_traversal, "_generated_cache_key_traversal" + ) + else: + _cache_key_traversal = cls.__dict__.get( + "_cache_key_traversal", None + ) + if _cache_key_traversal is None: + _cache_key_traversal = cls.__dict__.get( + "_traverse_internals", None + ) + if _cache_key_traversal is None: + cls._generated_cache_key_traversal = NO_CACHE + if ( + inherit_cache is None + and cls._hierarchy_supports_caching + ): + util.warn( + "Class %s will not make use of SQL compilation " + "caching as it does not set the 'inherit_cache' " + "attribute to ``True``. This can have " + "significant performance implications including " + "some performance degradations in comparison to " + "prior SQLAlchemy versions. Set this attribute " + "to True if this object can make use of the cache " + "key generated by the superclass. Alternatively, " + "this attribute may be set to False which will " + "disable this warning." % (cls.__name__), + code="cprf", + ) + return NO_CACHE + + return _cache_key_traversal_visitor.generate_dispatch( + cls, _cache_key_traversal, "_generated_cache_key_traversal" + ) + + @util.preload_module("sqlalchemy.sql.elements") + def _gen_cache_key(self, anon_map, bindparams): + """return an optional cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the structures that would affect the SQL string or the type handlers + should result in a different cache key. + + If a structure cannot produce a useful cache key, the NO_CACHE + symbol should be added to the anon_map and the method should + return None. + + """ + + cls = self.__class__ + + id_, found = anon_map.get_anon(self) + if found: + return (id_, cls) + + dispatcher: Union[ + Literal[CacheConst.NO_CACHE], + Callable[[HasCacheKey, "_CacheKeyTraversal"], "CacheKey"], + ] + + try: + dispatcher = cls.__dict__["_generated_cache_key_traversal"] + except KeyError: + # most of the dispatchers are generated up front + # in sqlalchemy/sql/__init__.py -> + # traversals.py-> _preconfigure_traversals(). + # this block will generate any remaining dispatchers. + dispatcher = cls._generate_cache_attrs() + + if dispatcher is NO_CACHE: + anon_map[NO_CACHE] = True + return None + + result = (id_, cls) + + # inline of _cache_key_traversal_visitor.run_generated_dispatch() + + for attrname, obj, meth in dispatcher( + self, _cache_key_traversal_visitor + ): + if obj is not None: + # TODO: see if C code can help here as Python lacks an + # efficient switch construct + + if meth is STATIC_CACHE_KEY: + sck = obj._static_cache_key + if sck is NO_CACHE: + anon_map[NO_CACHE] = True + return None + result += (attrname, sck) + elif meth is ANON_NAME: + elements = util.preloaded.sql_elements + if isinstance(obj, elements._anonymous_label): + obj = obj.apply_map(anon_map) + result += (attrname, obj) + elif meth is CALL_GEN_CACHE_KEY: + result += ( + attrname, + obj._gen_cache_key(anon_map, bindparams), + ) + + # remaining cache functions are against + # Python tuples, dicts, lists, etc. so we can skip + # if they are empty + elif obj: + if meth is CACHE_IN_PLACE: + result += (attrname, obj) + elif meth is PROPAGATE_ATTRS: + result += ( + attrname, + obj["compile_state_plugin"], + obj["plugin_subject"]._gen_cache_key( + anon_map, bindparams + ) + if obj["plugin_subject"] + else None, + ) + elif meth is InternalTraversal.dp_annotations_key: + # obj is here is the _annotations dict. however, we + # want to use the memoized cache key version of it. for + # Columns, this should be long lived. For select() + # statements, not so much, but they usually won't have + # annotations. + result += self._annotations_cache_key + elif ( + meth is InternalTraversal.dp_clauseelement_list + or meth is InternalTraversal.dp_clauseelement_tuple + or meth + is InternalTraversal.dp_memoized_select_entities + ): + result += ( + attrname, + tuple( + [ + elem._gen_cache_key(anon_map, bindparams) + for elem in obj + ] + ), + ) + else: + result += meth( + attrname, obj, self, anon_map, bindparams + ) + return result + + def _generate_cache_key(self): + """return a cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the structures that would affect the SQL string or the type handlers + should result in a different cache key. + + The cache key returned by this method is an instance of + :class:`.CacheKey`, which consists of a tuple representing the + cache key, as well as a list of :class:`.BindParameter` objects + which are extracted from the expression. While two expressions + that produce identical cache key tuples will themselves generate + identical SQL strings, the list of :class:`.BindParameter` objects + indicates the bound values which may have different values in + each one; these bound parameters must be consulted in order to + execute the statement with the correct parameters. + + a :class:`_expression.ClauseElement` structure that does not implement + a :meth:`._gen_cache_key` method and does not implement a + :attr:`.traverse_internals` attribute will not be cacheable; when + such an element is embedded into a larger structure, this method + will return None, indicating no cache key is available. + + """ + + bindparams = [] + + _anon_map = anon_map() + key = self._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + return CacheKey(key, bindparams) + + @classmethod + def _generate_cache_key_for_object(cls, obj): + bindparams = [] + + _anon_map = anon_map() + key = obj._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + return CacheKey(key, bindparams) + + +class MemoizedHasCacheKey(HasCacheKey, HasMemoized): + @HasMemoized.memoized_instancemethod + def _generate_cache_key(self): + return HasCacheKey._generate_cache_key(self) + + +class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): + """The key used to identify a SQL statement construct in the + SQL compilation cache. + + .. seealso:: + + :ref:`sql_caching` + + """ + + def __hash__(self): + """CacheKey itself is not hashable - hash the .key portion""" + + return None + + def to_offline_string(self, statement_cache, statement, parameters): + """Generate an "offline string" form of this :class:`.CacheKey` + + The "offline string" is basically the string SQL for the + statement plus a repr of the bound parameter values in series. + Whereas the :class:`.CacheKey` object is dependent on in-memory + identities in order to work as a cache key, the "offline" version + is suitable for a cache that will work for other processes as well. + + The given ``statement_cache`` is a dictionary-like object where the + string form of the statement itself will be cached. This dictionary + should be in a longer lived scope in order to reduce the time spent + stringifying statements. + + + """ + if self.key not in statement_cache: + statement_cache[self.key] = sql_str = str(statement) + else: + sql_str = statement_cache[self.key] + + if not self.bindparams: + param_tuple = tuple(parameters[key] for key in sorted(parameters)) + else: + param_tuple = tuple( + parameters.get(bindparam.key, bindparam.value) + for bindparam in self.bindparams + ) + + return repr((sql_str, param_tuple)) + + def __eq__(self, other): + return self.key == other.key + + @classmethod + def _diff_tuples(cls, left, right): + ck1 = CacheKey(left, []) + ck2 = CacheKey(right, []) + return ck1._diff(ck2) + + def _whats_different(self, other): + + k1 = self.key + k2 = other.key + + stack = [] + pickup_index = 0 + while True: + s1, s2 = k1, k2 + for idx in stack: + s1 = s1[idx] + s2 = s2[idx] + + for idx, (e1, e2) in enumerate(zip_longest(s1, s2)): + if idx < pickup_index: + continue + if e1 != e2: + if isinstance(e1, tuple) and isinstance(e2, tuple): + stack.append(idx) + break + else: + yield "key%s[%d]: %s != %s" % ( + "".join("[%d]" % id_ for id_ in stack), + idx, + e1, + e2, + ) + else: + pickup_index = stack.pop(-1) + break + + def _diff(self, other): + return ", ".join(self._whats_different(other)) + + def __str__(self): + stack = [self.key] + + output = [] + sentinel = object() + indent = -1 + while stack: + elem = stack.pop(0) + if elem is sentinel: + output.append((" " * (indent * 2)) + "),") + indent -= 1 + elif isinstance(elem, tuple): + if not elem: + output.append((" " * ((indent + 1) * 2)) + "()") + else: + indent += 1 + stack = list(elem) + [sentinel] + stack + output.append((" " * (indent * 2)) + "(") + else: + if isinstance(elem, HasCacheKey): + repr_ = "<%s object at %s>" % ( + type(elem).__name__, + hex(id(elem)), + ) + else: + repr_ = repr(elem) + output.append((" " * (indent * 2)) + " " + repr_ + ", ") + + return "CacheKey(key=%s)" % ("\n".join(output),) + + def _generate_param_dict(self): + """used for testing""" + + from .compiler import prefix_anon_map + + _anon_map = prefix_anon_map() + return {b.key % _anon_map: b.effective_value for b in self.bindparams} + + def _apply_params_to_element(self, original_cache_key, target_element): + translate = { + k.key: v.value + for k, v in zip(original_cache_key.bindparams, self.bindparams) + } + + return target_element.params(translate) + + +class _CacheKeyTraversal(ExtendedInternalTraversal): + # very common elements are inlined into the main _get_cache_key() method + # to produce a dramatic savings in Python function call overhead + + visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY + visit_clauseelement_list = InternalTraversal.dp_clauseelement_list + visit_annotations_key = InternalTraversal.dp_annotations_key + visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple + visit_memoized_select_entities = ( + InternalTraversal.dp_memoized_select_entities + ) + + visit_string = ( + visit_boolean + ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE + visit_statement_hint_list = CACHE_IN_PLACE + visit_type = STATIC_CACHE_KEY + visit_anon_name = ANON_NAME + + visit_propagate_attrs = PROPAGATE_ATTRS + + def visit_with_context_options( + self, attrname, obj, parent, anon_map, bindparams + ): + return tuple((fn.__code__, c_key) for fn, c_key in obj) + + def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) + + def visit_string_list(self, attrname, obj, parent, anon_map, bindparams): + return tuple(obj) + + def visit_multi(self, attrname, obj, parent, anon_map, bindparams): + return ( + attrname, + obj._gen_cache_key(anon_map, bindparams) + if isinstance(obj, HasCacheKey) + else obj, + ) + + def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams): + return ( + attrname, + tuple( + elem._gen_cache_key(anon_map, bindparams) + if isinstance(elem, HasCacheKey) + else elem + for elem in obj + ), + ) + + def visit_has_cache_key_tuples( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + return ( + attrname, + tuple( + tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in tup_elem + ) + for tup_elem in obj + ), + ) + + def visit_has_cache_key_list( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + return ( + attrname, + tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), + ) + + def visit_executable_options( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + return ( + attrname, + tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in obj + if elem._is_has_cache_key + ), + ) + + def visit_inspectable_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return self.visit_has_cache_key_list( + attrname, [inspect(o) for o in obj], parent, anon_map, bindparams + ) + + def visit_clauseelement_tuples( + self, attrname, obj, parent, anon_map, bindparams + ): + return self.visit_has_cache_key_tuples( + attrname, obj, parent, anon_map, bindparams + ) + + def visit_fromclause_ordered_set( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + return ( + attrname, + tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]), + ) + + def visit_clauseelement_unordered_set( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + cache_keys = [ + elem._gen_cache_key(anon_map, bindparams) for elem in obj + ] + return ( + attrname, + tuple( + sorted(cache_keys) + ), # cache keys all start with (id_, class) + ) + + def visit_named_ddl_element( + self, attrname, obj, parent, anon_map, bindparams + ): + return (attrname, obj.name) + + def visit_prefix_sequence( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + + return ( + attrname, + tuple( + [ + (clause._gen_cache_key(anon_map, bindparams), strval) + for clause, strval in obj + ] + ), + ) + + def visit_setup_join_tuple( + self, attrname, obj, parent, anon_map, bindparams + ): + return tuple( + ( + target._gen_cache_key(anon_map, bindparams), + onclause._gen_cache_key(anon_map, bindparams) + if onclause is not None + else None, + from_._gen_cache_key(anon_map, bindparams) + if from_ is not None + else None, + tuple([(key, flags[key]) for key in sorted(flags)]), + ) + for (target, onclause, from_, flags) in obj + ) + + def visit_table_hint_list( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + + return ( + attrname, + tuple( + [ + ( + clause._gen_cache_key(anon_map, bindparams), + dialect_name, + text, + ) + for (clause, dialect_name), text in obj.items() + ] + ), + ) + + def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, tuple([(key, obj[key]) for key in sorted(obj)])) + + def visit_dialect_options( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + dialect_name, + tuple( + [ + (key, obj[dialect_name][key]) + for key in sorted(obj[dialect_name]) + ] + ), + ) + for dialect_name in sorted(obj) + ), + ) + + def visit_string_clauseelement_dict( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + (key, obj[key]._gen_cache_key(anon_map, bindparams)) + for key in sorted(obj) + ), + ) + + def visit_string_multi_dict( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + key, + value._gen_cache_key(anon_map, bindparams) + if isinstance(value, HasCacheKey) + else value, + ) + for key, value in [(key, obj[key]) for key in sorted(obj)] + ), + ) + + def visit_fromclause_canonical_column_collection( + self, attrname, obj, parent, anon_map, bindparams + ): + # inlining into the internals of ColumnCollection + return ( + attrname, + tuple( + col._gen_cache_key(anon_map, bindparams) + for k, col in obj._collection + ), + ) + + def visit_unknown_structure( + self, attrname, obj, parent, anon_map, bindparams + ): + anon_map[NO_CACHE] = True + return () + + def visit_dml_ordered_values( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + key._gen_cache_key(anon_map, bindparams) + if hasattr(key, "__clause_element__") + else key, + value._gen_cache_key(anon_map, bindparams), + ) + for key, value in obj + ), + ) + + def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams): + # in py37 we can assume two dictionaries created in the same + # insert ordering will retain that sorting + return ( + attrname, + tuple( + ( + k._gen_cache_key(anon_map, bindparams) + if hasattr(k, "__clause_element__") + else k, + obj[k]._gen_cache_key(anon_map, bindparams), + ) + for k in obj + ), + ) + + def visit_dml_multi_values( + self, attrname, obj, parent, anon_map, bindparams + ): + # multivalues are simply not cacheable right now + anon_map[NO_CACHE] = True + return () + + +_cache_key_traversal_visitor = _CacheKeyTraversal() diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 95697806e6..fe2b498c8e 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -14,7 +14,7 @@ from . import roles from . import visitors from .base import ExecutableOption from .base import Options -from .traversals import HasCacheKey +from .cache_key import HasCacheKey from .visitors import Visitable from .. import exc from .. import inspection diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 697550df44..cb10811c6a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -503,7 +503,7 @@ class Compiled: return self.construct_params() -class TypeCompiler(metaclass=util.EnsureKWArgType): +class TypeCompiler(util.EnsureKWArg): """Produces DDL specification for TypeEngine objects.""" ensure_kwarg = r"visit_\w+" diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 12282de055..a025cce357 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -29,10 +29,10 @@ from .base import HasMemoized from .base import Immutable from .base import NO_ARG from .base import SingletonConstant +from .cache_key import MemoizedHasCacheKey +from .cache_key import NO_CACHE from .coercions import _document_text_coercion from .traversals import HasCopyInternals -from .traversals import MemoizedHasCacheKey -from .traversals import NO_CACHE from .visitors import cloned_traverse from .visitors import InternalTraversal from .visitors import traverse diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 2a3cd07d0a..54f67b930d 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -94,6 +94,7 @@ from .base import _from_objects from .base import _select_iterables from .base import ColumnCollection from .base import Executable +from .cache_key import CacheKey from .dml import Delete from .dml import Insert from .dml import Update @@ -173,7 +174,6 @@ from .selectable import TableValuedAlias from .selectable import TextAsFrom from .selectable import TextualSelect from .selectable import Values -from .traversals import CacheKey from .visitors import Visitable from ..util.langhelpers import public_factory diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index b7b9257b42..3b6da71757 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -37,8 +37,8 @@ from .elements import WithinGroup from .selectable import FromClause from .selectable import Select from .selectable import TableValuedAlias +from .type_api import TypeEngine from .visitors import InternalTraversal -from .visitors import TraversibleType from .. import util @@ -48,7 +48,7 @@ _registry = util.defaultdict(dict) def register_function(identifier, fn, package="_default"): """Associate a callable with a particular func. name. - This is normally called by _GenericMeta, but is also + This is normally called by GenericFunction, but is also available by itself so that a non-Function construct can be associated with the :data:`.func` accessor (i.e. CAST, EXTRACT). @@ -828,7 +828,11 @@ class Function(FunctionElement): ("type", InternalTraversal.dp_type), ] - type = sqltypes.NULLTYPE + name: str + + identifier: str + + type: TypeEngine = sqltypes.NULLTYPE """A :class:`_types.TypeEngine` object which refers to the SQL return type represented by this SQL function. @@ -871,30 +875,7 @@ class Function(FunctionElement): ) -class _GenericMeta(TraversibleType): - def __init__(cls, clsname, bases, clsdict): - if annotation.Annotated not in cls.__mro__: - cls.name = name = clsdict.get("name", clsname) - cls.identifier = identifier = clsdict.get("identifier", name) - package = clsdict.pop("package", "_default") - # legacy - if "__return_type__" in clsdict: - cls.type = clsdict["__return_type__"] - - # Check _register attribute status - cls._register = getattr(cls, "_register", True) - - # Register the function if required - if cls._register: - register_function(identifier, cls, package) - else: - # Set _register to True to register child classes by default - cls._register = True - - super(_GenericMeta, cls).__init__(clsname, bases, clsdict) - - -class GenericFunction(Function, metaclass=_GenericMeta): +class GenericFunction(Function): """Define a 'generic' function. A generic function is a pre-established :class:`.Function` @@ -986,9 +967,34 @@ class GenericFunction(Function, metaclass=_GenericMeta): """ coerce_arguments = True - _register = False inherit_cache = True + name = "GenericFunction" + + def __init_subclass__(cls) -> None: + if annotation.Annotated not in cls.__mro__: + cls._register_generic_function(cls.__name__, cls.__dict__) + super().__init_subclass__() + + @classmethod + def _register_generic_function(cls, clsname, clsdict): + cls.name = name = clsdict.get("name", clsname) + cls.identifier = identifier = clsdict.get("identifier", name) + package = clsdict.get("package", "_default") + # legacy + if "__return_type__" in clsdict: + cls.type = clsdict["__return_type__"] + + # Check _register attribute status + cls._register = getattr(cls, "_register", True) + + # Register the function if required + if cls._register: + register_function(identifier, cls, package) + else: + # Set _register to True to register child classes by default + cls._register = True + def __init__(self, *args, **kwargs): parsed_args = kwargs.pop("_parsed_args", None) if parsed_args is None: @@ -1006,6 +1012,7 @@ class GenericFunction(Function, metaclass=_GenericMeta): self.clause_expr = ClauseList( operator=operators.comma_op, group_contents=True, *parsed_args ).self_group() + self.type = sqltypes.to_instance( kwargs.pop("type_", None) or getattr(self, "type", None) ) diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 2387e551e7..d71c85d609 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -11,6 +11,7 @@ import operator import types import weakref +from . import cache_key as _cache_key from . import coercions from . import elements from . import roles @@ -185,7 +186,7 @@ class LambdaElement(elements.ClauseElement): else: parent_closure_cache_key = () - if parent_closure_cache_key is not traversals.NO_CACHE: + if parent_closure_cache_key is not _cache_key.NO_CACHE: anon_map = traversals.anon_map() cache_key = tuple( [ @@ -194,7 +195,7 @@ class LambdaElement(elements.ClauseElement): ] ) - if traversals.NO_CACHE not in anon_map: + if _cache_key.NO_CACHE not in anon_map: cache_key = parent_closure_cache_key + cache_key self.closure_cache_key = cache_key @@ -204,17 +205,17 @@ class LambdaElement(elements.ClauseElement): except KeyError: rec = None else: - cache_key = traversals.NO_CACHE + cache_key = _cache_key.NO_CACHE rec = None else: - cache_key = traversals.NO_CACHE + cache_key = _cache_key.NO_CACHE rec = None self.closure_cache_key = cache_key if rec is None: - if cache_key is not traversals.NO_CACHE: + if cache_key is not _cache_key.NO_CACHE: rec = AnalyzedFunction( tracker, self, apply_propagate_attrs, fn ) @@ -233,7 +234,7 @@ class LambdaElement(elements.ClauseElement): self._rec = rec - if cache_key is not traversals.NO_CACHE: + if cache_key is not _cache_key.NO_CACHE: if self.parent_lambda is not None: bindparams[:0] = self.parent_lambda._resolved_bindparams @@ -326,8 +327,8 @@ class LambdaElement(elements.ClauseElement): return expr def _gen_cache_key(self, anon_map, bindparams): - if self.closure_cache_key is traversals.NO_CACHE: - anon_map[traversals.NO_CACHE] = True + if self.closure_cache_key is _cache_key.NO_CACHE: + anon_map[_cache_key.NO_CACHE] = True return None cache_key = ( @@ -808,7 +809,7 @@ class AnalyzedCode: for tup_elem in opts.track_on[idx] ) - elif isinstance(elem, traversals.HasCacheKey): + elif isinstance(elem, _cache_key.HasCacheKey): def get(closure, opts, anon_map, bindparams): return opts.track_on[idx]._gen_cache_key(anon_map, bindparams) @@ -834,7 +835,7 @@ class AnalyzedCode: """ - if isinstance(cell_contents, traversals.HasCacheKey): + if isinstance(cell_contents, _cache_key.HasCacheKey): def get(closure, opts, anon_map, bindparams): @@ -1166,7 +1167,7 @@ class PyWrapper(ColumnOperators): and not isinstance( # TODO: coverage where an ORM option or similar is here value, - traversals.HasCacheKey, + _cache_key.HasCacheKey, ) ): name = object.__getattribute__(self, "_name") diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 655d98b02f..e674c4b74d 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -18,6 +18,7 @@ import typing from typing import Type from typing import Union +from . import cache_key from . import coercions from . import operators from . import roles @@ -4300,7 +4301,7 @@ class _SelectFromElements: class _MemoizedSelectEntities( - traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible + cache_key.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible ): __visit_name__ = "memoized_select_entities" diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 05007eff1f..d5506cda22 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -22,11 +22,11 @@ from . import roles from . import type_api from .base import NO_ARG from .base import SchemaEventTarget +from .cache_key import HasCacheKey from .elements import _NONE_NAME from .elements import quoted_name from .elements import Slice from .elements import TypeCoerce as type_coerce # noqa -from .traversals import HasCacheKey from .traversals import InternalTraversal from .type_api import Emulated from .type_api import NativeForEmulated # noqa diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index b689fe5789..2fa3a04083 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -1,32 +1,25 @@ +# sql/traversals.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + from collections import deque -from collections import namedtuple import collections.abc as collections_abc import itertools from itertools import zip_longest import operator from . import operators -from .visitors import ExtendedInternalTraversal +from .visitors import anon_map from .visitors import InternalTraversal from .. import util -from ..inspection import inspect -from ..util import HasMemoized - -try: - from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa -except ImportError: - from ._py_util import cache_anon_map as anon_map # noqa SKIP_TRAVERSE = util.symbol("skip_traverse") COMPARE_FAILED = False COMPARE_SUCCEEDED = True -NO_CACHE = util.symbol("no_cache") -CACHE_IN_PLACE = util.symbol("cache_in_place") -CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key") -STATIC_CACHE_KEY = util.symbol("static_cache_key") -PROPAGATE_ATTRS = util.symbol("propagate_attrs") -ANON_NAME = util.symbol("anon_name") def compare(obj1, obj2, **kw): @@ -54,729 +47,10 @@ def _preconfigure_traversals(target_hierarchy): ) -class HasCacheKey: - """Mixin for objects which can produce a cache key. - - .. seealso:: - - :class:`.CacheKey` - - :ref:`sql_caching` - - """ - - _cache_key_traversal = NO_CACHE - - _is_has_cache_key = True - - _hierarchy_supports_caching = True - """private attribute which may be set to False to prevent the - inherit_cache warning from being emitted for a hierarchy of subclasses. - - Currently applies to the DDLElement hierarchy which does not implement - caching. - - """ - - inherit_cache = None - """Indicate if this :class:`.HasCacheKey` instance should make use of the - cache key generation scheme used by its immediate superclass. - - The attribute defaults to ``None``, which indicates that a construct has - not yet taken into account whether or not its appropriate for it to - participate in caching; this is functionally equivalent to setting the - value to ``False``, except that a warning is also emitted. - - This flag can be set to ``True`` on a particular class, if the SQL that - corresponds to the object does not change based on attributes which - are local to this class, and not its superclass. - - .. seealso:: - - :ref:`compilerext_caching` - General guideslines for setting the - :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user - defined SQL constructs. - - """ - - __slots__ = () - - @classmethod - def _generate_cache_attrs(cls): - """generate cache key dispatcher for a new class. - - This sets the _generated_cache_key_traversal attribute once called - so should only be called once per class. - - """ - inherit_cache = cls.__dict__.get("inherit_cache", None) - inherit = bool(inherit_cache) - - if inherit: - _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) - if _cache_key_traversal is None: - try: - _cache_key_traversal = cls._traverse_internals - except AttributeError: - cls._generated_cache_key_traversal = NO_CACHE - return NO_CACHE - - # TODO: wouldn't we instead get this from our superclass? - # also, our superclass may not have this yet, but in any case, - # we'd generate for the superclass that has it. this is a little - # more complicated, so for the moment this is a little less - # efficient on startup but simpler. - return _cache_key_traversal_visitor.generate_dispatch( - cls, _cache_key_traversal, "_generated_cache_key_traversal" - ) - else: - _cache_key_traversal = cls.__dict__.get( - "_cache_key_traversal", None - ) - if _cache_key_traversal is None: - _cache_key_traversal = cls.__dict__.get( - "_traverse_internals", None - ) - if _cache_key_traversal is None: - cls._generated_cache_key_traversal = NO_CACHE - if ( - inherit_cache is None - and cls._hierarchy_supports_caching - ): - util.warn( - "Class %s will not make use of SQL compilation " - "caching as it does not set the 'inherit_cache' " - "attribute to ``True``. This can have " - "significant performance implications including " - "some performance degradations in comparison to " - "prior SQLAlchemy versions. Set this attribute " - "to True if this object can make use of the cache " - "key generated by the superclass. Alternatively, " - "this attribute may be set to False which will " - "disable this warning." % (cls.__name__), - code="cprf", - ) - return NO_CACHE - - return _cache_key_traversal_visitor.generate_dispatch( - cls, _cache_key_traversal, "_generated_cache_key_traversal" - ) - - @util.preload_module("sqlalchemy.sql.elements") - def _gen_cache_key(self, anon_map, bindparams): - """return an optional cache key. - - The cache key is a tuple which can contain any series of - objects that are hashable and also identifies - this object uniquely within the presence of a larger SQL expression - or statement, for the purposes of caching the resulting query. - - The cache key should be based on the SQL compiled structure that would - ultimately be produced. That is, two structures that are composed in - exactly the same way should produce the same cache key; any difference - in the structures that would affect the SQL string or the type handlers - should result in a different cache key. - - If a structure cannot produce a useful cache key, the NO_CACHE - symbol should be added to the anon_map and the method should - return None. - - """ - - cls = self.__class__ - - id_, found = anon_map.get_anon(self) - if found: - return (id_, cls) - - try: - dispatcher = cls.__dict__["_generated_cache_key_traversal"] - except KeyError: - # most of the dispatchers are generated up front - # in sqlalchemy/sql/__init__.py -> - # traversals.py-> _preconfigure_traversals(). - # this block will generate any remaining dispatchers. - dispatcher = cls._generate_cache_attrs() - - if dispatcher is NO_CACHE: - anon_map[NO_CACHE] = True - return None - - result = (id_, cls) - - # inline of _cache_key_traversal_visitor.run_generated_dispatch() - - for attrname, obj, meth in dispatcher( - self, _cache_key_traversal_visitor - ): - if obj is not None: - # TODO: see if C code can help here as Python lacks an - # efficient switch construct - - if meth is STATIC_CACHE_KEY: - sck = obj._static_cache_key - if sck is NO_CACHE: - anon_map[NO_CACHE] = True - return None - result += (attrname, sck) - elif meth is ANON_NAME: - elements = util.preloaded.sql_elements - if isinstance(obj, elements._anonymous_label): - obj = obj.apply_map(anon_map) - result += (attrname, obj) - elif meth is CALL_GEN_CACHE_KEY: - result += ( - attrname, - obj._gen_cache_key(anon_map, bindparams), - ) - - # remaining cache functions are against - # Python tuples, dicts, lists, etc. so we can skip - # if they are empty - elif obj: - if meth is CACHE_IN_PLACE: - result += (attrname, obj) - elif meth is PROPAGATE_ATTRS: - result += ( - attrname, - obj["compile_state_plugin"], - obj["plugin_subject"]._gen_cache_key( - anon_map, bindparams - ) - if obj["plugin_subject"] - else None, - ) - elif meth is InternalTraversal.dp_annotations_key: - # obj is here is the _annotations dict. however, we - # want to use the memoized cache key version of it. for - # Columns, this should be long lived. For select() - # statements, not so much, but they usually won't have - # annotations. - result += self._annotations_cache_key - elif ( - meth is InternalTraversal.dp_clauseelement_list - or meth is InternalTraversal.dp_clauseelement_tuple - or meth - is InternalTraversal.dp_memoized_select_entities - ): - result += ( - attrname, - tuple( - [ - elem._gen_cache_key(anon_map, bindparams) - for elem in obj - ] - ), - ) - else: - result += meth( - attrname, obj, self, anon_map, bindparams - ) - return result - - def _generate_cache_key(self): - """return a cache key. - - The cache key is a tuple which can contain any series of - objects that are hashable and also identifies - this object uniquely within the presence of a larger SQL expression - or statement, for the purposes of caching the resulting query. - - The cache key should be based on the SQL compiled structure that would - ultimately be produced. That is, two structures that are composed in - exactly the same way should produce the same cache key; any difference - in the structures that would affect the SQL string or the type handlers - should result in a different cache key. - - The cache key returned by this method is an instance of - :class:`.CacheKey`, which consists of a tuple representing the - cache key, as well as a list of :class:`.BindParameter` objects - which are extracted from the expression. While two expressions - that produce identical cache key tuples will themselves generate - identical SQL strings, the list of :class:`.BindParameter` objects - indicates the bound values which may have different values in - each one; these bound parameters must be consulted in order to - execute the statement with the correct parameters. - - a :class:`_expression.ClauseElement` structure that does not implement - a :meth:`._gen_cache_key` method and does not implement a - :attr:`.traverse_internals` attribute will not be cacheable; when - such an element is embedded into a larger structure, this method - will return None, indicating no cache key is available. - - """ - - bindparams = [] - - _anon_map = anon_map() - key = self._gen_cache_key(_anon_map, bindparams) - if NO_CACHE in _anon_map: - return None - else: - return CacheKey(key, bindparams) - - @classmethod - def _generate_cache_key_for_object(cls, obj): - bindparams = [] - - _anon_map = anon_map() - key = obj._gen_cache_key(_anon_map, bindparams) - if NO_CACHE in _anon_map: - return None - else: - return CacheKey(key, bindparams) - - -class MemoizedHasCacheKey(HasCacheKey, HasMemoized): - @HasMemoized.memoized_instancemethod - def _generate_cache_key(self): - return HasCacheKey._generate_cache_key(self) - - -class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): - """The key used to identify a SQL statement construct in the - SQL compilation cache. - - .. seealso:: - - :ref:`sql_caching` - - """ - - def __hash__(self): - """CacheKey itself is not hashable - hash the .key portion""" - - return None - - def to_offline_string(self, statement_cache, statement, parameters): - """Generate an "offline string" form of this :class:`.CacheKey` - - The "offline string" is basically the string SQL for the - statement plus a repr of the bound parameter values in series. - Whereas the :class:`.CacheKey` object is dependent on in-memory - identities in order to work as a cache key, the "offline" version - is suitable for a cache that will work for other processes as well. - - The given ``statement_cache`` is a dictionary-like object where the - string form of the statement itself will be cached. This dictionary - should be in a longer lived scope in order to reduce the time spent - stringifying statements. - - - """ - if self.key not in statement_cache: - statement_cache[self.key] = sql_str = str(statement) - else: - sql_str = statement_cache[self.key] - - if not self.bindparams: - param_tuple = tuple(parameters[key] for key in sorted(parameters)) - else: - param_tuple = tuple( - parameters.get(bindparam.key, bindparam.value) - for bindparam in self.bindparams - ) - - return repr((sql_str, param_tuple)) - - def __eq__(self, other): - return self.key == other.key - - @classmethod - def _diff_tuples(cls, left, right): - ck1 = CacheKey(left, []) - ck2 = CacheKey(right, []) - return ck1._diff(ck2) - - def _whats_different(self, other): - - k1 = self.key - k2 = other.key - - stack = [] - pickup_index = 0 - while True: - s1, s2 = k1, k2 - for idx in stack: - s1 = s1[idx] - s2 = s2[idx] - - for idx, (e1, e2) in enumerate(zip_longest(s1, s2)): - if idx < pickup_index: - continue - if e1 != e2: - if isinstance(e1, tuple) and isinstance(e2, tuple): - stack.append(idx) - break - else: - yield "key%s[%d]: %s != %s" % ( - "".join("[%d]" % id_ for id_ in stack), - idx, - e1, - e2, - ) - else: - pickup_index = stack.pop(-1) - break - - def _diff(self, other): - return ", ".join(self._whats_different(other)) - - def __str__(self): - stack = [self.key] - - output = [] - sentinel = object() - indent = -1 - while stack: - elem = stack.pop(0) - if elem is sentinel: - output.append((" " * (indent * 2)) + "),") - indent -= 1 - elif isinstance(elem, tuple): - if not elem: - output.append((" " * ((indent + 1) * 2)) + "()") - else: - indent += 1 - stack = list(elem) + [sentinel] + stack - output.append((" " * (indent * 2)) + "(") - else: - if isinstance(elem, HasCacheKey): - repr_ = "<%s object at %s>" % ( - type(elem).__name__, - hex(id(elem)), - ) - else: - repr_ = repr(elem) - output.append((" " * (indent * 2)) + " " + repr_ + ", ") - - return "CacheKey(key=%s)" % ("\n".join(output),) - - def _generate_param_dict(self): - """used for testing""" - - from .compiler import prefix_anon_map - - _anon_map = prefix_anon_map() - return {b.key % _anon_map: b.effective_value for b in self.bindparams} - - def _apply_params_to_element(self, original_cache_key, target_element): - translate = { - k.key: v.value - for k, v in zip(original_cache_key.bindparams, self.bindparams) - } - - return target_element.params(translate) - - def _clone(element, **kw): return element._clone() -class _CacheKey(ExtendedInternalTraversal): - # very common elements are inlined into the main _get_cache_key() method - # to produce a dramatic savings in Python function call overhead - - visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY - visit_clauseelement_list = InternalTraversal.dp_clauseelement_list - visit_annotations_key = InternalTraversal.dp_annotations_key - visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple - visit_memoized_select_entities = ( - InternalTraversal.dp_memoized_select_entities - ) - - visit_string = ( - visit_boolean - ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE - visit_statement_hint_list = CACHE_IN_PLACE - visit_type = STATIC_CACHE_KEY - visit_anon_name = ANON_NAME - - visit_propagate_attrs = PROPAGATE_ATTRS - - def visit_with_context_options( - self, attrname, obj, parent, anon_map, bindparams - ): - return tuple((fn.__code__, c_key) for fn, c_key in obj) - - def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): - return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) - - def visit_string_list(self, attrname, obj, parent, anon_map, bindparams): - return tuple(obj) - - def visit_multi(self, attrname, obj, parent, anon_map, bindparams): - return ( - attrname, - obj._gen_cache_key(anon_map, bindparams) - if isinstance(obj, HasCacheKey) - else obj, - ) - - def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams): - return ( - attrname, - tuple( - elem._gen_cache_key(anon_map, bindparams) - if isinstance(elem, HasCacheKey) - else elem - for elem in obj - ), - ) - - def visit_has_cache_key_tuples( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - return ( - attrname, - tuple( - tuple( - elem._gen_cache_key(anon_map, bindparams) - for elem in tup_elem - ) - for tup_elem in obj - ), - ) - - def visit_has_cache_key_list( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - return ( - attrname, - tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), - ) - - def visit_executable_options( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - return ( - attrname, - tuple( - elem._gen_cache_key(anon_map, bindparams) - for elem in obj - if elem._is_has_cache_key - ), - ) - - def visit_inspectable_list( - self, attrname, obj, parent, anon_map, bindparams - ): - return self.visit_has_cache_key_list( - attrname, [inspect(o) for o in obj], parent, anon_map, bindparams - ) - - def visit_clauseelement_tuples( - self, attrname, obj, parent, anon_map, bindparams - ): - return self.visit_has_cache_key_tuples( - attrname, obj, parent, anon_map, bindparams - ) - - def visit_fromclause_ordered_set( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - return ( - attrname, - tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]), - ) - - def visit_clauseelement_unordered_set( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - cache_keys = [ - elem._gen_cache_key(anon_map, bindparams) for elem in obj - ] - return ( - attrname, - tuple( - sorted(cache_keys) - ), # cache keys all start with (id_, class) - ) - - def visit_named_ddl_element( - self, attrname, obj, parent, anon_map, bindparams - ): - return (attrname, obj.name) - - def visit_prefix_sequence( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - - return ( - attrname, - tuple( - [ - (clause._gen_cache_key(anon_map, bindparams), strval) - for clause, strval in obj - ] - ), - ) - - def visit_setup_join_tuple( - self, attrname, obj, parent, anon_map, bindparams - ): - is_legacy = "legacy" in attrname - - return tuple( - ( - target - if is_legacy and isinstance(target, str) - else target._gen_cache_key(anon_map, bindparams), - onclause - if is_legacy and isinstance(onclause, str) - else onclause._gen_cache_key(anon_map, bindparams) - if onclause is not None - else None, - from_._gen_cache_key(anon_map, bindparams) - if from_ is not None - else None, - tuple([(key, flags[key]) for key in sorted(flags)]), - ) - for (target, onclause, from_, flags) in obj - ) - - def visit_table_hint_list( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - - return ( - attrname, - tuple( - [ - ( - clause._gen_cache_key(anon_map, bindparams), - dialect_name, - text, - ) - for (clause, dialect_name), text in obj.items() - ] - ), - ) - - def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams): - return (attrname, tuple([(key, obj[key]) for key in sorted(obj)])) - - def visit_dialect_options( - self, attrname, obj, parent, anon_map, bindparams - ): - return ( - attrname, - tuple( - ( - dialect_name, - tuple( - [ - (key, obj[dialect_name][key]) - for key in sorted(obj[dialect_name]) - ] - ), - ) - for dialect_name in sorted(obj) - ), - ) - - def visit_string_clauseelement_dict( - self, attrname, obj, parent, anon_map, bindparams - ): - return ( - attrname, - tuple( - (key, obj[key]._gen_cache_key(anon_map, bindparams)) - for key in sorted(obj) - ), - ) - - def visit_string_multi_dict( - self, attrname, obj, parent, anon_map, bindparams - ): - return ( - attrname, - tuple( - ( - key, - value._gen_cache_key(anon_map, bindparams) - if isinstance(value, HasCacheKey) - else value, - ) - for key, value in [(key, obj[key]) for key in sorted(obj)] - ), - ) - - def visit_fromclause_canonical_column_collection( - self, attrname, obj, parent, anon_map, bindparams - ): - # inlining into the internals of ColumnCollection - return ( - attrname, - tuple( - col._gen_cache_key(anon_map, bindparams) - for k, col in obj._collection - ), - ) - - def visit_unknown_structure( - self, attrname, obj, parent, anon_map, bindparams - ): - anon_map[NO_CACHE] = True - return () - - def visit_dml_ordered_values( - self, attrname, obj, parent, anon_map, bindparams - ): - return ( - attrname, - tuple( - ( - key._gen_cache_key(anon_map, bindparams) - if hasattr(key, "__clause_element__") - else key, - value._gen_cache_key(anon_map, bindparams), - ) - for key, value in obj - ), - ) - - def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams): - # in py37 we can assume two dictionaries created in the same - # insert ordering will retain that sorting - return ( - attrname, - tuple( - ( - k._gen_cache_key(anon_map, bindparams) - if hasattr(k, "__clause_element__") - else k, - obj[k]._gen_cache_key(anon_map, bindparams), - ) - for k in obj - ), - ) - - def visit_dml_multi_values( - self, attrname, obj, parent, anon_map, bindparams - ): - # multivalues are simply not cacheable right now - anon_map[NO_CACHE] = True - return () - - -_cache_key_traversal_visitor = _CacheKey() - - class HasCopyInternals: __slots__ = () @@ -813,7 +87,7 @@ class HasCopyInternals: setattr(self, attrname, result) -class _CopyInternals(InternalTraversal): +class _CopyInternalsTraversal(InternalTraversal): """Generate a _copy_internals internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -936,7 +210,7 @@ class _CopyInternals(InternalTraversal): return element -_copy_internals = _CopyInternals() +_copy_internals = _CopyInternalsTraversal() def _flatten_clauseelement(element): @@ -948,7 +222,7 @@ def _flatten_clauseelement(element): return element -class _GetChildren(InternalTraversal): +class _GetChildrenTraversal(InternalTraversal): """Generate a _children_traversal internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -1019,7 +293,7 @@ class _GetChildren(InternalTraversal): return () -_get_children = _GetChildren() +_get_children = _GetChildrenTraversal() @util.preload_module("sqlalchemy.sql.elements") diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 9dd702410b..55289cb853 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -13,9 +13,8 @@ import typing from . import operators from .base import SchemaEventTarget -from .traversals import NO_CACHE +from .cache_key import NO_CACHE from .visitors import Traversible -from .visitors import TraversibleType from .. import exc from .. import util @@ -858,10 +857,6 @@ class TypeEngine(Traversible): return util.generic_repr(self) -class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType): - pass - - class ExternalType: """mixin that defines attributes and behaviors specific to third-party datatypes. @@ -1038,7 +1033,7 @@ class ExternalType: return NO_CACHE -class UserDefinedType(ExternalType, TypeEngine, metaclass=VisitableCheckKWArg): +class UserDefinedType(ExternalType, TypeEngine, util.EnsureKWArg): """Base for user defined types. This should be the base of new types. Note that diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 69e83e46a6..63067585ed 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -22,6 +22,7 @@ from .annotation import _shallow_annotate # noqa from .base import _expand_cloned from .base import _from_objects from .base import ColumnSet +from .cache_key import HasCacheKey # noqa from .ddl import sort_tables # noqa from .elements import _find_columns # noqa from .elements import _label_reference @@ -41,7 +42,6 @@ from .selectable import Join from .selectable import ScalarSelect from .selectable import SelectBase from .selectable import TableClause -from .traversals import HasCacheKey # noqa from .. import exc from .. import util diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 87fe369444..70c4dc1332 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -32,6 +32,11 @@ from .. import util from ..util import langhelpers from ..util import symbol +try: + from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa +except ImportError: + from ._py_util import cache_anon_map as anon_map # noqa + __all__ = [ "iterate", "traverse_using", @@ -39,88 +44,77 @@ __all__ = [ "cloned_traverse", "replacement_traverse", "Traversible", - "TraversibleType", "ExternalTraversal", "InternalTraversal", ] -def _generate_compiler_dispatch(cls): - """Generate a _compiler_dispatch() external traversal on classes with a - __visit_name__ attribute. - - """ - visit_name = cls.__visit_name__ - - if "_compiler_dispatch" in cls.__dict__: - # class has a fixed _compiler_dispatch() method. - # copy it to "original" so that we can get it back if - # sqlalchemy.ext.compiles overrides it. - cls._original_compiler_dispatch = cls._compiler_dispatch - return - - if not isinstance(visit_name, str): - raise exc.InvalidRequestError( - "__visit_name__ on class %s must be a string at the class level" - % cls.__name__ - ) - - name = "visit_%s" % visit_name - getter = operator.attrgetter(name) - - def _compiler_dispatch(self, visitor, **kw): - """Look for an attribute named "visit_" on the - visitor, and call it with the same kw params. - - """ - try: - meth = getter(visitor) - except AttributeError as err: - return visitor.visit_unsupported_compilation(self, err, **kw) - - else: - return meth(self, **kw) - - cls._compiler_dispatch = ( - cls._original_compiler_dispatch - ) = _compiler_dispatch +class Traversible: + """Base class for visitable objects.""" + __slots__ = () -class TraversibleType(type): - """Metaclass which assigns dispatch attributes to various kinds of - "visitable" classes. + __visit_name__: str - Attributes include: + def __init_subclass__(cls) -> None: + if "__visit_name__" in cls.__dict__: + cls._generate_compiler_dispatch() + super().__init_subclass__() - * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``. - This is called "external traversal" because the caller of each visit() - method is responsible for sub-traversing the inner elements of each - object. This is appropriate for string compilers and other traversals - that need to call upon the inner elements in a specific pattern. + @classmethod + def _generate_compiler_dispatch(cls): + """Assign dispatch attributes to various kinds of + "visitable" classes. - * internal traversal collections ``_children_traversal``, - ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from - an optional ``_traverse_internals`` collection of symbols which comes - from the :class:`.InternalTraversal` list of symbols. This is called - "internal traversal" MARKMARK + Attributes include: - """ + * The ``_compiler_dispatch`` method, corresponding to + ``__visit_name__``. This is called "external traversal" because the + caller of each visit() method is responsible for sub-traversing the + inner elements of each object. This is appropriate for string + compilers and other traversals that need to call upon the inner + elements in a specific pattern. - def __init__(cls, clsname, bases, clsdict): - if clsname != "Traversible": - if "__visit_name__" in clsdict: - _generate_compiler_dispatch(cls) + * internal traversal collections ``_children_traversal``, + ``_cache_key_traversal``, ``_copy_internals_traversal``, generated + from an optional ``_traverse_internals`` collection of symbols which + comes from the :class:`.InternalTraversal` list of symbols. This is + called "internal traversal". - super(TraversibleType, cls).__init__(clsname, bases, clsdict) + """ + visit_name = cls.__visit_name__ + + if "_compiler_dispatch" in cls.__dict__: + # class has a fixed _compiler_dispatch() method. + # copy it to "original" so that we can get it back if + # sqlalchemy.ext.compiles overrides it. + cls._original_compiler_dispatch = cls._compiler_dispatch + return + + if not isinstance(visit_name, str): + raise exc.InvalidRequestError( + f"__visit_name__ on class {cls.__name__} must be a string " + "at the class level" + ) + name = "visit_%s" % visit_name + getter = operator.attrgetter(name) -class Traversible(metaclass=TraversibleType): - """Base class for visitable objects, applies the - :class:`.visitors.TraversibleType` metaclass. + def _compiler_dispatch(self, visitor, **kw): + """Look for an attribute named "visit_" on the + visitor, and call it with the same kw params. - """ + """ + try: + meth = getter(visitor) + except AttributeError as err: + return visitor.visit_unsupported_compilation(self, err, **kw) + else: + return meth(self, **kw) - __slots__ = () + cls._compiler_dispatch = ( + cls._original_compiler_dispatch + ) = _compiler_dispatch def __class_getitem__(cls, key): # allow generic classes in py3.9+ @@ -159,48 +153,90 @@ class Traversible(metaclass=TraversibleType): ) -class _InternalTraversalType(type): - def __init__(cls, clsname, bases, clsdict): - if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"): - lookup = {} - for key, sym in clsdict.items(): - if key.startswith("dp_"): - visit_key = key.replace("dp_", "visit_") - sym_name = sym.name - assert sym_name not in lookup, sym_name - lookup[sym] = lookup[sym_name] = visit_key - if hasattr(cls, "_dispatch_lookup"): - lookup.update(cls._dispatch_lookup) - cls._dispatch_lookup = lookup +class _HasTraversalDispatch: + r"""Define infrastructure for the :class:`.InternalTraversal` class. - super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict) + .. versionadded:: 2.0 + """ -def _generate_dispatcher(visitor, internal_dispatch, method_name): - names = [] - for attrname, visit_sym in internal_dispatch: - meth = visitor.dispatch(visit_sym) - if meth: - visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym] - names.append((attrname, visit_name)) - - code = ( - (" return [\n") - + ( - ", \n".join( - " (%r, self.%s, visitor.%s)" - % (attrname, attrname, visit_name) - for attrname, visit_name in names + def __init_subclass__(cls) -> None: + cls._generate_traversal_dispatch() + super().__init_subclass__() + + def dispatch(self, visit_symbol): + """Given a method from :class:`._HasTraversalDispatch`, return the + corresponding method on a subclass. + + """ + name = self._dispatch_lookup[visit_symbol] + return getattr(self, name, None) + + def run_generated_dispatch( + self, target, internal_dispatch, generate_dispatcher_name + ): + try: + dispatcher = target.__class__.__dict__[generate_dispatcher_name] + except KeyError: + # most of the dispatchers are generated up front + # in sqlalchemy/sql/__init__.py -> + # traversals.py-> _preconfigure_traversals(). + # this block will generate any remaining dispatchers. + dispatcher = self.generate_dispatch( + target.__class__, internal_dispatch, generate_dispatcher_name + ) + return dispatcher(target, self) + + def generate_dispatch( + self, target_cls, internal_dispatch, generate_dispatcher_name + ): + dispatcher = self._generate_dispatcher( + internal_dispatch, generate_dispatcher_name + ) + # assert isinstance(target_cls, type) + setattr(target_cls, generate_dispatcher_name, dispatcher) + return dispatcher + + @classmethod + def _generate_traversal_dispatch(cls): + lookup = {} + clsdict = cls.__dict__ + for key, sym in clsdict.items(): + if key.startswith("dp_"): + visit_key = key.replace("dp_", "visit_") + sym_name = sym.name + assert sym_name not in lookup, sym_name + lookup[sym] = lookup[sym_name] = visit_key + if hasattr(cls, "_dispatch_lookup"): + lookup.update(cls._dispatch_lookup) + cls._dispatch_lookup = lookup + + def _generate_dispatcher(self, internal_dispatch, method_name): + names = [] + for attrname, visit_sym in internal_dispatch: + meth = self.dispatch(visit_sym) + if meth: + visit_name = ExtendedInternalTraversal._dispatch_lookup[ + visit_sym + ] + names.append((attrname, visit_name)) + + code = ( + (" return [\n") + + ( + ", \n".join( + " (%r, self.%s, visitor.%s)" + % (attrname, attrname, visit_name) + for attrname, visit_name in names + ) ) + + ("\n ]\n") ) - + ("\n ]\n") - ) - meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" - # print(meth_text) - return langhelpers._exec_code_in_env(meth_text, {}, method_name) + meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" + return langhelpers._exec_code_in_env(meth_text, {}, method_name) -class InternalTraversal(metaclass=_InternalTraversalType): +class InternalTraversal(_HasTraversalDispatch): r"""Defines visitor symbols used for internal traversal. The :class:`.InternalTraversal` class is used in two ways. One is that @@ -239,39 +275,6 @@ class InternalTraversal(metaclass=_InternalTraversalType): """ - def dispatch(self, visit_symbol): - """Given a method from :class:`.InternalTraversal`, return the - corresponding method on a subclass. - - """ - name = self._dispatch_lookup[visit_symbol] - return getattr(self, name, None) - - def run_generated_dispatch( - self, target, internal_dispatch, generate_dispatcher_name - ): - try: - dispatcher = target.__class__.__dict__[generate_dispatcher_name] - except KeyError: - # most of the dispatchers are generated up front - # in sqlalchemy/sql/__init__.py -> - # traversals.py-> _preconfigure_traversals(). - # this block will generate any remaining dispatchers. - dispatcher = self.generate_dispatch( - target.__class__, internal_dispatch, generate_dispatcher_name - ) - return dispatcher(target, self) - - def generate_dispatch( - self, target_cls, internal_dispatch, generate_dispatcher_name - ): - dispatcher = _generate_dispatcher( - self, internal_dispatch, generate_dispatcher_name - ) - # assert isinstance(target_cls, type) - setattr(target_cls, generate_dispatcher_name, dispatcher) - return dispatcher - dp_has_cache_key = symbol("HC") """Visit a :class:`.HasCacheKey` object.""" @@ -623,7 +626,6 @@ class ReplacingExternalTraversal(CloningExternalTraversal): # backwards compatibility Visitable = Traversible -VisitableType = TraversibleType ClauseVisitor = ExternalTraversal CloningVisitor = CloningExternalTraversal ReplacingCloningVisitor = ReplacingExternalTraversal diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 74c86e85ab..ecc20f1638 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -21,7 +21,6 @@ from .. import event from .. import util from ..orm import declarative_base from ..orm import registry -from ..orm.decl_api import DeclarativeMeta from ..schema import sort_tables_and_constraints @@ -647,15 +646,11 @@ class MappedTest(TablesTest, assertions.AssertsExecutionResults): """ cls_registry = cls.classes - assert cls_registry is not None - - class FindFixture(type): - def __init__(cls, classname, bases, dict_): - cls_registry[classname] = cls - type.__init__(cls, classname, bases, dict_) - - class _Base(metaclass=FindFixture): - pass + class _Base: + def __init_subclass__(cls) -> None: + assert cls_registry is not None + cls_registry[cls.__name__] = cls + super().__init_subclass__() class Basic(BasicEntity, _Base): pass @@ -699,17 +694,16 @@ class DeclarativeMappedTest(MappedTest): def _with_register_classes(cls, fn): cls_registry = cls.classes - class FindFixtureDeclarative(DeclarativeMeta): - def __init__(cls, classname, bases, dict_): - cls_registry[classname] = cls - DeclarativeMeta.__init__(cls, classname, bases, dict_) - class DeclarativeBasic: __table_cls__ = schema.Table + def __init_subclass__(cls) -> None: + assert cls_registry is not None + cls_registry[cls.__name__] = cls + super().__init_subclass__() + _DeclBase = declarative_base( metadata=cls._tables_metadata, - metaclass=FindFixtureDeclarative, cls=DeclarativeBasic, ) diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 2d2ff35651..7c03bcd4ba 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -98,7 +98,7 @@ from .langhelpers import decorator from .langhelpers import dictlike_iteritems from .langhelpers import duck_type_collection from .langhelpers import ellipses_string -from .langhelpers import EnsureKWArgType +from .langhelpers import EnsureKWArg from .langhelpers import format_argspec_init from .langhelpers import format_argspec_plus from .langhelpers import generic_repr diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 1b277cdee9..66c5308678 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1743,14 +1743,28 @@ def attrsetter(attrname): return env["set"] -class EnsureKWArgType(type): +class EnsureKWArg: r"""Apply translation of functions to accept \**kw arguments if they don't already. + Used to ensure cross-compatibility with third party legacy code, for things + like compiler visit methods that need to accept ``**kw`` arguments, + but may have been copied from old code that didn't accept them. + + """ + + ensure_kwarg: str + """a regular expression that indicates method names for which the method + should accept ``**kw`` arguments. + + The class will scan for methods matching the name template and decorate + them if necessary to ensure ``**kw`` parameters are accepted. + """ - def __init__(cls, clsname, bases, clsdict): + def __init_subclass__(cls) -> None: fn_reg = cls.ensure_kwarg + clsdict = cls.__dict__ if fn_reg: for key in clsdict: m = re.match(fn_reg, key) @@ -1758,11 +1772,12 @@ class EnsureKWArgType(type): fn = clsdict[key] spec = compat.inspect_getfullargspec(fn) if not spec.varkw: - clsdict[key] = wrapped = cls._wrap_w_kw(fn) + wrapped = cls._wrap_w_kw(fn) setattr(cls, key, wrapped) - super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict) + super().__init_subclass__() - def _wrap_w_kw(self, fn): + @classmethod + def _wrap_w_kw(cls, fn): def wrap(*arg, **kw): return fn(*arg) diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 92ef241ed7..37b4d0ae1f 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -32,7 +32,7 @@ from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import subqueryload from sqlalchemy.orm import UserDefinedOption -from sqlalchemy.sql.traversals import NO_CACHE +from sqlalchemy.sql.cache_key import NO_CACHE from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index bbf9716f58..43aed06725 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -18,7 +18,7 @@ from sqlalchemy.sql import select from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.sql.base import ExecutableOption -from sqlalchemy.sql.traversals import HasCacheKey +from sqlalchemy.sql.cache_key import HasCacheKey from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_