From: Mike Bayer Date: Tue, 18 Jan 2022 22:00:16 +0000 (-0500) Subject: Add new infrastructure to support greater use of __slots__ X-Git-Tag: rel_2_0_0b1~522 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d46a4c0326bd2e697794514b920e6727d5153324;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add new infrastructure to support greater use of __slots__ * Changed AliasedInsp to use __slots__ * Migrated all of strategy_options to use __slots__ for objects. Adds new infrastructure to traversals to support shallow copy, to dict and from dict based on internal traversal attributes. Load / _LoadElement then leverage this to provide clone / generative / getstate without the need for __dict__ or explicit attribute lists. Doing this change revealed that there are lots of things that trigger off of whether or not a class has a __visit_name__ attribute. so to suit that we've gone back to having Visitable, which is a better name than Traversible at this point (I think Traversible is mis-spelled too). Change-Id: I13d04e494339fac9dbda0b8e78153418abebaf72 References: #7527 --- diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 08189a1b75..b9a5aaf518 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -64,20 +64,24 @@ __all__ = ( class ORMStatementRole(roles.StatementRole): + __slots__ = () _role_name = ( "Executable SQL or text() construct, including ORM " "aware objects" ) class ORMColumnsClauseRole(roles.ColumnsClauseRole): + __slots__ = () _role_name = "ORM mapped entity, aliased entity, or Column expression" class ORMEntityColumnsClauseRole(ORMColumnsClauseRole): + __slots__ = () _role_name = "ORM mapped or aliased entity" class ORMFromClauseRole(roles.StrictFromClauseRole): + __slots__ = () _role_name = "ORM mapped entity, aliased entity, or FROM expression" @@ -798,6 +802,8 @@ class CompileStateOption(HasCacheKey, ORMOption): """ + __slots__ = () + _is_compile_state = True def process_compile_state(self, compile_state): @@ -832,6 +838,8 @@ class LoaderOption(CompileStateOption): """ + __slots__ = () + def process_compile_state_replaced_entities( self, compile_state, mapper_entities ): @@ -846,6 +854,8 @@ class CriteriaOption(CompileStateOption): """ + __slots__ = () + _is_criteria_option = True def get_global_criteria(self, attributes): @@ -861,6 +871,8 @@ class UserDefinedOption(ORMOption): """ + __slots__ = ("payload",) + _is_legacy_option = False propagate_to_loaders = False @@ -887,6 +899,8 @@ class UserDefinedOption(ORMOption): class MapperOption(ORMOption): """Describe a modification to a Query""" + __slots__ = () + _is_legacy_option = True propagate_to_loaders = False diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index c2cfbb9fc2..0f993b86cf 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -13,6 +13,7 @@ from typing import Any from typing import cast from typing import Mapping from typing import NoReturn +from typing import Optional from typing import Tuple from typing import Union @@ -32,9 +33,9 @@ from ..sql import and_ from ..sql import cache_key from ..sql import coercions from ..sql import roles +from ..sql import traversals from ..sql import visitors from ..sql.base import _generative -from ..sql.base import Generative _RELATIONSHIP_TOKEN = "relationship" _COLUMN_TOKEN = "column" @@ -45,9 +46,11 @@ if typing.TYPE_CHECKING: Self_AbstractLoad = typing.TypeVar("Self_AbstractLoad", bound="_AbstractLoad") -class _AbstractLoad(Generative, LoaderOption): +class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): + __slots__ = ("propagate_to_loaders",) + _is_strategy_option = True - propagate_to_loaders = False + propagate_to_loaders: bool def contains_eager(self, attr, alias=None, _is_chain=False): r"""Indicate that the given attribute should be eagerly loaded from @@ -882,13 +885,20 @@ class Load(_AbstractLoad): """ - _cache_key_traversal = [ + __slots__ = ( + "path", + "context", + ) + + _traverse_internals = [ ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key), ( "context", visitors.InternalTraversal.dp_has_cache_key_list, ), + ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean), ] + _cache_key_traversal = None path: PathRegistry context: Tuple["_LoadElement", ...] @@ -899,6 +909,7 @@ class Load(_AbstractLoad): self.path = insp._path_registry self.context = () + self.propagate_to_loaders = False def __str__(self): return f"Load({self.path[0]})" @@ -908,6 +919,7 @@ class Load(_AbstractLoad): load = cls.__new__(cls) load.path = path load.context = () + load.propagate_to_loaders = False return load def _adjust_for_extra_criteria(self, context): @@ -1128,13 +1140,13 @@ class Load(_AbstractLoad): self.context += (load_element,) def __getstate__(self): - d = self.__dict__.copy() + d = self._shallow_to_dict() d["path"] = self.path.serialize() return d def __setstate__(self, state): - self.__dict__.update(state) - self.path = PathRegistry.deserialize(self.path) + state["path"] = PathRegistry.deserialize(state["path"]) + self._shallow_from_dict(state) SelfWildcardLoad = typing.TypeVar("SelfWildcardLoad", bound="_WildcardLoad") @@ -1143,16 +1155,27 @@ SelfWildcardLoad = typing.TypeVar("SelfWildcardLoad", bound="_WildcardLoad") class _WildcardLoad(_AbstractLoad): """represent a standalone '*' load operation""" - _cache_key_traversal = [ + __slots__ = ("strategy", "path", "local_opts") + + _traverse_internals = [ ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("path", visitors.ExtendedInternalTraversal.dp_plain_obj), ( "local_opts", visitors.ExtendedInternalTraversal.dp_string_multi_dict, ), ] + cache_key_traversal = None - local_opts = util.EMPTY_DICT - path: Tuple[str, ...] = () + strategy: Optional[Tuple[Any, ...]] + local_opts: Mapping[str, Any] + path: Tuple[str, ...] + propagate_to_loaders = False + + def __init__(self): + self.path = () + self.strategy = None + self.local_opts = util.EMPTY_DICT def _clone_for_bind_strategy( self, @@ -1171,16 +1194,6 @@ class _WildcardLoad(_AbstractLoad): and attr in (_WILDCARD_TOKEN, _DEFAULT_TOKEN) ) - if attr == _DEFAULT_TOKEN: - # for someload('*'), this currently does propagate=False, - # to prevent it from taking effect for lazy loads. - # it seems like adjusting for current_path for a lazy load etc. - # should be taking care of that, so that the option still takes - # effect for a refresh as well, but currently it does not. - # probably should be adjusted to be more accurate re: current - # path vs. refresh - self.propagate_to_loaders = False - attr = f"{wildcard_key}:{attr}" self.strategy = strategy @@ -1310,13 +1323,16 @@ class _WildcardLoad(_AbstractLoad): return None def __getstate__(self): - return self.__dict__.copy() + d = self._shallow_to_dict() + return d def __setstate__(self, state): - self.__dict__.update(state) + self._shallow_from_dict(state) -class _LoadElement(cache_key.HasCacheKey): +class _LoadElement( + cache_key.HasCacheKey, traversals.HasShallowCopy, visitors.Traversible +): """represents strategy information to select for a LoaderStrategy and pass options to it. @@ -1328,40 +1344,66 @@ class _LoadElement(cache_key.HasCacheKey): """ - _cache_key_traversal = [ + __slots__ = ( + "path", + "strategy", + "propagate_to_loaders", + "local_opts", + "_extra_criteria", + "_reconcile_to_other", + ) + __visit_name__ = "load_element" + + _traverse_internals = [ ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key), ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), ( "local_opts", visitors.ExtendedInternalTraversal.dp_string_multi_dict, ), + ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list), + ("propagate_to_loaders", visitors.InternalTraversal.dp_plain_obj), + ("_reconcile_to_other", visitors.InternalTraversal.dp_plain_obj), ] + _cache_key_traversal = None - _extra_criteria = () + _extra_criteria: Tuple[Any, ...] - _reconcile_to_other = None - strategy = None + _reconcile_to_other: Optional[bool] + strategy: Tuple[Any, ...] path: PathRegistry - propagate_to_loaders = False + propagate_to_loaders: bool local_opts: Mapping[str, Any] is_token_strategy: bool is_class_strategy: bool + def __hash__(self): + return id(self) + + def __eq__(self, other): + return traversals.compare(self, other) + @property def is_opts_only(self): return bool(self.local_opts and self.strategy is None) + def _clone(self): + cls = self.__class__ + s = cls.__new__(cls) + + self._shallow_copy_to(s) + return s + def __getstate__(self): - d = self.__dict__.copy() + d = self._shallow_to_dict() d["path"] = self.path.serialize() - return d def __setstate__(self, state): state["path"] = PathRegistry.deserialize(state["path"]) - self.__dict__.update(state) + self._shallow_from_dict(state) def _raise_for_no_match(self, parent_loader, mapper_entities): path = parent_loader.path @@ -1498,11 +1540,14 @@ class _LoadElement(cache_key.HasCacheKey): opt.local_opts = ( util.immutabledict(local_opts) if local_opts else util.EMPTY_DICT ) + opt._extra_criteria = () if reconcile_to_other is not None: opt._reconcile_to_other = reconcile_to_other elif strategy is None and not local_opts: opt._reconcile_to_other = True + else: + opt._reconcile_to_other = None path = opt._init_path(path, attr, wildcard_key, attr_group, raiseerr) @@ -1517,12 +1562,6 @@ class _LoadElement(cache_key.HasCacheKey): def __init__(self, path, strategy, local_opts, propagate_to_loaders): raise NotImplementedError() - def _clone(self): - cls = self.__class__ - s = cls.__new__(cls) - s.__dict__ = self.__dict__.copy() - return s - def _prepend_path_from(self, parent): """adjust the path of this :class:`._LoadElement` to be a subpath of that of the given parent :class:`_orm.Load` object's @@ -1617,20 +1656,28 @@ class _AttributeStrategyLoad(_LoadElement): """ - _cache_key_traversal = _LoadElement._cache_key_traversal + [ + __slots__ = ("_of_type", "_path_with_polymorphic_path") + + __visit_name__ = "attribute_strategy_load_element" + + _traverse_internals = _LoadElement._traverse_internals + [ ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), - ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list), + ( + "_path_with_polymorphic_path", + visitors.ExtendedInternalTraversal.dp_has_cache_key, + ), ] - _of_type: Union["Mapper", AliasedInsp, None] = None - _path_with_polymorphic_path = None + _of_type: Union["Mapper", AliasedInsp, None] + _path_with_polymorphic_path: Optional[PathRegistry] - inherit_cache = True is_class_strategy = False is_token_strategy = False def _init_path(self, path, attr, wildcard_key, attr_group, raiseerr): assert attr is not None + self._of_type = None + self._path_with_polymorphic_path = None insp, _, prop = _parse_attr_argument(attr) if insp.is_property: @@ -1832,12 +1879,14 @@ class _AttributeStrategyLoad(_LoadElement): return [("loader", cast(PathRegistry, effective_path).natural_path)] def __getstate__(self): - d = self.__dict__.copy() + d = super().__getstate__() + + # can't pickle this. See + # test_pickled.py -> test_lazyload_extra_criteria_not_supported + # where we should be emitting a warning for the usual case where this + # would be non-None d["_extra_criteria"] = () - d["path"] = self.path.serialize() - # TODO: we hope to do this logic only at compile time so that - # we aren't carrying these extra attributes around if self._path_with_polymorphic_path: d[ "_path_with_polymorphic_path" @@ -1854,14 +1903,19 @@ class _AttributeStrategyLoad(_LoadElement): return d def __setstate__(self, state): - state["path"] = PathRegistry.deserialize(state["path"]) - self.__dict__.update(state) - if "_path_with_polymorphic_path" in state: + super().__setstate__(state) + + if state.get("_path_with_polymorphic_path", None): self._path_with_polymorphic_path = PathRegistry.deserialize( - self._path_with_polymorphic_path + state["_path_with_polymorphic_path"] ) - if self._of_type is not None: - self._of_type = inspect(self._of_type) + else: + self._path_with_polymorphic_path = None + + if state.get("_of_type", None): + self._of_type = inspect(state["_of_type"]) + else: + self._of_type = None class _TokenStrategyLoad(_LoadElement): @@ -1877,6 +1931,8 @@ class _TokenStrategyLoad(_LoadElement): """ + __visit_name__ = "token_strategy_load_element" + inherit_cache = True is_class_strategy = False is_token_strategy = True @@ -1962,6 +2018,8 @@ class _ClassStrategyLoad(_LoadElement): is_class_strategy = True is_token_strategy = False + __visit_name__ = "class_strategy_load_element" + def _init_path(self, path, attr, wildcard_key, attr_group, raiseerr): return path diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index e84517670c..75f7110078 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -45,6 +45,7 @@ from ..sql import util as sql_util from ..sql import visitors from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import ColumnCollection +from ..util.langhelpers import MemoizedSlots all_cascades = frozenset( @@ -609,8 +610,9 @@ class AliasedClass: class AliasedInsp( ORMEntityColumnsClauseRole, ORMFromClauseRole, - sql_base.MemoizedHasCacheKey, + sql_base.HasCacheKey, InspectionAttr, + MemoizedSlots, ): """Provide an inspection interface for an :class:`.AliasedClass` object. @@ -650,6 +652,30 @@ class AliasedInsp( """ + __slots__ = ( + "__weakref__", + "_weak_entity", + "mapper", + "selectable", + "name", + "_adapt_on_names", + "with_polymorphic_mappers", + "polymorphic_on", + "_use_mapper_path", + "_base_alias", + "represents_outer_join", + "persist_selectable", + "local_table", + "_is_with_polymorphic", + "_with_polymorphic_entities", + "_adapter", + "_target", + "__clause_element__", + "_memoized_values", + "_all_column_expressions", + "_nest_adapters", + ) + def __init__( self, entity, @@ -738,8 +764,7 @@ class AliasedInsp( is_aliased_class = True "always returns True" - @util.memoized_instancemethod - def __clause_element__(self): + def _memoized_method___clause_element__(self): return self.selectable._annotate( { "parentmapper": self.mapper, @@ -863,8 +888,7 @@ class AliasedInsp( else: assert False, "mapper %s doesn't correspond to %s" % (mapper, self) - @util.memoized_property - def _get_clause(self): + def _memoized_attr__get_clause(self): onclause, replacemap = self.mapper._get_clause return ( self._adapter.traverse(onclause), @@ -874,12 +898,10 @@ class AliasedInsp( }, ) - @util.memoized_property - def _memoized_values(self): + def _memoized_attr__memoized_values(self): return {} - @util.memoized_property - def _all_column_expressions(self): + def _memoized_attr__all_column_expressions(self): if self._is_with_polymorphic: cols_plus_keys = self.mapper._columns_plus_keys( [ent.mapper for ent in self._with_polymorphic_entities] @@ -965,6 +987,15 @@ class LoaderCriteriaOption(CriteriaOption): """ + __slots__ = ( + "root_entity", + "entity", + "deferred_where_criteria", + "where_criteria", + "include_aliases", + "propagate_to_loaders", + ) + _traverse_internals = [ ("root_entity", visitors.ExtendedInternalTraversal.dp_plain_obj), ("entity", visitors.ExtendedInternalTraversal.dp_has_cache_key), diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 74469b0350..8ae8f8f65f 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -17,6 +17,7 @@ from itertools import zip_longest import operator import re import typing +from typing import TypeVar from . import roles from . import visitors @@ -571,11 +572,14 @@ class CompileState: return decorate +SelfGenerative = TypeVar("SelfGenerative", bound="Generative") + + class Generative(HasMemoized): """Provide a method-chaining pattern in conjunction with the @_generative decorator.""" - def _generate(self): + def _generate(self: SelfGenerative) -> SelfGenerative: skip = self._memoized_keys cls = self.__class__ s = cls.__new__(cls) @@ -783,6 +787,8 @@ class Options(metaclass=_MetaOptions): class CacheableOptions(Options, HasCacheKey): + __slots__ = () + @hybridmethod def _gen_cache_key(self, anon_map, bindparams): return HasCacheKey._gen_cache_key(self, anon_map, bindparams) @@ -797,6 +803,8 @@ class CacheableOptions(Options, HasCacheKey): class ExecutableOption(HasCopyInternals): + __slots__ = () + _annotations = util.EMPTY_DICT __visit_name__ = "executable_option" diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index 8dd44dbf08..42bd603537 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -47,6 +47,11 @@ class CacheTraverseTarget(enum.Enum): class HasCacheKey: """Mixin for objects which can produce a cache key. + This class is usually in a hierarchy that starts with the + :class:`.HasTraverseInternals` base, but this is optional. Currently, + the class should be able to work on its own without including + :class:`.HasTraverseInternals`. + .. seealso:: :class:`.CacheKey` @@ -55,6 +60,8 @@ class HasCacheKey: """ + __slots__ = () + _cache_key_traversal = NO_CACHE _is_has_cache_key = True @@ -106,11 +113,17 @@ class HasCacheKey: _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) if _cache_key_traversal is None: try: + # this would be HasTraverseInternals _cache_key_traversal = cls._traverse_internals except AttributeError: cls._generated_cache_key_traversal = NO_CACHE return NO_CACHE + assert _cache_key_traversal is not NO_CACHE, ( + f"class {cls} has _cache_key_traversal=NO_CACHE, " + "which conflicts with inherit_cache=True" + ) + # TODO: wouldn't we instead get this from our superclass? # also, our superclass may not have this yet, but in any case, # we'd generate for the superclass that has it. this is a little @@ -323,6 +336,8 @@ class HasCacheKey: class MemoizedHasCacheKey(HasCacheKey, HasMemoized): + __slots__ = () + @HasMemoized.memoized_instancemethod def _generate_cache_key(self): return HasCacheKey._generate_cache_key(self) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 43979b4aee..d14521ba73 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -46,7 +46,7 @@ from .traversals import HasCopyInternals from .visitors import cloned_traverse from .visitors import InternalTraversal from .visitors import traverse -from .visitors import Traversible +from .visitors import Visitable from .. import exc from .. import inspection from .. import util @@ -126,7 +126,7 @@ def literal_column(text, type_=None): return ColumnClause(text, type_=type_, is_literal=True) -class CompilerElement(Traversible): +class CompilerElement(Visitable): """base class for SQL elements that can be compiled to produce a SQL string. diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 2fa3a04083..18fd1d4b81 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -10,12 +10,22 @@ import collections.abc as collections_abc import itertools from itertools import zip_longest import operator +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Type +from typing import TypeVar from . import operators +from .cache_key import HasCacheKey +from .visitors import _TraverseInternalsType from .visitors import anon_map +from .visitors import ExtendedInternalTraversal +from .visitors import HasTraverseInternals from .visitors import InternalTraversal from .. import util - +from ..util import langhelpers SKIP_TRAVERSE = util.symbol("skip_traverse") COMPARE_FAILED = False @@ -47,11 +57,158 @@ def _preconfigure_traversals(target_hierarchy): ) +SelfHasShallowCopy = TypeVar("SelfHasShallowCopy", bound="HasShallowCopy") + + +class HasShallowCopy(HasTraverseInternals): + """attribute-wide operations that are useful for classes that use + __slots__ and therefore can't operate on their attributes in a dictionary. + + + """ + + __slots__ = () + + if typing.TYPE_CHECKING: + + def _generated_shallow_copy_traversal( + self: SelfHasShallowCopy, other: SelfHasShallowCopy + ) -> None: + ... + + def _generated_shallow_from_dict_traversal( + self, d: Dict[str, Any] + ) -> None: + ... + + def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]: + ... + + @classmethod + def _generate_shallow_copy( + cls: Type[SelfHasShallowCopy], + internal_dispatch: _TraverseInternalsType, + method_name: str, + ) -> Callable[[SelfHasShallowCopy, SelfHasShallowCopy], None]: + code = "\n".join( + f" other.{attrname} = self.{attrname}" + for attrname, _ in internal_dispatch + ) + meth_text = f"def {method_name}(self, other):\n{code}\n" + return langhelpers._exec_code_in_env(meth_text, {}, method_name) + + @classmethod + def _generate_shallow_to_dict( + cls: Type[SelfHasShallowCopy], + internal_dispatch: _TraverseInternalsType, + method_name: str, + ) -> Callable[[SelfHasShallowCopy], Dict[str, Any]]: + code = ",\n".join( + f" '{attrname}': self.{attrname}" + for attrname, _ in internal_dispatch + ) + meth_text = f"def {method_name}(self):\n return {{{code}}}\n" + return langhelpers._exec_code_in_env(meth_text, {}, method_name) + + @classmethod + def _generate_shallow_from_dict( + cls: Type[SelfHasShallowCopy], + internal_dispatch: _TraverseInternalsType, + method_name: str, + ) -> Callable[[SelfHasShallowCopy, Dict[str, Any]], None]: + code = "\n".join( + f" self.{attrname} = d['{attrname}']" + for attrname, _ in internal_dispatch + ) + meth_text = f"def {method_name}(self, d):\n{code}\n" + return langhelpers._exec_code_in_env(meth_text, {}, method_name) + + def _shallow_from_dict(self, d: Dict) -> None: + cls = self.__class__ + + try: + shallow_from_dict = cls.__dict__[ + "_generated_shallow_from_dict_traversal" + ] + except KeyError: + shallow_from_dict = ( + cls._generated_shallow_from_dict_traversal # type: ignore + ) = self._generate_shallow_from_dict( + cls._traverse_internals, + "_generated_shallow_from_dict_traversal", + ) + + shallow_from_dict(self, d) + + def _shallow_to_dict(self) -> Dict[str, Any]: + cls = self.__class__ + + try: + shallow_to_dict = cls.__dict__[ + "_generated_shallow_to_dict_traversal" + ] + except KeyError: + shallow_to_dict = ( + cls._generated_shallow_to_dict_traversal # type: ignore + ) = self._generate_shallow_to_dict( + cls._traverse_internals, "_generated_shallow_to_dict_traversal" + ) + + return shallow_to_dict(self) + + def _shallow_copy_to(self: SelfHasShallowCopy, other: SelfHasShallowCopy): + cls = self.__class__ + + try: + shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"] + except KeyError: + shallow_copy = ( + cls._generated_shallow_copy_traversal # type: ignore + ) = self._generate_shallow_copy( + cls._traverse_internals, "_generated_shallow_copy_traversal" + ) + + shallow_copy(self, other) + + def _clone(self: SelfHasShallowCopy, **kw) -> SelfHasShallowCopy: + """Create a shallow copy""" + c = self.__class__.__new__(self.__class__) + self._shallow_copy_to(c) + return c + + +SelfGenerativeOnTraversal = TypeVar( + "SelfGenerativeOnTraversal", bound="GenerativeOnTraversal" +) + + +class GenerativeOnTraversal(HasShallowCopy): + """Supplies Generative behavior but making use of traversals to shallow + copy. + + .. seealso:: + + :class:`sqlalchemy.sql.base.Generative` + + + """ + + __slots__ = () + + def _generate( + self: SelfGenerativeOnTraversal, + ) -> SelfGenerativeOnTraversal: + cls = self.__class__ + s = cls.__new__(cls) + self._shallow_copy_to(s) + return s + + def _clone(element, **kw): return element._clone() -class HasCopyInternals: +class HasCopyInternals(HasTraverseInternals): __slots__ = () def _clone(self, **kw): @@ -304,7 +461,9 @@ def _resolve_name_for_compare(element, name, anon_map, **kw): return name -class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): +class TraversalComparatorStrategy( + ExtendedInternalTraversal, util.MemoizedSlots +): __slots__ = "stack", "cache", "anon_map" def __init__(self): @@ -377,6 +536,10 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): continue dispatch = self.dispatch(left_visit_sym) + assert dispatch, ( + f"{self.__class__} has no dispatch for " + f"'{self._dispatch_lookup[left_visit_sym]}'" + ) left_child = operator.attrgetter(left_attrname)(left) right_child = operator.attrgetter(right_attrname)(right) if left_child is None: @@ -517,6 +680,46 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): ): return left == right + def visit_string_multi_dict( + self, attrname, left_parent, left, right_parent, right, **kw + ): + + for lk, rk in zip_longest( + sorted(left.keys()), sorted(right.keys()), fillvalue=(None, None) + ): + if lk != rk: + return COMPARE_FAILED + + lv, rv = left[lk], right[rk] + + lhc = isinstance(left, HasCacheKey) + rhc = isinstance(right, HasCacheKey) + if lhc and rhc: + if lv._gen_cache_key( + self.anon_map[0], [] + ) != rv._gen_cache_key(self.anon_map[1], []): + return COMPARE_FAILED + elif lhc != rhc: + return COMPARE_FAILED + elif lv != rv: + return COMPARE_FAILED + + def visit_multi( + self, attrname, left_parent, left, right_parent, right, **kw + ): + + lhc = isinstance(left, HasCacheKey) + rhc = isinstance(right, HasCacheKey) + if lhc and rhc: + if left._gen_cache_key( + self.anon_map[0], [] + ) != right._gen_cache_key(self.anon_map[1], []): + return COMPARE_FAILED + elif lhc != rhc: + return COMPARE_FAILED + else: + return left == right + def visit_anon_name( self, attrname, left_parent, left, right_parent, right, **kw ): diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 70c4dc1332..78384782b8 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -26,11 +26,14 @@ https://techspot.zzzeek.org/2008/01/23/expression-transformations/ . from collections import deque import itertools import operator +from typing import List +from typing import Tuple from .. import exc from .. import util from ..util import langhelpers from ..util import symbol +from ..util.langhelpers import _symbol try: from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa @@ -43,14 +46,67 @@ __all__ = [ "traverse", "cloned_traverse", "replacement_traverse", - "Traversible", + "Visitable", "ExternalTraversal", "InternalTraversal", ] +_TraverseInternalsType = List[Tuple[str, _symbol]] -class Traversible: - """Base class for visitable objects.""" + +class HasTraverseInternals: + """base for classes that have a "traverse internals" element, + which defines all kinds of ways of traversing the elements of an object. + + """ + + __slots__ = () + + _traverse_internals: _TraverseInternalsType + + @util.preload_module("sqlalchemy.sql.traversals") + def get_children(self, omit_attrs=(), **kw): + r"""Return immediate child :class:`.visitors.Visitable` + elements of this :class:`.visitors.Visitable`. + + This is used for visit traversal. + + \**kw may contain flags that change the collection that is + returned, for example to return a subset of items in order to + cut down on larger traversals, or to return child items from a + different context (such as schema-level collections instead of + clause-level). + + """ + + traversals = util.preloaded.sql_traversals + + try: + traverse_internals = self._traverse_internals + except AttributeError: + # user-defined classes may not have a _traverse_internals + return [] + + dispatch = traversals._get_children.run_generated_dispatch + return itertools.chain.from_iterable( + meth(obj, **kw) + for attrname, obj, meth in dispatch( + self, traverse_internals, "_generated_get_children_traversal" + ) + if attrname not in omit_attrs and obj is not None + ) + + +class Visitable: + """Base class for visitable objects. + + .. versionchanged:: 2.0 The :class:`.Visitable` class was named + :class:`.Traversible` in the 1.4 series; the name is changed back + to :class:`.Visitable` in 2.0 which is what it was prior to 1.4. + + Both names remain importable in both 1.4 and 2.0 versions. + + """ __slots__ = () @@ -120,38 +176,6 @@ class Traversible: # allow generic classes in py3.9+ return cls - @util.preload_module("sqlalchemy.sql.traversals") - def get_children(self, omit_attrs=(), **kw): - r"""Return immediate child :class:`.visitors.Traversible` - elements of this :class:`.visitors.Traversible`. - - This is used for visit traversal. - - \**kw may contain flags that change the collection that is - returned, for example to return a subset of items in order to - cut down on larger traversals, or to return child items from a - different context (such as schema-level collections instead of - clause-level). - - """ - - traversals = util.preloaded.sql_traversals - - try: - traverse_internals = self._traverse_internals - except AttributeError: - # user-defined classes may not have a _traverse_internals - return [] - - dispatch = traversals._get_children.run_generated_dispatch - return itertools.chain.from_iterable( - meth(obj, **kw) - for attrname, obj, meth in dispatch( - self, traverse_internals, "_generated_get_children_traversal" - ) - if attrname not in omit_attrs and obj is not None - ) - class _HasTraversalDispatch: r"""Define infrastructure for the :class:`.InternalTraversal` class. @@ -261,14 +285,14 @@ class InternalTraversal(_HasTraversalDispatch): :class:`.InternalTraversible` will have the following methods automatically implemented: - * :meth:`.Traversible.get_children` + * :meth:`.HasTraverseInternals.get_children` - * :meth:`.Traversible._copy_internals` + * :meth:`.HasTraverseInternals._copy_internals` - * :meth:`.Traversible._gen_cache_key` + * :meth:`.HasCacheKey._gen_cache_key` Subclasses can also implement these methods directly, particularly for the - :meth:`.Traversible._copy_internals` method, when special steps + :meth:`.HasTraverseInternals._copy_internals` method, when special steps are needed. .. versionadded:: 1.4 @@ -625,7 +649,8 @@ class ReplacingExternalTraversal(CloningExternalTraversal): # backwards compatibility -Visitable = Traversible +Traversible = Visitable + ClauseVisitor = ExternalTraversal CloningVisitor = CloningExternalTraversal ReplacingCloningVisitor = ReplacingExternalTraversal diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 80ef3458c0..8b65fb4cf6 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1156,7 +1156,9 @@ class MemoizedSlots: raise AttributeError(key) def __getattr__(self, key): - if key.startswith("_memoized"): + if key.startswith("_memoized_attr_") or key.startswith( + "_memoized_method_" + ): raise AttributeError(key) elif hasattr(self, "_memoized_attr_%s" % key): value = getattr(self, "_memoized_attr_%s" % key)() diff --git a/test/orm/test_options.py b/test/orm/test_options.py index 7098220801..e74ffeced4 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -1,3 +1,5 @@ +import pickle + import sqlalchemy as sa from sqlalchemy import Column from sqlalchemy import ForeignKey @@ -5,6 +7,7 @@ from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import String +from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy.orm import aliased from sqlalchemy.orm import attributes @@ -1241,31 +1244,61 @@ class OptionsNoPropTestInh(_Polymorphic): eq_(loader.path, orig_path) -class PickleTest(PathTest, QueryTest): - def _option_fixture(self, *arg): - return strategy_options._generate_from_keys( - strategy_options.Load.joinedload, arg, True, {} +class PickleTest(fixtures.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table( + "users", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), nullable=False), + ) + Table( + "addresses", + metadata, + Column("id", Integer, primary_key=True), + Column("user_id", None, ForeignKey("users.id")), + Column("email_address", String(50), nullable=False), + ) + + @testing.fixture + def user_address_fixture(self, registry): + from sqlalchemy.testing.pickleable import User, Address + + registry.map_imperatively( + User, + self.tables.users, + properties={"addresses": relationship(Address)}, ) + registry.map_imperatively(Address, self.tables.addresses) - def test_modern_opt_getstate(self): - User = self.classes.User + return User, Address - opt = self._option_fixture(User.addresses) + def test_slots(self, user_address_fixture): + User, Address = user_address_fixture + + opt = joinedload(User.addresses) + + assert not hasattr(opt, "__dict__") + assert not hasattr(opt.context[0], "__dict__") + + def test_pickle_relationship_loader(self, user_address_fixture): + User, Address = user_address_fixture - q1 = fixture_session().query(User).options(opt) - c1 = q1._compile_context() + for i in range(3): + opt = joinedload(User.addresses) - state = opt.__getstate__() + q1 = fixture_session().query(User).options(opt) + c1 = q1._compile_context() - opt2 = Load.__new__(Load) - opt2.__setstate__(state) + pickled = pickle.dumps(opt) - eq_(opt.__dict__, opt2.__dict__) + opt2 = pickle.loads(pickled) - q2 = fixture_session().query(User).options(opt2) - c2 = q2._compile_context() + q2 = fixture_session().query(User).options(opt2) + c2 = q2._compile_context() - eq_(c1.attributes, c2.attributes) + eq_(c1.attributes, c2.attributes) class LocalOptsTest(PathTest, QueryTest): diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py index 8e4c6ab17a..a4250e375c 100644 --- a/test/orm/test_pickled.py +++ b/test/orm/test_pickled.py @@ -271,7 +271,6 @@ class PickleTest(fixtures.MappedTest): sess.add(u1) sess.commit() sess.close() - u1 = ( sess.query(User) .options(