instances of ``_Dispatch``.
"""
+from typing import ClassVar
+from typing import Optional
+from typing import Type
import weakref
from .attr import _ClsLevelDispatch
from .attr import _EmptyListener
from .attr import _JoinedListener
from .. import util
-
+from ..util.typing import Protocol
_registrars = util.defaultdict(list)
of the :class:`._Dispatch` class is returned.
A :class:`._Dispatch` class is generated for each :class:`.Events`
- class defined, by the :func:`._create_dispatcher_class` function.
- The original :class:`.Events` classes remain untouched.
+ class defined, by the :meth:`._HasEventsDispatch._create_dispatcher_class`
+ method. The original :class:`.Events` classes remain untouched.
This decouples the construction of :class:`.Events` subclasses from
the implementation used by the event internals, and allows
inspecting tools like Sphinx to work in an unsurprising
_empty_listener_reg = weakref.WeakKeyDictionary()
+ _events: Type["_HasEventsDispatch"]
+ """reference back to the Events class.
+
+ Bidirectional against _HasEventsDispatch.dispatch
+
+ """
+
def __init__(self, parent, instance_cls=None):
self._parent = parent
self._instance_cls = instance_cls
ls.for_modify(self).clear()
-class _EventMeta(type):
- """Intercept new Event subclasses and create
- associated _Dispatch classes."""
-
- def __init__(cls, classname, bases, dict_):
- _create_dispatcher_class(cls, classname, bases, dict_)
- type.__init__(cls, classname, bases, dict_)
-
-
-def _create_dispatcher_class(cls, classname, bases, dict_):
- """Create a :class:`._Dispatch` class corresponding to an
- :class:`.Events` class."""
-
- # there's all kinds of ways to do this,
- # i.e. make a Dispatch class that shares the '_listen' method
- # of the Event class, this is the straight monkeypatch.
- if hasattr(cls, "dispatch"):
- dispatch_base = cls.dispatch.__class__
- else:
- dispatch_base = _Dispatch
-
- event_names = [k for k in dict_ if _is_event_name(k)]
- dispatch_cls = type(
- "%sDispatch" % classname, (dispatch_base,), {"__slots__": event_names}
- )
-
- dispatch_cls._event_names = event_names
-
- dispatch_inst = cls._set_dispatch(cls, dispatch_cls)
- for k in dispatch_cls._event_names:
- setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k]))
- _registrars[k].append(cls)
-
- for super_ in dispatch_cls.__bases__:
- if issubclass(super_, _Dispatch) and super_ is not _Dispatch:
- for ls in super_._events.dispatch._event_descriptors:
- setattr(dispatch_inst, ls.name, ls)
- dispatch_cls._event_names.append(ls.name)
-
- if getattr(cls, "_dispatch_target", None):
- the_cls = cls._dispatch_target
- if (
- hasattr(the_cls, "__slots__")
- and "_slots_dispatch" in the_cls.__slots__
- ):
- cls._dispatch_target.dispatch = slots_dispatcher(cls)
- else:
- cls._dispatch_target.dispatch = dispatcher(cls)
-
-
def _remove_dispatcher(cls):
for k in cls.dispatch._event_names:
_registrars[k].remove(cls)
del _registrars[k]
-class Events(metaclass=_EventMeta):
- """Define event listening functions for a particular target type."""
+class _HasEventsDispatchProto(Protocol):
+ """protocol for non-event classes that will also receive the 'dispatch'
+ attribute in the form of a descriptor.
+
+ """
+
+ dispatch: ClassVar["dispatcher"]
+
+
+class _HasEventsDispatch:
+ _dispatch_target: Optional[Type[_HasEventsDispatchProto]]
+ """class which will receive the .dispatch collection"""
+
+ dispatch: _Dispatch
+ """reference back to the _Dispatch class.
+
+ Bidirectional against _Dispatch._events
+
+ """
+
+ def __init_subclass__(cls) -> None:
+ """Intercept new Event subclasses and create associated _Dispatch
+ classes."""
+
+ cls._create_dispatcher_class(cls.__name__, cls.__bases__, cls.__dict__)
@staticmethod
def _set_dispatch(cls, dispatch_cls):
dispatch_cls._events = cls
return cls.dispatch
+ @classmethod
+ def _create_dispatcher_class(cls, classname, bases, dict_):
+ """Create a :class:`._Dispatch` class corresponding to an
+ :class:`.Events` class."""
+
+ # there's all kinds of ways to do this,
+ # i.e. make a Dispatch class that shares the '_listen' method
+ # of the Event class, this is the straight monkeypatch.
+ if hasattr(cls, "dispatch"):
+ dispatch_base = cls.dispatch.__class__
+ else:
+ dispatch_base = _Dispatch
+
+ event_names = [k for k in dict_ if _is_event_name(k)]
+ dispatch_cls = type(
+ "%sDispatch" % classname,
+ (dispatch_base,),
+ {"__slots__": event_names},
+ )
+
+ dispatch_cls._event_names = event_names
+
+ dispatch_inst = cls._set_dispatch(cls, dispatch_cls)
+ for k in dispatch_cls._event_names:
+ setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k]))
+ _registrars[k].append(cls)
+
+ for super_ in dispatch_cls.__bases__:
+ if issubclass(super_, _Dispatch) and super_ is not _Dispatch:
+ for ls in super_._events.dispatch._event_descriptors:
+ setattr(dispatch_inst, ls.name, ls)
+ dispatch_cls._event_names.append(ls.name)
+
+ if getattr(cls, "_dispatch_target", None):
+ dispatch_target_cls = cls._dispatch_target
+ assert dispatch_target_cls is not None
+ if (
+ hasattr(dispatch_target_cls, "__slots__")
+ and "_slots_dispatch" in dispatch_target_cls.__slots__
+ ):
+ dispatch_target_cls.dispatch = slots_dispatcher(cls)
+ else:
+ dispatch_target_cls.dispatch = dispatcher(cls)
+
+
+class Events(_HasEventsDispatch):
+ """Define event listening functions for a particular target type."""
+
@classmethod
def _accept_with(cls, target):
def dispatch_is(*types):
def __ne__(self, other):
return not self.__eq__(other)
-The :class:`.MutableComposite` class uses a Python metaclass to automatically
-establish listeners for any usage of :func:`_orm.composite` that specifies our
-``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` class,
-listeners are established which will route change events from ``Point``
+The :class:`.MutableComposite` class makes use of class mapping events to
+automatically establish listeners for any usage of :func:`_orm.composite` that
+specifies our ``Point`` type. Below, when ``Point`` is mapped to the ``Vertex``
+class, listeners are established which will route change events from ``Point``
objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes::
from sqlalchemy.orm import composite, mapper
class DeclarativeMeta(type):
+ # DeclarativeMeta could be replaced by __subclass_init__()
+ # except for the class-level __setattr__() and __delattr__ hooks,
+ # which are still very important.
+
def __init__(cls, classname, bases, dict_, **kw):
# early-consume registry from the initial declarative base,
# assign privately to not conflict with subclass attributes named
from ..sql import roles
from ..sql import visitors
from ..sql.base import ExecutableOption
-from ..sql.traversals import HasCacheKey
+from ..sql.cache_key import HasCacheKey
__all__ = (
from .. import inspection
from .. import util
from ..sql import visitors
-from ..sql.traversals import HasCacheKey
+from ..sql.cache_key import HasCacheKey
log = logging.getLogger(__name__)
# local Query builder state, not needed for
# compilation or execution
_enable_assertions = True
- _last_joined_entity = None
+
_statement = None
# mirrors that of ClauseElement, used to propagate the "orm"
from .. import inspect
from .. import util
from ..sql import and_
+from ..sql import cache_key
from ..sql import coercions
from ..sql import roles
-from ..sql import traversals
from ..sql import visitors
from ..sql.base import _generative
from ..sql.base import Generative
self.__dict__.update(state)
-class _LoadElement(traversals.HasCacheKey):
+class _LoadElement(cache_key.HasCacheKey):
"""represents strategy information to select for a LoaderStrategy
and pass options to it.
from . import roles
from . import visitors
-from .traversals import HasCacheKey # noqa
+from .cache_key import HasCacheKey # noqa
+from .cache_key import MemoizedHasCacheKey # noqa
from .traversals import HasCopyInternals # noqa
-from .traversals import MemoizedHasCacheKey # noqa
from .visitors import ClauseVisitor
from .visitors import ExtendedInternalTraversal
from .visitors import InternalTraversal
except ImportError:
from ._py_util import prefix_anon_map # noqa
-
coercions = None
elements = None
type_api = None
class _MetaOptions(type):
- """metaclass for the Options class."""
+ """metaclass for the Options class.
- def __init__(cls, classname, bases, dict_):
- cls._cache_attrs = tuple(
- sorted(
- d
- for d in dict_
- if not d.startswith("__")
- and d not in ("_cache_key_traversal",)
- )
- )
- type.__init__(cls, classname, bases, dict_)
+ This metaclass is actually necessary despite the availability of the
+ ``__init_subclass__()`` hook as this type also provides custom class-level
+ behavior for the ``__add__()`` method.
+
+ """
def __add__(self, other):
o1 = self()
class Options(metaclass=_MetaOptions):
"""A cacheable option dictionary with defaults."""
+ def __init_subclass__(cls) -> None:
+ dict_ = cls.__dict__
+ cls._cache_attrs = tuple(
+ sorted(
+ d
+ for d in dict_
+ if not d.startswith("__")
+ and d not in ("_cache_key_traversal",)
+ )
+ )
+ super().__init_subclass__()
+
def __init__(self, **kw):
self.__dict__.update(kw)
--- /dev/null
+# sql/cache_key.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from collections import namedtuple
+import enum
+from itertools import zip_longest
+from typing import Callable
+from typing import Union
+
+from .visitors import anon_map
+from .visitors import ExtendedInternalTraversal
+from .visitors import InternalTraversal
+from .. import util
+from ..inspection import inspect
+from ..util import HasMemoized
+from ..util.typing import Literal
+
+
+class CacheConst(enum.Enum):
+ NO_CACHE = 0
+
+
+NO_CACHE = CacheConst.NO_CACHE
+
+
+class CacheTraverseTarget(enum.Enum):
+ CACHE_IN_PLACE = 0
+ CALL_GEN_CACHE_KEY = 1
+ STATIC_CACHE_KEY = 2
+ PROPAGATE_ATTRS = 3
+ ANON_NAME = 4
+
+
+(
+ CACHE_IN_PLACE,
+ CALL_GEN_CACHE_KEY,
+ STATIC_CACHE_KEY,
+ PROPAGATE_ATTRS,
+ ANON_NAME,
+) = tuple(CacheTraverseTarget)
+
+
+class HasCacheKey:
+ """Mixin for objects which can produce a cache key.
+
+ .. seealso::
+
+ :class:`.CacheKey`
+
+ :ref:`sql_caching`
+
+ """
+
+ _cache_key_traversal = NO_CACHE
+
+ _is_has_cache_key = True
+
+ _hierarchy_supports_caching = True
+ """private attribute which may be set to False to prevent the
+ inherit_cache warning from being emitted for a hierarchy of subclasses.
+
+ Currently applies to the DDLElement hierarchy which does not implement
+ caching.
+
+ """
+
+ inherit_cache = None
+ """Indicate if this :class:`.HasCacheKey` instance should make use of the
+ cache key generation scheme used by its immediate superclass.
+
+ The attribute defaults to ``None``, which indicates that a construct has
+ not yet taken into account whether or not its appropriate for it to
+ participate in caching; this is functionally equivalent to setting the
+ value to ``False``, except that a warning is also emitted.
+
+ This flag can be set to ``True`` on a particular class, if the SQL that
+ corresponds to the object does not change based on attributes which
+ are local to this class, and not its superclass.
+
+ .. seealso::
+
+ :ref:`compilerext_caching` - General guideslines for setting the
+ :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user
+ defined SQL constructs.
+
+ """
+
+ __slots__ = ()
+
+ @classmethod
+ def _generate_cache_attrs(cls):
+ """generate cache key dispatcher for a new class.
+
+ This sets the _generated_cache_key_traversal attribute once called
+ so should only be called once per class.
+
+ """
+ inherit_cache = cls.__dict__.get("inherit_cache", None)
+ inherit = bool(inherit_cache)
+
+ if inherit:
+ _cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
+ if _cache_key_traversal is None:
+ try:
+ _cache_key_traversal = cls._traverse_internals
+ except AttributeError:
+ cls._generated_cache_key_traversal = NO_CACHE
+ return NO_CACHE
+
+ # TODO: wouldn't we instead get this from our superclass?
+ # also, our superclass may not have this yet, but in any case,
+ # we'd generate for the superclass that has it. this is a little
+ # more complicated, so for the moment this is a little less
+ # efficient on startup but simpler.
+ return _cache_key_traversal_visitor.generate_dispatch(
+ cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ )
+ else:
+ _cache_key_traversal = cls.__dict__.get(
+ "_cache_key_traversal", None
+ )
+ if _cache_key_traversal is None:
+ _cache_key_traversal = cls.__dict__.get(
+ "_traverse_internals", None
+ )
+ if _cache_key_traversal is None:
+ cls._generated_cache_key_traversal = NO_CACHE
+ if (
+ inherit_cache is None
+ and cls._hierarchy_supports_caching
+ ):
+ util.warn(
+ "Class %s will not make use of SQL compilation "
+ "caching as it does not set the 'inherit_cache' "
+ "attribute to ``True``. This can have "
+ "significant performance implications including "
+ "some performance degradations in comparison to "
+ "prior SQLAlchemy versions. Set this attribute "
+ "to True if this object can make use of the cache "
+ "key generated by the superclass. Alternatively, "
+ "this attribute may be set to False which will "
+ "disable this warning." % (cls.__name__),
+ code="cprf",
+ )
+ return NO_CACHE
+
+ return _cache_key_traversal_visitor.generate_dispatch(
+ cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ )
+
+ @util.preload_module("sqlalchemy.sql.elements")
+ def _gen_cache_key(self, anon_map, bindparams):
+ """return an optional cache key.
+
+ The cache key is a tuple which can contain any series of
+ objects that are hashable and also identifies
+ this object uniquely within the presence of a larger SQL expression
+ or statement, for the purposes of caching the resulting query.
+
+ The cache key should be based on the SQL compiled structure that would
+ ultimately be produced. That is, two structures that are composed in
+ exactly the same way should produce the same cache key; any difference
+ in the structures that would affect the SQL string or the type handlers
+ should result in a different cache key.
+
+ If a structure cannot produce a useful cache key, the NO_CACHE
+ symbol should be added to the anon_map and the method should
+ return None.
+
+ """
+
+ cls = self.__class__
+
+ id_, found = anon_map.get_anon(self)
+ if found:
+ return (id_, cls)
+
+ dispatcher: Union[
+ Literal[CacheConst.NO_CACHE],
+ Callable[[HasCacheKey, "_CacheKeyTraversal"], "CacheKey"],
+ ]
+
+ try:
+ dispatcher = cls.__dict__["_generated_cache_key_traversal"]
+ except KeyError:
+ # most of the dispatchers are generated up front
+ # in sqlalchemy/sql/__init__.py ->
+ # traversals.py-> _preconfigure_traversals().
+ # this block will generate any remaining dispatchers.
+ dispatcher = cls._generate_cache_attrs()
+
+ if dispatcher is NO_CACHE:
+ anon_map[NO_CACHE] = True
+ return None
+
+ result = (id_, cls)
+
+ # inline of _cache_key_traversal_visitor.run_generated_dispatch()
+
+ for attrname, obj, meth in dispatcher(
+ self, _cache_key_traversal_visitor
+ ):
+ if obj is not None:
+ # TODO: see if C code can help here as Python lacks an
+ # efficient switch construct
+
+ if meth is STATIC_CACHE_KEY:
+ sck = obj._static_cache_key
+ if sck is NO_CACHE:
+ anon_map[NO_CACHE] = True
+ return None
+ result += (attrname, sck)
+ elif meth is ANON_NAME:
+ elements = util.preloaded.sql_elements
+ if isinstance(obj, elements._anonymous_label):
+ obj = obj.apply_map(anon_map)
+ result += (attrname, obj)
+ elif meth is CALL_GEN_CACHE_KEY:
+ result += (
+ attrname,
+ obj._gen_cache_key(anon_map, bindparams),
+ )
+
+ # remaining cache functions are against
+ # Python tuples, dicts, lists, etc. so we can skip
+ # if they are empty
+ elif obj:
+ if meth is CACHE_IN_PLACE:
+ result += (attrname, obj)
+ elif meth is PROPAGATE_ATTRS:
+ result += (
+ attrname,
+ obj["compile_state_plugin"],
+ obj["plugin_subject"]._gen_cache_key(
+ anon_map, bindparams
+ )
+ if obj["plugin_subject"]
+ else None,
+ )
+ elif meth is InternalTraversal.dp_annotations_key:
+ # obj is here is the _annotations dict. however, we
+ # want to use the memoized cache key version of it. for
+ # Columns, this should be long lived. For select()
+ # statements, not so much, but they usually won't have
+ # annotations.
+ result += self._annotations_cache_key
+ elif (
+ meth is InternalTraversal.dp_clauseelement_list
+ or meth is InternalTraversal.dp_clauseelement_tuple
+ or meth
+ is InternalTraversal.dp_memoized_select_entities
+ ):
+ result += (
+ attrname,
+ tuple(
+ [
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in obj
+ ]
+ ),
+ )
+ else:
+ result += meth(
+ attrname, obj, self, anon_map, bindparams
+ )
+ return result
+
+ def _generate_cache_key(self):
+ """return a cache key.
+
+ The cache key is a tuple which can contain any series of
+ objects that are hashable and also identifies
+ this object uniquely within the presence of a larger SQL expression
+ or statement, for the purposes of caching the resulting query.
+
+ The cache key should be based on the SQL compiled structure that would
+ ultimately be produced. That is, two structures that are composed in
+ exactly the same way should produce the same cache key; any difference
+ in the structures that would affect the SQL string or the type handlers
+ should result in a different cache key.
+
+ The cache key returned by this method is an instance of
+ :class:`.CacheKey`, which consists of a tuple representing the
+ cache key, as well as a list of :class:`.BindParameter` objects
+ which are extracted from the expression. While two expressions
+ that produce identical cache key tuples will themselves generate
+ identical SQL strings, the list of :class:`.BindParameter` objects
+ indicates the bound values which may have different values in
+ each one; these bound parameters must be consulted in order to
+ execute the statement with the correct parameters.
+
+ a :class:`_expression.ClauseElement` structure that does not implement
+ a :meth:`._gen_cache_key` method and does not implement a
+ :attr:`.traverse_internals` attribute will not be cacheable; when
+ such an element is embedded into a larger structure, this method
+ will return None, indicating no cache key is available.
+
+ """
+
+ bindparams = []
+
+ _anon_map = anon_map()
+ key = self._gen_cache_key(_anon_map, bindparams)
+ if NO_CACHE in _anon_map:
+ return None
+ else:
+ return CacheKey(key, bindparams)
+
+ @classmethod
+ def _generate_cache_key_for_object(cls, obj):
+ bindparams = []
+
+ _anon_map = anon_map()
+ key = obj._gen_cache_key(_anon_map, bindparams)
+ if NO_CACHE in _anon_map:
+ return None
+ else:
+ return CacheKey(key, bindparams)
+
+
+class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
+ @HasMemoized.memoized_instancemethod
+ def _generate_cache_key(self):
+ return HasCacheKey._generate_cache_key(self)
+
+
+class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
+ """The key used to identify a SQL statement construct in the
+ SQL compilation cache.
+
+ .. seealso::
+
+ :ref:`sql_caching`
+
+ """
+
+ def __hash__(self):
+ """CacheKey itself is not hashable - hash the .key portion"""
+
+ return None
+
+ def to_offline_string(self, statement_cache, statement, parameters):
+ """Generate an "offline string" form of this :class:`.CacheKey`
+
+ The "offline string" is basically the string SQL for the
+ statement plus a repr of the bound parameter values in series.
+ Whereas the :class:`.CacheKey` object is dependent on in-memory
+ identities in order to work as a cache key, the "offline" version
+ is suitable for a cache that will work for other processes as well.
+
+ The given ``statement_cache`` is a dictionary-like object where the
+ string form of the statement itself will be cached. This dictionary
+ should be in a longer lived scope in order to reduce the time spent
+ stringifying statements.
+
+
+ """
+ if self.key not in statement_cache:
+ statement_cache[self.key] = sql_str = str(statement)
+ else:
+ sql_str = statement_cache[self.key]
+
+ if not self.bindparams:
+ param_tuple = tuple(parameters[key] for key in sorted(parameters))
+ else:
+ param_tuple = tuple(
+ parameters.get(bindparam.key, bindparam.value)
+ for bindparam in self.bindparams
+ )
+
+ return repr((sql_str, param_tuple))
+
+ def __eq__(self, other):
+ return self.key == other.key
+
+ @classmethod
+ def _diff_tuples(cls, left, right):
+ ck1 = CacheKey(left, [])
+ ck2 = CacheKey(right, [])
+ return ck1._diff(ck2)
+
+ def _whats_different(self, other):
+
+ k1 = self.key
+ k2 = other.key
+
+ stack = []
+ pickup_index = 0
+ while True:
+ s1, s2 = k1, k2
+ for idx in stack:
+ s1 = s1[idx]
+ s2 = s2[idx]
+
+ for idx, (e1, e2) in enumerate(zip_longest(s1, s2)):
+ if idx < pickup_index:
+ continue
+ if e1 != e2:
+ if isinstance(e1, tuple) and isinstance(e2, tuple):
+ stack.append(idx)
+ break
+ else:
+ yield "key%s[%d]: %s != %s" % (
+ "".join("[%d]" % id_ for id_ in stack),
+ idx,
+ e1,
+ e2,
+ )
+ else:
+ pickup_index = stack.pop(-1)
+ break
+
+ def _diff(self, other):
+ return ", ".join(self._whats_different(other))
+
+ def __str__(self):
+ stack = [self.key]
+
+ output = []
+ sentinel = object()
+ indent = -1
+ while stack:
+ elem = stack.pop(0)
+ if elem is sentinel:
+ output.append((" " * (indent * 2)) + "),")
+ indent -= 1
+ elif isinstance(elem, tuple):
+ if not elem:
+ output.append((" " * ((indent + 1) * 2)) + "()")
+ else:
+ indent += 1
+ stack = list(elem) + [sentinel] + stack
+ output.append((" " * (indent * 2)) + "(")
+ else:
+ if isinstance(elem, HasCacheKey):
+ repr_ = "<%s object at %s>" % (
+ type(elem).__name__,
+ hex(id(elem)),
+ )
+ else:
+ repr_ = repr(elem)
+ output.append((" " * (indent * 2)) + " " + repr_ + ", ")
+
+ return "CacheKey(key=%s)" % ("\n".join(output),)
+
+ def _generate_param_dict(self):
+ """used for testing"""
+
+ from .compiler import prefix_anon_map
+
+ _anon_map = prefix_anon_map()
+ return {b.key % _anon_map: b.effective_value for b in self.bindparams}
+
+ def _apply_params_to_element(self, original_cache_key, target_element):
+ translate = {
+ k.key: v.value
+ for k, v in zip(original_cache_key.bindparams, self.bindparams)
+ }
+
+ return target_element.params(translate)
+
+
+class _CacheKeyTraversal(ExtendedInternalTraversal):
+ # very common elements are inlined into the main _get_cache_key() method
+ # to produce a dramatic savings in Python function call overhead
+
+ visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY
+ visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
+ visit_annotations_key = InternalTraversal.dp_annotations_key
+ visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
+ visit_memoized_select_entities = (
+ InternalTraversal.dp_memoized_select_entities
+ )
+
+ visit_string = (
+ visit_boolean
+ ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE
+ visit_statement_hint_list = CACHE_IN_PLACE
+ visit_type = STATIC_CACHE_KEY
+ visit_anon_name = ANON_NAME
+
+ visit_propagate_attrs = PROPAGATE_ATTRS
+
+ def visit_with_context_options(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return tuple((fn.__code__, c_key) for fn, c_key in obj)
+
+ def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
+
+ def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
+ return tuple(obj)
+
+ def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
+ return (
+ attrname,
+ obj._gen_cache_key(anon_map, bindparams)
+ if isinstance(obj, HasCacheKey)
+ else obj,
+ )
+
+ def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
+ return (
+ attrname,
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ if isinstance(elem, HasCacheKey)
+ else elem
+ for elem in obj
+ ),
+ )
+
+ def visit_has_cache_key_tuples(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ return (
+ attrname,
+ tuple(
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in tup_elem
+ )
+ for tup_elem in obj
+ ),
+ )
+
+ def visit_has_cache_key_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ return (
+ attrname,
+ tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+ )
+
+ def visit_executable_options(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ return (
+ attrname,
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in obj
+ if elem._is_has_cache_key
+ ),
+ )
+
+ def visit_inspectable_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return self.visit_has_cache_key_list(
+ attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
+ )
+
+ def visit_clauseelement_tuples(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return self.visit_has_cache_key_tuples(
+ attrname, obj, parent, anon_map, bindparams
+ )
+
+ def visit_fromclause_ordered_set(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ return (
+ attrname,
+ tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]),
+ )
+
+ def visit_clauseelement_unordered_set(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ cache_keys = [
+ elem._gen_cache_key(anon_map, bindparams) for elem in obj
+ ]
+ return (
+ attrname,
+ tuple(
+ sorted(cache_keys)
+ ), # cache keys all start with (id_, class)
+ )
+
+ def visit_named_ddl_element(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (attrname, obj.name)
+
+ def visit_prefix_sequence(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+
+ return (
+ attrname,
+ tuple(
+ [
+ (clause._gen_cache_key(anon_map, bindparams), strval)
+ for clause, strval in obj
+ ]
+ ),
+ )
+
+ def visit_setup_join_tuple(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return tuple(
+ (
+ target._gen_cache_key(anon_map, bindparams),
+ onclause._gen_cache_key(anon_map, bindparams)
+ if onclause is not None
+ else None,
+ from_._gen_cache_key(anon_map, bindparams)
+ if from_ is not None
+ else None,
+ tuple([(key, flags[key]) for key in sorted(flags)]),
+ )
+ for (target, onclause, from_, flags) in obj
+ )
+
+ def visit_table_hint_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+
+ return (
+ attrname,
+ tuple(
+ [
+ (
+ clause._gen_cache_key(anon_map, bindparams),
+ dialect_name,
+ text,
+ )
+ for (clause, dialect_name), text in obj.items()
+ ]
+ ),
+ )
+
+ def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
+
+ def visit_dialect_options(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ dialect_name,
+ tuple(
+ [
+ (key, obj[dialect_name][key])
+ for key in sorted(obj[dialect_name])
+ ]
+ ),
+ )
+ for dialect_name in sorted(obj)
+ ),
+ )
+
+ def visit_string_clauseelement_dict(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (key, obj[key]._gen_cache_key(anon_map, bindparams))
+ for key in sorted(obj)
+ ),
+ )
+
+ def visit_string_multi_dict(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ key,
+ value._gen_cache_key(anon_map, bindparams)
+ if isinstance(value, HasCacheKey)
+ else value,
+ )
+ for key, value in [(key, obj[key]) for key in sorted(obj)]
+ ),
+ )
+
+ def visit_fromclause_canonical_column_collection(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ # inlining into the internals of ColumnCollection
+ return (
+ attrname,
+ tuple(
+ col._gen_cache_key(anon_map, bindparams)
+ for k, col in obj._collection
+ ),
+ )
+
+ def visit_unknown_structure(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ anon_map[NO_CACHE] = True
+ return ()
+
+ def visit_dml_ordered_values(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ key._gen_cache_key(anon_map, bindparams)
+ if hasattr(key, "__clause_element__")
+ else key,
+ value._gen_cache_key(anon_map, bindparams),
+ )
+ for key, value in obj
+ ),
+ )
+
+ def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+ # in py37 we can assume two dictionaries created in the same
+ # insert ordering will retain that sorting
+ return (
+ attrname,
+ tuple(
+ (
+ k._gen_cache_key(anon_map, bindparams)
+ if hasattr(k, "__clause_element__")
+ else k,
+ obj[k]._gen_cache_key(anon_map, bindparams),
+ )
+ for k in obj
+ ),
+ )
+
+ def visit_dml_multi_values(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ # multivalues are simply not cacheable right now
+ anon_map[NO_CACHE] = True
+ return ()
+
+
+_cache_key_traversal_visitor = _CacheKeyTraversal()
from . import visitors
from .base import ExecutableOption
from .base import Options
-from .traversals import HasCacheKey
+from .cache_key import HasCacheKey
from .visitors import Visitable
from .. import exc
from .. import inspection
return self.construct_params()
-class TypeCompiler(metaclass=util.EnsureKWArgType):
+class TypeCompiler(util.EnsureKWArg):
"""Produces DDL specification for TypeEngine objects."""
ensure_kwarg = r"visit_\w+"
from .base import Immutable
from .base import NO_ARG
from .base import SingletonConstant
+from .cache_key import MemoizedHasCacheKey
+from .cache_key import NO_CACHE
from .coercions import _document_text_coercion
from .traversals import HasCopyInternals
-from .traversals import MemoizedHasCacheKey
-from .traversals import NO_CACHE
from .visitors import cloned_traverse
from .visitors import InternalTraversal
from .visitors import traverse
from .base import _select_iterables
from .base import ColumnCollection
from .base import Executable
+from .cache_key import CacheKey
from .dml import Delete
from .dml import Insert
from .dml import Update
from .selectable import TextAsFrom
from .selectable import TextualSelect
from .selectable import Values
-from .traversals import CacheKey
from .visitors import Visitable
from ..util.langhelpers import public_factory
from .selectable import FromClause
from .selectable import Select
from .selectable import TableValuedAlias
+from .type_api import TypeEngine
from .visitors import InternalTraversal
-from .visitors import TraversibleType
from .. import util
def register_function(identifier, fn, package="_default"):
"""Associate a callable with a particular func. name.
- This is normally called by _GenericMeta, but is also
+ This is normally called by GenericFunction, but is also
available by itself so that a non-Function construct
can be associated with the :data:`.func` accessor (i.e.
CAST, EXTRACT).
("type", InternalTraversal.dp_type),
]
- type = sqltypes.NULLTYPE
+ name: str
+
+ identifier: str
+
+ type: TypeEngine = sqltypes.NULLTYPE
"""A :class:`_types.TypeEngine` object which refers to the SQL return
type represented by this SQL function.
)
-class _GenericMeta(TraversibleType):
- def __init__(cls, clsname, bases, clsdict):
- if annotation.Annotated not in cls.__mro__:
- cls.name = name = clsdict.get("name", clsname)
- cls.identifier = identifier = clsdict.get("identifier", name)
- package = clsdict.pop("package", "_default")
- # legacy
- if "__return_type__" in clsdict:
- cls.type = clsdict["__return_type__"]
-
- # Check _register attribute status
- cls._register = getattr(cls, "_register", True)
-
- # Register the function if required
- if cls._register:
- register_function(identifier, cls, package)
- else:
- # Set _register to True to register child classes by default
- cls._register = True
-
- super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
-
-
-class GenericFunction(Function, metaclass=_GenericMeta):
+class GenericFunction(Function):
"""Define a 'generic' function.
A generic function is a pre-established :class:`.Function`
"""
coerce_arguments = True
- _register = False
inherit_cache = True
+ name = "GenericFunction"
+
+ def __init_subclass__(cls) -> None:
+ if annotation.Annotated not in cls.__mro__:
+ cls._register_generic_function(cls.__name__, cls.__dict__)
+ super().__init_subclass__()
+
+ @classmethod
+ def _register_generic_function(cls, clsname, clsdict):
+ cls.name = name = clsdict.get("name", clsname)
+ cls.identifier = identifier = clsdict.get("identifier", name)
+ package = clsdict.get("package", "_default")
+ # legacy
+ if "__return_type__" in clsdict:
+ cls.type = clsdict["__return_type__"]
+
+ # Check _register attribute status
+ cls._register = getattr(cls, "_register", True)
+
+ # Register the function if required
+ if cls._register:
+ register_function(identifier, cls, package)
+ else:
+ # Set _register to True to register child classes by default
+ cls._register = True
+
def __init__(self, *args, **kwargs):
parsed_args = kwargs.pop("_parsed_args", None)
if parsed_args is None:
self.clause_expr = ClauseList(
operator=operators.comma_op, group_contents=True, *parsed_args
).self_group()
+
self.type = sqltypes.to_instance(
kwargs.pop("type_", None) or getattr(self, "type", None)
)
import types
import weakref
+from . import cache_key as _cache_key
from . import coercions
from . import elements
from . import roles
else:
parent_closure_cache_key = ()
- if parent_closure_cache_key is not traversals.NO_CACHE:
+ if parent_closure_cache_key is not _cache_key.NO_CACHE:
anon_map = traversals.anon_map()
cache_key = tuple(
[
]
)
- if traversals.NO_CACHE not in anon_map:
+ if _cache_key.NO_CACHE not in anon_map:
cache_key = parent_closure_cache_key + cache_key
self.closure_cache_key = cache_key
except KeyError:
rec = None
else:
- cache_key = traversals.NO_CACHE
+ cache_key = _cache_key.NO_CACHE
rec = None
else:
- cache_key = traversals.NO_CACHE
+ cache_key = _cache_key.NO_CACHE
rec = None
self.closure_cache_key = cache_key
if rec is None:
- if cache_key is not traversals.NO_CACHE:
+ if cache_key is not _cache_key.NO_CACHE:
rec = AnalyzedFunction(
tracker, self, apply_propagate_attrs, fn
)
self._rec = rec
- if cache_key is not traversals.NO_CACHE:
+ if cache_key is not _cache_key.NO_CACHE:
if self.parent_lambda is not None:
bindparams[:0] = self.parent_lambda._resolved_bindparams
return expr
def _gen_cache_key(self, anon_map, bindparams):
- if self.closure_cache_key is traversals.NO_CACHE:
- anon_map[traversals.NO_CACHE] = True
+ if self.closure_cache_key is _cache_key.NO_CACHE:
+ anon_map[_cache_key.NO_CACHE] = True
return None
cache_key = (
for tup_elem in opts.track_on[idx]
)
- elif isinstance(elem, traversals.HasCacheKey):
+ elif isinstance(elem, _cache_key.HasCacheKey):
def get(closure, opts, anon_map, bindparams):
return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
"""
- if isinstance(cell_contents, traversals.HasCacheKey):
+ if isinstance(cell_contents, _cache_key.HasCacheKey):
def get(closure, opts, anon_map, bindparams):
and not isinstance(
# TODO: coverage where an ORM option or similar is here
value,
- traversals.HasCacheKey,
+ _cache_key.HasCacheKey,
)
):
name = object.__getattribute__(self, "_name")
from typing import Type
from typing import Union
+from . import cache_key
from . import coercions
from . import operators
from . import roles
class _MemoizedSelectEntities(
- traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible
+ cache_key.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible
):
__visit_name__ = "memoized_select_entities"
from . import type_api
from .base import NO_ARG
from .base import SchemaEventTarget
+from .cache_key import HasCacheKey
from .elements import _NONE_NAME
from .elements import quoted_name
from .elements import Slice
from .elements import TypeCoerce as type_coerce # noqa
-from .traversals import HasCacheKey
from .traversals import InternalTraversal
from .type_api import Emulated
from .type_api import NativeForEmulated # noqa
+# sql/traversals.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
from collections import deque
-from collections import namedtuple
import collections.abc as collections_abc
import itertools
from itertools import zip_longest
import operator
from . import operators
-from .visitors import ExtendedInternalTraversal
+from .visitors import anon_map
from .visitors import InternalTraversal
from .. import util
-from ..inspection import inspect
-from ..util import HasMemoized
-
-try:
- from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa
-except ImportError:
- from ._py_util import cache_anon_map as anon_map # noqa
SKIP_TRAVERSE = util.symbol("skip_traverse")
COMPARE_FAILED = False
COMPARE_SUCCEEDED = True
-NO_CACHE = util.symbol("no_cache")
-CACHE_IN_PLACE = util.symbol("cache_in_place")
-CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key")
-STATIC_CACHE_KEY = util.symbol("static_cache_key")
-PROPAGATE_ATTRS = util.symbol("propagate_attrs")
-ANON_NAME = util.symbol("anon_name")
def compare(obj1, obj2, **kw):
)
-class HasCacheKey:
- """Mixin for objects which can produce a cache key.
-
- .. seealso::
-
- :class:`.CacheKey`
-
- :ref:`sql_caching`
-
- """
-
- _cache_key_traversal = NO_CACHE
-
- _is_has_cache_key = True
-
- _hierarchy_supports_caching = True
- """private attribute which may be set to False to prevent the
- inherit_cache warning from being emitted for a hierarchy of subclasses.
-
- Currently applies to the DDLElement hierarchy which does not implement
- caching.
-
- """
-
- inherit_cache = None
- """Indicate if this :class:`.HasCacheKey` instance should make use of the
- cache key generation scheme used by its immediate superclass.
-
- The attribute defaults to ``None``, which indicates that a construct has
- not yet taken into account whether or not its appropriate for it to
- participate in caching; this is functionally equivalent to setting the
- value to ``False``, except that a warning is also emitted.
-
- This flag can be set to ``True`` on a particular class, if the SQL that
- corresponds to the object does not change based on attributes which
- are local to this class, and not its superclass.
-
- .. seealso::
-
- :ref:`compilerext_caching` - General guideslines for setting the
- :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user
- defined SQL constructs.
-
- """
-
- __slots__ = ()
-
- @classmethod
- def _generate_cache_attrs(cls):
- """generate cache key dispatcher for a new class.
-
- This sets the _generated_cache_key_traversal attribute once called
- so should only be called once per class.
-
- """
- inherit_cache = cls.__dict__.get("inherit_cache", None)
- inherit = bool(inherit_cache)
-
- if inherit:
- _cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
- if _cache_key_traversal is None:
- try:
- _cache_key_traversal = cls._traverse_internals
- except AttributeError:
- cls._generated_cache_key_traversal = NO_CACHE
- return NO_CACHE
-
- # TODO: wouldn't we instead get this from our superclass?
- # also, our superclass may not have this yet, but in any case,
- # we'd generate for the superclass that has it. this is a little
- # more complicated, so for the moment this is a little less
- # efficient on startup but simpler.
- return _cache_key_traversal_visitor.generate_dispatch(
- cls, _cache_key_traversal, "_generated_cache_key_traversal"
- )
- else:
- _cache_key_traversal = cls.__dict__.get(
- "_cache_key_traversal", None
- )
- if _cache_key_traversal is None:
- _cache_key_traversal = cls.__dict__.get(
- "_traverse_internals", None
- )
- if _cache_key_traversal is None:
- cls._generated_cache_key_traversal = NO_CACHE
- if (
- inherit_cache is None
- and cls._hierarchy_supports_caching
- ):
- util.warn(
- "Class %s will not make use of SQL compilation "
- "caching as it does not set the 'inherit_cache' "
- "attribute to ``True``. This can have "
- "significant performance implications including "
- "some performance degradations in comparison to "
- "prior SQLAlchemy versions. Set this attribute "
- "to True if this object can make use of the cache "
- "key generated by the superclass. Alternatively, "
- "this attribute may be set to False which will "
- "disable this warning." % (cls.__name__),
- code="cprf",
- )
- return NO_CACHE
-
- return _cache_key_traversal_visitor.generate_dispatch(
- cls, _cache_key_traversal, "_generated_cache_key_traversal"
- )
-
- @util.preload_module("sqlalchemy.sql.elements")
- def _gen_cache_key(self, anon_map, bindparams):
- """return an optional cache key.
-
- The cache key is a tuple which can contain any series of
- objects that are hashable and also identifies
- this object uniquely within the presence of a larger SQL expression
- or statement, for the purposes of caching the resulting query.
-
- The cache key should be based on the SQL compiled structure that would
- ultimately be produced. That is, two structures that are composed in
- exactly the same way should produce the same cache key; any difference
- in the structures that would affect the SQL string or the type handlers
- should result in a different cache key.
-
- If a structure cannot produce a useful cache key, the NO_CACHE
- symbol should be added to the anon_map and the method should
- return None.
-
- """
-
- cls = self.__class__
-
- id_, found = anon_map.get_anon(self)
- if found:
- return (id_, cls)
-
- try:
- dispatcher = cls.__dict__["_generated_cache_key_traversal"]
- except KeyError:
- # most of the dispatchers are generated up front
- # in sqlalchemy/sql/__init__.py ->
- # traversals.py-> _preconfigure_traversals().
- # this block will generate any remaining dispatchers.
- dispatcher = cls._generate_cache_attrs()
-
- if dispatcher is NO_CACHE:
- anon_map[NO_CACHE] = True
- return None
-
- result = (id_, cls)
-
- # inline of _cache_key_traversal_visitor.run_generated_dispatch()
-
- for attrname, obj, meth in dispatcher(
- self, _cache_key_traversal_visitor
- ):
- if obj is not None:
- # TODO: see if C code can help here as Python lacks an
- # efficient switch construct
-
- if meth is STATIC_CACHE_KEY:
- sck = obj._static_cache_key
- if sck is NO_CACHE:
- anon_map[NO_CACHE] = True
- return None
- result += (attrname, sck)
- elif meth is ANON_NAME:
- elements = util.preloaded.sql_elements
- if isinstance(obj, elements._anonymous_label):
- obj = obj.apply_map(anon_map)
- result += (attrname, obj)
- elif meth is CALL_GEN_CACHE_KEY:
- result += (
- attrname,
- obj._gen_cache_key(anon_map, bindparams),
- )
-
- # remaining cache functions are against
- # Python tuples, dicts, lists, etc. so we can skip
- # if they are empty
- elif obj:
- if meth is CACHE_IN_PLACE:
- result += (attrname, obj)
- elif meth is PROPAGATE_ATTRS:
- result += (
- attrname,
- obj["compile_state_plugin"],
- obj["plugin_subject"]._gen_cache_key(
- anon_map, bindparams
- )
- if obj["plugin_subject"]
- else None,
- )
- elif meth is InternalTraversal.dp_annotations_key:
- # obj is here is the _annotations dict. however, we
- # want to use the memoized cache key version of it. for
- # Columns, this should be long lived. For select()
- # statements, not so much, but they usually won't have
- # annotations.
- result += self._annotations_cache_key
- elif (
- meth is InternalTraversal.dp_clauseelement_list
- or meth is InternalTraversal.dp_clauseelement_tuple
- or meth
- is InternalTraversal.dp_memoized_select_entities
- ):
- result += (
- attrname,
- tuple(
- [
- elem._gen_cache_key(anon_map, bindparams)
- for elem in obj
- ]
- ),
- )
- else:
- result += meth(
- attrname, obj, self, anon_map, bindparams
- )
- return result
-
- def _generate_cache_key(self):
- """return a cache key.
-
- The cache key is a tuple which can contain any series of
- objects that are hashable and also identifies
- this object uniquely within the presence of a larger SQL expression
- or statement, for the purposes of caching the resulting query.
-
- The cache key should be based on the SQL compiled structure that would
- ultimately be produced. That is, two structures that are composed in
- exactly the same way should produce the same cache key; any difference
- in the structures that would affect the SQL string or the type handlers
- should result in a different cache key.
-
- The cache key returned by this method is an instance of
- :class:`.CacheKey`, which consists of a tuple representing the
- cache key, as well as a list of :class:`.BindParameter` objects
- which are extracted from the expression. While two expressions
- that produce identical cache key tuples will themselves generate
- identical SQL strings, the list of :class:`.BindParameter` objects
- indicates the bound values which may have different values in
- each one; these bound parameters must be consulted in order to
- execute the statement with the correct parameters.
-
- a :class:`_expression.ClauseElement` structure that does not implement
- a :meth:`._gen_cache_key` method and does not implement a
- :attr:`.traverse_internals` attribute will not be cacheable; when
- such an element is embedded into a larger structure, this method
- will return None, indicating no cache key is available.
-
- """
-
- bindparams = []
-
- _anon_map = anon_map()
- key = self._gen_cache_key(_anon_map, bindparams)
- if NO_CACHE in _anon_map:
- return None
- else:
- return CacheKey(key, bindparams)
-
- @classmethod
- def _generate_cache_key_for_object(cls, obj):
- bindparams = []
-
- _anon_map = anon_map()
- key = obj._gen_cache_key(_anon_map, bindparams)
- if NO_CACHE in _anon_map:
- return None
- else:
- return CacheKey(key, bindparams)
-
-
-class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
- @HasMemoized.memoized_instancemethod
- def _generate_cache_key(self):
- return HasCacheKey._generate_cache_key(self)
-
-
-class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
- """The key used to identify a SQL statement construct in the
- SQL compilation cache.
-
- .. seealso::
-
- :ref:`sql_caching`
-
- """
-
- def __hash__(self):
- """CacheKey itself is not hashable - hash the .key portion"""
-
- return None
-
- def to_offline_string(self, statement_cache, statement, parameters):
- """Generate an "offline string" form of this :class:`.CacheKey`
-
- The "offline string" is basically the string SQL for the
- statement plus a repr of the bound parameter values in series.
- Whereas the :class:`.CacheKey` object is dependent on in-memory
- identities in order to work as a cache key, the "offline" version
- is suitable for a cache that will work for other processes as well.
-
- The given ``statement_cache`` is a dictionary-like object where the
- string form of the statement itself will be cached. This dictionary
- should be in a longer lived scope in order to reduce the time spent
- stringifying statements.
-
-
- """
- if self.key not in statement_cache:
- statement_cache[self.key] = sql_str = str(statement)
- else:
- sql_str = statement_cache[self.key]
-
- if not self.bindparams:
- param_tuple = tuple(parameters[key] for key in sorted(parameters))
- else:
- param_tuple = tuple(
- parameters.get(bindparam.key, bindparam.value)
- for bindparam in self.bindparams
- )
-
- return repr((sql_str, param_tuple))
-
- def __eq__(self, other):
- return self.key == other.key
-
- @classmethod
- def _diff_tuples(cls, left, right):
- ck1 = CacheKey(left, [])
- ck2 = CacheKey(right, [])
- return ck1._diff(ck2)
-
- def _whats_different(self, other):
-
- k1 = self.key
- k2 = other.key
-
- stack = []
- pickup_index = 0
- while True:
- s1, s2 = k1, k2
- for idx in stack:
- s1 = s1[idx]
- s2 = s2[idx]
-
- for idx, (e1, e2) in enumerate(zip_longest(s1, s2)):
- if idx < pickup_index:
- continue
- if e1 != e2:
- if isinstance(e1, tuple) and isinstance(e2, tuple):
- stack.append(idx)
- break
- else:
- yield "key%s[%d]: %s != %s" % (
- "".join("[%d]" % id_ for id_ in stack),
- idx,
- e1,
- e2,
- )
- else:
- pickup_index = stack.pop(-1)
- break
-
- def _diff(self, other):
- return ", ".join(self._whats_different(other))
-
- def __str__(self):
- stack = [self.key]
-
- output = []
- sentinel = object()
- indent = -1
- while stack:
- elem = stack.pop(0)
- if elem is sentinel:
- output.append((" " * (indent * 2)) + "),")
- indent -= 1
- elif isinstance(elem, tuple):
- if not elem:
- output.append((" " * ((indent + 1) * 2)) + "()")
- else:
- indent += 1
- stack = list(elem) + [sentinel] + stack
- output.append((" " * (indent * 2)) + "(")
- else:
- if isinstance(elem, HasCacheKey):
- repr_ = "<%s object at %s>" % (
- type(elem).__name__,
- hex(id(elem)),
- )
- else:
- repr_ = repr(elem)
- output.append((" " * (indent * 2)) + " " + repr_ + ", ")
-
- return "CacheKey(key=%s)" % ("\n".join(output),)
-
- def _generate_param_dict(self):
- """used for testing"""
-
- from .compiler import prefix_anon_map
-
- _anon_map = prefix_anon_map()
- return {b.key % _anon_map: b.effective_value for b in self.bindparams}
-
- def _apply_params_to_element(self, original_cache_key, target_element):
- translate = {
- k.key: v.value
- for k, v in zip(original_cache_key.bindparams, self.bindparams)
- }
-
- return target_element.params(translate)
-
-
def _clone(element, **kw):
return element._clone()
-class _CacheKey(ExtendedInternalTraversal):
- # very common elements are inlined into the main _get_cache_key() method
- # to produce a dramatic savings in Python function call overhead
-
- visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY
- visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
- visit_annotations_key = InternalTraversal.dp_annotations_key
- visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
- visit_memoized_select_entities = (
- InternalTraversal.dp_memoized_select_entities
- )
-
- visit_string = (
- visit_boolean
- ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE
- visit_statement_hint_list = CACHE_IN_PLACE
- visit_type = STATIC_CACHE_KEY
- visit_anon_name = ANON_NAME
-
- visit_propagate_attrs = PROPAGATE_ATTRS
-
- def visit_with_context_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- return tuple((fn.__code__, c_key) for fn, c_key in obj)
-
- def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
- return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
-
- def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
- return tuple(obj)
-
- def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
- return (
- attrname,
- obj._gen_cache_key(anon_map, bindparams)
- if isinstance(obj, HasCacheKey)
- else obj,
- )
-
- def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
- return (
- attrname,
- tuple(
- elem._gen_cache_key(anon_map, bindparams)
- if isinstance(elem, HasCacheKey)
- else elem
- for elem in obj
- ),
- )
-
- def visit_has_cache_key_tuples(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- if not obj:
- return ()
- return (
- attrname,
- tuple(
- tuple(
- elem._gen_cache_key(anon_map, bindparams)
- for elem in tup_elem
- )
- for tup_elem in obj
- ),
- )
-
- def visit_has_cache_key_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- if not obj:
- return ()
- return (
- attrname,
- tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
- )
-
- def visit_executable_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- if not obj:
- return ()
- return (
- attrname,
- tuple(
- elem._gen_cache_key(anon_map, bindparams)
- for elem in obj
- if elem._is_has_cache_key
- ),
- )
-
- def visit_inspectable_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- return self.visit_has_cache_key_list(
- attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
- )
-
- def visit_clauseelement_tuples(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- return self.visit_has_cache_key_tuples(
- attrname, obj, parent, anon_map, bindparams
- )
-
- def visit_fromclause_ordered_set(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- if not obj:
- return ()
- return (
- attrname,
- tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]),
- )
-
- def visit_clauseelement_unordered_set(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- if not obj:
- return ()
- cache_keys = [
- elem._gen_cache_key(anon_map, bindparams) for elem in obj
- ]
- return (
- attrname,
- tuple(
- sorted(cache_keys)
- ), # cache keys all start with (id_, class)
- )
-
- def visit_named_ddl_element(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- return (attrname, obj.name)
-
- def visit_prefix_sequence(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- if not obj:
- return ()
-
- return (
- attrname,
- tuple(
- [
- (clause._gen_cache_key(anon_map, bindparams), strval)
- for clause, strval in obj
- ]
- ),
- )
-
- def visit_setup_join_tuple(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- is_legacy = "legacy" in attrname
-
- return tuple(
- (
- target
- if is_legacy and isinstance(target, str)
- else target._gen_cache_key(anon_map, bindparams),
- onclause
- if is_legacy and isinstance(onclause, str)
- else onclause._gen_cache_key(anon_map, bindparams)
- if onclause is not None
- else None,
- from_._gen_cache_key(anon_map, bindparams)
- if from_ is not None
- else None,
- tuple([(key, flags[key]) for key in sorted(flags)]),
- )
- for (target, onclause, from_, flags) in obj
- )
-
- def visit_table_hint_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- if not obj:
- return ()
-
- return (
- attrname,
- tuple(
- [
- (
- clause._gen_cache_key(anon_map, bindparams),
- dialect_name,
- text,
- )
- for (clause, dialect_name), text in obj.items()
- ]
- ),
- )
-
- def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
- return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
-
- def visit_dialect_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- return (
- attrname,
- tuple(
- (
- dialect_name,
- tuple(
- [
- (key, obj[dialect_name][key])
- for key in sorted(obj[dialect_name])
- ]
- ),
- )
- for dialect_name in sorted(obj)
- ),
- )
-
- def visit_string_clauseelement_dict(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- return (
- attrname,
- tuple(
- (key, obj[key]._gen_cache_key(anon_map, bindparams))
- for key in sorted(obj)
- ),
- )
-
- def visit_string_multi_dict(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- return (
- attrname,
- tuple(
- (
- key,
- value._gen_cache_key(anon_map, bindparams)
- if isinstance(value, HasCacheKey)
- else value,
- )
- for key, value in [(key, obj[key]) for key in sorted(obj)]
- ),
- )
-
- def visit_fromclause_canonical_column_collection(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- # inlining into the internals of ColumnCollection
- return (
- attrname,
- tuple(
- col._gen_cache_key(anon_map, bindparams)
- for k, col in obj._collection
- ),
- )
-
- def visit_unknown_structure(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- anon_map[NO_CACHE] = True
- return ()
-
- def visit_dml_ordered_values(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- return (
- attrname,
- tuple(
- (
- key._gen_cache_key(anon_map, bindparams)
- if hasattr(key, "__clause_element__")
- else key,
- value._gen_cache_key(anon_map, bindparams),
- )
- for key, value in obj
- ),
- )
-
- def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
- # in py37 we can assume two dictionaries created in the same
- # insert ordering will retain that sorting
- return (
- attrname,
- tuple(
- (
- k._gen_cache_key(anon_map, bindparams)
- if hasattr(k, "__clause_element__")
- else k,
- obj[k]._gen_cache_key(anon_map, bindparams),
- )
- for k in obj
- ),
- )
-
- def visit_dml_multi_values(
- self, attrname, obj, parent, anon_map, bindparams
- ):
- # multivalues are simply not cacheable right now
- anon_map[NO_CACHE] = True
- return ()
-
-
-_cache_key_traversal_visitor = _CacheKey()
-
-
class HasCopyInternals:
__slots__ = ()
setattr(self, attrname, result)
-class _CopyInternals(InternalTraversal):
+class _CopyInternalsTraversal(InternalTraversal):
"""Generate a _copy_internals internal traversal dispatch for classes
with a _traverse_internals collection."""
return element
-_copy_internals = _CopyInternals()
+_copy_internals = _CopyInternalsTraversal()
def _flatten_clauseelement(element):
return element
-class _GetChildren(InternalTraversal):
+class _GetChildrenTraversal(InternalTraversal):
"""Generate a _children_traversal internal traversal dispatch for classes
with a _traverse_internals collection."""
return ()
-_get_children = _GetChildren()
+_get_children = _GetChildrenTraversal()
@util.preload_module("sqlalchemy.sql.elements")
from . import operators
from .base import SchemaEventTarget
-from .traversals import NO_CACHE
+from .cache_key import NO_CACHE
from .visitors import Traversible
-from .visitors import TraversibleType
from .. import exc
from .. import util
return util.generic_repr(self)
-class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType):
- pass
-
-
class ExternalType:
"""mixin that defines attributes and behaviors specific to third-party
datatypes.
return NO_CACHE
-class UserDefinedType(ExternalType, TypeEngine, metaclass=VisitableCheckKWArg):
+class UserDefinedType(ExternalType, TypeEngine, util.EnsureKWArg):
"""Base for user defined types.
This should be the base of new types. Note that
from .base import _expand_cloned
from .base import _from_objects
from .base import ColumnSet
+from .cache_key import HasCacheKey # noqa
from .ddl import sort_tables # noqa
from .elements import _find_columns # noqa
from .elements import _label_reference
from .selectable import ScalarSelect
from .selectable import SelectBase
from .selectable import TableClause
-from .traversals import HasCacheKey # noqa
from .. import exc
from .. import util
from ..util import langhelpers
from ..util import symbol
+try:
+ from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa
+except ImportError:
+ from ._py_util import cache_anon_map as anon_map # noqa
+
__all__ = [
"iterate",
"traverse_using",
"cloned_traverse",
"replacement_traverse",
"Traversible",
- "TraversibleType",
"ExternalTraversal",
"InternalTraversal",
]
-def _generate_compiler_dispatch(cls):
- """Generate a _compiler_dispatch() external traversal on classes with a
- __visit_name__ attribute.
-
- """
- visit_name = cls.__visit_name__
-
- if "_compiler_dispatch" in cls.__dict__:
- # class has a fixed _compiler_dispatch() method.
- # copy it to "original" so that we can get it back if
- # sqlalchemy.ext.compiles overrides it.
- cls._original_compiler_dispatch = cls._compiler_dispatch
- return
-
- if not isinstance(visit_name, str):
- raise exc.InvalidRequestError(
- "__visit_name__ on class %s must be a string at the class level"
- % cls.__name__
- )
-
- name = "visit_%s" % visit_name
- getter = operator.attrgetter(name)
-
- def _compiler_dispatch(self, visitor, **kw):
- """Look for an attribute named "visit_<visit_name>" on the
- visitor, and call it with the same kw params.
-
- """
- try:
- meth = getter(visitor)
- except AttributeError as err:
- return visitor.visit_unsupported_compilation(self, err, **kw)
-
- else:
- return meth(self, **kw)
-
- cls._compiler_dispatch = (
- cls._original_compiler_dispatch
- ) = _compiler_dispatch
+class Traversible:
+ """Base class for visitable objects."""
+ __slots__ = ()
-class TraversibleType(type):
- """Metaclass which assigns dispatch attributes to various kinds of
- "visitable" classes.
+ __visit_name__: str
- Attributes include:
+ def __init_subclass__(cls) -> None:
+ if "__visit_name__" in cls.__dict__:
+ cls._generate_compiler_dispatch()
+ super().__init_subclass__()
- * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``.
- This is called "external traversal" because the caller of each visit()
- method is responsible for sub-traversing the inner elements of each
- object. This is appropriate for string compilers and other traversals
- that need to call upon the inner elements in a specific pattern.
+ @classmethod
+ def _generate_compiler_dispatch(cls):
+ """Assign dispatch attributes to various kinds of
+ "visitable" classes.
- * internal traversal collections ``_children_traversal``,
- ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from
- an optional ``_traverse_internals`` collection of symbols which comes
- from the :class:`.InternalTraversal` list of symbols. This is called
- "internal traversal" MARKMARK
+ Attributes include:
- """
+ * The ``_compiler_dispatch`` method, corresponding to
+ ``__visit_name__``. This is called "external traversal" because the
+ caller of each visit() method is responsible for sub-traversing the
+ inner elements of each object. This is appropriate for string
+ compilers and other traversals that need to call upon the inner
+ elements in a specific pattern.
- def __init__(cls, clsname, bases, clsdict):
- if clsname != "Traversible":
- if "__visit_name__" in clsdict:
- _generate_compiler_dispatch(cls)
+ * internal traversal collections ``_children_traversal``,
+ ``_cache_key_traversal``, ``_copy_internals_traversal``, generated
+ from an optional ``_traverse_internals`` collection of symbols which
+ comes from the :class:`.InternalTraversal` list of symbols. This is
+ called "internal traversal".
- super(TraversibleType, cls).__init__(clsname, bases, clsdict)
+ """
+ visit_name = cls.__visit_name__
+
+ if "_compiler_dispatch" in cls.__dict__:
+ # class has a fixed _compiler_dispatch() method.
+ # copy it to "original" so that we can get it back if
+ # sqlalchemy.ext.compiles overrides it.
+ cls._original_compiler_dispatch = cls._compiler_dispatch
+ return
+
+ if not isinstance(visit_name, str):
+ raise exc.InvalidRequestError(
+ f"__visit_name__ on class {cls.__name__} must be a string "
+ "at the class level"
+ )
+ name = "visit_%s" % visit_name
+ getter = operator.attrgetter(name)
-class Traversible(metaclass=TraversibleType):
- """Base class for visitable objects, applies the
- :class:`.visitors.TraversibleType` metaclass.
+ def _compiler_dispatch(self, visitor, **kw):
+ """Look for an attribute named "visit_<visit_name>" on the
+ visitor, and call it with the same kw params.
- """
+ """
+ try:
+ meth = getter(visitor)
+ except AttributeError as err:
+ return visitor.visit_unsupported_compilation(self, err, **kw)
+ else:
+ return meth(self, **kw)
- __slots__ = ()
+ cls._compiler_dispatch = (
+ cls._original_compiler_dispatch
+ ) = _compiler_dispatch
def __class_getitem__(cls, key):
# allow generic classes in py3.9+
)
-class _InternalTraversalType(type):
- def __init__(cls, clsname, bases, clsdict):
- if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"):
- lookup = {}
- for key, sym in clsdict.items():
- if key.startswith("dp_"):
- visit_key = key.replace("dp_", "visit_")
- sym_name = sym.name
- assert sym_name not in lookup, sym_name
- lookup[sym] = lookup[sym_name] = visit_key
- if hasattr(cls, "_dispatch_lookup"):
- lookup.update(cls._dispatch_lookup)
- cls._dispatch_lookup = lookup
+class _HasTraversalDispatch:
+ r"""Define infrastructure for the :class:`.InternalTraversal` class.
- super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict)
+ .. versionadded:: 2.0
+ """
-def _generate_dispatcher(visitor, internal_dispatch, method_name):
- names = []
- for attrname, visit_sym in internal_dispatch:
- meth = visitor.dispatch(visit_sym)
- if meth:
- visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym]
- names.append((attrname, visit_name))
-
- code = (
- (" return [\n")
- + (
- ", \n".join(
- " (%r, self.%s, visitor.%s)"
- % (attrname, attrname, visit_name)
- for attrname, visit_name in names
+ def __init_subclass__(cls) -> None:
+ cls._generate_traversal_dispatch()
+ super().__init_subclass__()
+
+ def dispatch(self, visit_symbol):
+ """Given a method from :class:`._HasTraversalDispatch`, return the
+ corresponding method on a subclass.
+
+ """
+ name = self._dispatch_lookup[visit_symbol]
+ return getattr(self, name, None)
+
+ def run_generated_dispatch(
+ self, target, internal_dispatch, generate_dispatcher_name
+ ):
+ try:
+ dispatcher = target.__class__.__dict__[generate_dispatcher_name]
+ except KeyError:
+ # most of the dispatchers are generated up front
+ # in sqlalchemy/sql/__init__.py ->
+ # traversals.py-> _preconfigure_traversals().
+ # this block will generate any remaining dispatchers.
+ dispatcher = self.generate_dispatch(
+ target.__class__, internal_dispatch, generate_dispatcher_name
+ )
+ return dispatcher(target, self)
+
+ def generate_dispatch(
+ self, target_cls, internal_dispatch, generate_dispatcher_name
+ ):
+ dispatcher = self._generate_dispatcher(
+ internal_dispatch, generate_dispatcher_name
+ )
+ # assert isinstance(target_cls, type)
+ setattr(target_cls, generate_dispatcher_name, dispatcher)
+ return dispatcher
+
+ @classmethod
+ def _generate_traversal_dispatch(cls):
+ lookup = {}
+ clsdict = cls.__dict__
+ for key, sym in clsdict.items():
+ if key.startswith("dp_"):
+ visit_key = key.replace("dp_", "visit_")
+ sym_name = sym.name
+ assert sym_name not in lookup, sym_name
+ lookup[sym] = lookup[sym_name] = visit_key
+ if hasattr(cls, "_dispatch_lookup"):
+ lookup.update(cls._dispatch_lookup)
+ cls._dispatch_lookup = lookup
+
+ def _generate_dispatcher(self, internal_dispatch, method_name):
+ names = []
+ for attrname, visit_sym in internal_dispatch:
+ meth = self.dispatch(visit_sym)
+ if meth:
+ visit_name = ExtendedInternalTraversal._dispatch_lookup[
+ visit_sym
+ ]
+ names.append((attrname, visit_name))
+
+ code = (
+ (" return [\n")
+ + (
+ ", \n".join(
+ " (%r, self.%s, visitor.%s)"
+ % (attrname, attrname, visit_name)
+ for attrname, visit_name in names
+ )
)
+ + ("\n ]\n")
)
- + ("\n ]\n")
- )
- meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
- # print(meth_text)
- return langhelpers._exec_code_in_env(meth_text, {}, method_name)
+ meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
+ return langhelpers._exec_code_in_env(meth_text, {}, method_name)
-class InternalTraversal(metaclass=_InternalTraversalType):
+class InternalTraversal(_HasTraversalDispatch):
r"""Defines visitor symbols used for internal traversal.
The :class:`.InternalTraversal` class is used in two ways. One is that
"""
- def dispatch(self, visit_symbol):
- """Given a method from :class:`.InternalTraversal`, return the
- corresponding method on a subclass.
-
- """
- name = self._dispatch_lookup[visit_symbol]
- return getattr(self, name, None)
-
- def run_generated_dispatch(
- self, target, internal_dispatch, generate_dispatcher_name
- ):
- try:
- dispatcher = target.__class__.__dict__[generate_dispatcher_name]
- except KeyError:
- # most of the dispatchers are generated up front
- # in sqlalchemy/sql/__init__.py ->
- # traversals.py-> _preconfigure_traversals().
- # this block will generate any remaining dispatchers.
- dispatcher = self.generate_dispatch(
- target.__class__, internal_dispatch, generate_dispatcher_name
- )
- return dispatcher(target, self)
-
- def generate_dispatch(
- self, target_cls, internal_dispatch, generate_dispatcher_name
- ):
- dispatcher = _generate_dispatcher(
- self, internal_dispatch, generate_dispatcher_name
- )
- # assert isinstance(target_cls, type)
- setattr(target_cls, generate_dispatcher_name, dispatcher)
- return dispatcher
-
dp_has_cache_key = symbol("HC")
"""Visit a :class:`.HasCacheKey` object."""
# backwards compatibility
Visitable = Traversible
-VisitableType = TraversibleType
ClauseVisitor = ExternalTraversal
CloningVisitor = CloningExternalTraversal
ReplacingCloningVisitor = ReplacingExternalTraversal
from .. import util
from ..orm import declarative_base
from ..orm import registry
-from ..orm.decl_api import DeclarativeMeta
from ..schema import sort_tables_and_constraints
"""
cls_registry = cls.classes
- assert cls_registry is not None
-
- class FindFixture(type):
- def __init__(cls, classname, bases, dict_):
- cls_registry[classname] = cls
- type.__init__(cls, classname, bases, dict_)
-
- class _Base(metaclass=FindFixture):
- pass
+ class _Base:
+ def __init_subclass__(cls) -> None:
+ assert cls_registry is not None
+ cls_registry[cls.__name__] = cls
+ super().__init_subclass__()
class Basic(BasicEntity, _Base):
pass
def _with_register_classes(cls, fn):
cls_registry = cls.classes
- class FindFixtureDeclarative(DeclarativeMeta):
- def __init__(cls, classname, bases, dict_):
- cls_registry[classname] = cls
- DeclarativeMeta.__init__(cls, classname, bases, dict_)
-
class DeclarativeBasic:
__table_cls__ = schema.Table
+ def __init_subclass__(cls) -> None:
+ assert cls_registry is not None
+ cls_registry[cls.__name__] = cls
+ super().__init_subclass__()
+
_DeclBase = declarative_base(
metadata=cls._tables_metadata,
- metaclass=FindFixtureDeclarative,
cls=DeclarativeBasic,
)
from .langhelpers import dictlike_iteritems
from .langhelpers import duck_type_collection
from .langhelpers import ellipses_string
-from .langhelpers import EnsureKWArgType
+from .langhelpers import EnsureKWArg
from .langhelpers import format_argspec_init
from .langhelpers import format_argspec_plus
from .langhelpers import generic_repr
return env["set"]
-class EnsureKWArgType(type):
+class EnsureKWArg:
r"""Apply translation of functions to accept \**kw arguments if they
don't already.
+ Used to ensure cross-compatibility with third party legacy code, for things
+ like compiler visit methods that need to accept ``**kw`` arguments,
+ but may have been copied from old code that didn't accept them.
+
+ """
+
+ ensure_kwarg: str
+ """a regular expression that indicates method names for which the method
+ should accept ``**kw`` arguments.
+
+ The class will scan for methods matching the name template and decorate
+ them if necessary to ensure ``**kw`` parameters are accepted.
+
"""
- def __init__(cls, clsname, bases, clsdict):
+ def __init_subclass__(cls) -> None:
fn_reg = cls.ensure_kwarg
+ clsdict = cls.__dict__
if fn_reg:
for key in clsdict:
m = re.match(fn_reg, key)
fn = clsdict[key]
spec = compat.inspect_getfullargspec(fn)
if not spec.varkw:
- clsdict[key] = wrapped = cls._wrap_w_kw(fn)
+ wrapped = cls._wrap_w_kw(fn)
setattr(cls, key, wrapped)
- super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict)
+ super().__init_subclass__()
- def _wrap_w_kw(self, fn):
+ @classmethod
+ def _wrap_w_kw(cls, fn):
def wrap(*arg, **kw):
return fn(*arg)
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import subqueryload
from sqlalchemy.orm import UserDefinedOption
-from sqlalchemy.sql.traversals import NO_CACHE
+from sqlalchemy.sql.cache_key import NO_CACHE
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.sql import table
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql.base import ExecutableOption
-from sqlalchemy.sql.traversals import HasCacheKey
+from sqlalchemy.sql.cache_key import HasCacheKey
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_