]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add new infrastructure to support greater use of __slots__
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 18 Jan 2022 22:00:16 +0000 (17:00 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Jan 2022 16:46:51 +0000 (11:46 -0500)
* Changed AliasedInsp to use __slots__
* Migrated all of strategy_options to use __slots__ for objects.
  Adds new infrastructure to traversals to support shallow
  copy, to dict and from dict based on internal traversal
  attributes.  Load / _LoadElement then leverage this to
  provide clone / generative / getstate without the need
  for __dict__ or explicit attribute lists.

Doing this change revealed that there are lots of things that
trigger off of whether or not a class has a __visit_name__ attribute.
so to suit that we've gone back to having Visitable, which is
a better name than Traversible at this point  (I think
Traversible is mis-spelled too).

Change-Id: I13d04e494339fac9dbda0b8e78153418abebaf72
References: #7527

lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/strategy_options.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/cache_key.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/traversals.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/util/langhelpers.py
test/orm/test_options.py
test/orm/test_pickled.py

index 08189a1b757648283d5d30e2b74b77f9864ec3eb..b9a5aaf518ad5c53eab046042f99956a1ba1d8b4 100644 (file)
@@ -64,20 +64,24 @@ __all__ = (
 
 
 class ORMStatementRole(roles.StatementRole):
+    __slots__ = ()
     _role_name = (
         "Executable SQL or text() construct, including ORM " "aware objects"
     )
 
 
 class ORMColumnsClauseRole(roles.ColumnsClauseRole):
+    __slots__ = ()
     _role_name = "ORM mapped entity, aliased entity, or Column expression"
 
 
 class ORMEntityColumnsClauseRole(ORMColumnsClauseRole):
+    __slots__ = ()
     _role_name = "ORM mapped or aliased entity"
 
 
 class ORMFromClauseRole(roles.StrictFromClauseRole):
+    __slots__ = ()
     _role_name = "ORM mapped entity, aliased entity, or FROM expression"
 
 
@@ -798,6 +802,8 @@ class CompileStateOption(HasCacheKey, ORMOption):
 
     """
 
+    __slots__ = ()
+
     _is_compile_state = True
 
     def process_compile_state(self, compile_state):
@@ -832,6 +838,8 @@ class LoaderOption(CompileStateOption):
 
     """
 
+    __slots__ = ()
+
     def process_compile_state_replaced_entities(
         self, compile_state, mapper_entities
     ):
@@ -846,6 +854,8 @@ class CriteriaOption(CompileStateOption):
 
     """
 
+    __slots__ = ()
+
     _is_criteria_option = True
 
     def get_global_criteria(self, attributes):
@@ -861,6 +871,8 @@ class UserDefinedOption(ORMOption):
 
     """
 
+    __slots__ = ("payload",)
+
     _is_legacy_option = False
 
     propagate_to_loaders = False
@@ -887,6 +899,8 @@ class UserDefinedOption(ORMOption):
 class MapperOption(ORMOption):
     """Describe a modification to a Query"""
 
+    __slots__ = ()
+
     _is_legacy_option = True
 
     propagate_to_loaders = False
index c2cfbb9fc2184f73711421ccd2b4bb6583dc281e..0f993b86cfd38ac6623038fc34f0e00432ac4771 100644 (file)
@@ -13,6 +13,7 @@ from typing import Any
 from typing import cast
 from typing import Mapping
 from typing import NoReturn
+from typing import Optional
 from typing import Tuple
 from typing import Union
 
@@ -32,9 +33,9 @@ from ..sql import and_
 from ..sql import cache_key
 from ..sql import coercions
 from ..sql import roles
+from ..sql import traversals
 from ..sql import visitors
 from ..sql.base import _generative
-from ..sql.base import Generative
 
 _RELATIONSHIP_TOKEN = "relationship"
 _COLUMN_TOKEN = "column"
@@ -45,9 +46,11 @@ if typing.TYPE_CHECKING:
 Self_AbstractLoad = typing.TypeVar("Self_AbstractLoad", bound="_AbstractLoad")
 
 
-class _AbstractLoad(Generative, LoaderOption):
+class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
+    __slots__ = ("propagate_to_loaders",)
+
     _is_strategy_option = True
-    propagate_to_loaders = False
+    propagate_to_loaders: bool
 
     def contains_eager(self, attr, alias=None, _is_chain=False):
         r"""Indicate that the given attribute should be eagerly loaded from
@@ -882,13 +885,20 @@ class Load(_AbstractLoad):
 
     """
 
-    _cache_key_traversal = [
+    __slots__ = (
+        "path",
+        "context",
+    )
+
+    _traverse_internals = [
         ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key),
         (
             "context",
             visitors.InternalTraversal.dp_has_cache_key_list,
         ),
+        ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean),
     ]
+    _cache_key_traversal = None
 
     path: PathRegistry
     context: Tuple["_LoadElement", ...]
@@ -899,6 +909,7 @@ class Load(_AbstractLoad):
 
         self.path = insp._path_registry
         self.context = ()
+        self.propagate_to_loaders = False
 
     def __str__(self):
         return f"Load({self.path[0]})"
@@ -908,6 +919,7 @@ class Load(_AbstractLoad):
         load = cls.__new__(cls)
         load.path = path
         load.context = ()
+        load.propagate_to_loaders = False
         return load
 
     def _adjust_for_extra_criteria(self, context):
@@ -1128,13 +1140,13 @@ class Load(_AbstractLoad):
                     self.context += (load_element,)
 
     def __getstate__(self):
-        d = self.__dict__.copy()
+        d = self._shallow_to_dict()
         d["path"] = self.path.serialize()
         return d
 
     def __setstate__(self, state):
-        self.__dict__.update(state)
-        self.path = PathRegistry.deserialize(self.path)
+        state["path"] = PathRegistry.deserialize(state["path"])
+        self._shallow_from_dict(state)
 
 
 SelfWildcardLoad = typing.TypeVar("SelfWildcardLoad", bound="_WildcardLoad")
@@ -1143,16 +1155,27 @@ SelfWildcardLoad = typing.TypeVar("SelfWildcardLoad", bound="_WildcardLoad")
 class _WildcardLoad(_AbstractLoad):
     """represent a standalone '*' load operation"""
 
-    _cache_key_traversal = [
+    __slots__ = ("strategy", "path", "local_opts")
+
+    _traverse_internals = [
         ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
+        ("path", visitors.ExtendedInternalTraversal.dp_plain_obj),
         (
             "local_opts",
             visitors.ExtendedInternalTraversal.dp_string_multi_dict,
         ),
     ]
+    cache_key_traversal = None
 
-    local_opts = util.EMPTY_DICT
-    path: Tuple[str, ...] = ()
+    strategy: Optional[Tuple[Any, ...]]
+    local_opts: Mapping[str, Any]
+    path: Tuple[str, ...]
+    propagate_to_loaders = False
+
+    def __init__(self):
+        self.path = ()
+        self.strategy = None
+        self.local_opts = util.EMPTY_DICT
 
     def _clone_for_bind_strategy(
         self,
@@ -1171,16 +1194,6 @@ class _WildcardLoad(_AbstractLoad):
             and attr in (_WILDCARD_TOKEN, _DEFAULT_TOKEN)
         )
 
-        if attr == _DEFAULT_TOKEN:
-            # for someload('*'), this currently does propagate=False,
-            # to prevent it from taking effect for lazy loads.
-            # it seems like adjusting for current_path for a lazy load etc.
-            # should be taking care of that, so that the option still takes
-            # effect for a refresh as well, but currently it does not.
-            # probably should be adjusted to be more accurate re: current
-            # path vs. refresh
-            self.propagate_to_loaders = False
-
         attr = f"{wildcard_key}:{attr}"
 
         self.strategy = strategy
@@ -1310,13 +1323,16 @@ class _WildcardLoad(_AbstractLoad):
                 return None
 
     def __getstate__(self):
-        return self.__dict__.copy()
+        d = self._shallow_to_dict()
+        return d
 
     def __setstate__(self, state):
-        self.__dict__.update(state)
+        self._shallow_from_dict(state)
 
 
-class _LoadElement(cache_key.HasCacheKey):
+class _LoadElement(
+    cache_key.HasCacheKey, traversals.HasShallowCopy, visitors.Traversible
+):
     """represents strategy information to select for a LoaderStrategy
     and pass options to it.
 
@@ -1328,40 +1344,66 @@ class _LoadElement(cache_key.HasCacheKey):
 
     """
 
-    _cache_key_traversal = [
+    __slots__ = (
+        "path",
+        "strategy",
+        "propagate_to_loaders",
+        "local_opts",
+        "_extra_criteria",
+        "_reconcile_to_other",
+    )
+    __visit_name__ = "load_element"
+
+    _traverse_internals = [
         ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key),
         ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
         (
             "local_opts",
             visitors.ExtendedInternalTraversal.dp_string_multi_dict,
         ),
+        ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
+        ("propagate_to_loaders", visitors.InternalTraversal.dp_plain_obj),
+        ("_reconcile_to_other", visitors.InternalTraversal.dp_plain_obj),
     ]
+    _cache_key_traversal = None
 
-    _extra_criteria = ()
+    _extra_criteria: Tuple[Any, ...]
 
-    _reconcile_to_other = None
-    strategy = None
+    _reconcile_to_other: Optional[bool]
+    strategy: Tuple[Any, ...]
     path: PathRegistry
-    propagate_to_loaders = False
+    propagate_to_loaders: bool
 
     local_opts: Mapping[str, Any]
 
     is_token_strategy: bool
     is_class_strategy: bool
 
+    def __hash__(self):
+        return id(self)
+
+    def __eq__(self, other):
+        return traversals.compare(self, other)
+
     @property
     def is_opts_only(self):
         return bool(self.local_opts and self.strategy is None)
 
+    def _clone(self):
+        cls = self.__class__
+        s = cls.__new__(cls)
+
+        self._shallow_copy_to(s)
+        return s
+
     def __getstate__(self):
-        d = self.__dict__.copy()
+        d = self._shallow_to_dict()
         d["path"] = self.path.serialize()
-
         return d
 
     def __setstate__(self, state):
         state["path"] = PathRegistry.deserialize(state["path"])
-        self.__dict__.update(state)
+        self._shallow_from_dict(state)
 
     def _raise_for_no_match(self, parent_loader, mapper_entities):
         path = parent_loader.path
@@ -1498,11 +1540,14 @@ class _LoadElement(cache_key.HasCacheKey):
         opt.local_opts = (
             util.immutabledict(local_opts) if local_opts else util.EMPTY_DICT
         )
+        opt._extra_criteria = ()
 
         if reconcile_to_other is not None:
             opt._reconcile_to_other = reconcile_to_other
         elif strategy is None and not local_opts:
             opt._reconcile_to_other = True
+        else:
+            opt._reconcile_to_other = None
 
         path = opt._init_path(path, attr, wildcard_key, attr_group, raiseerr)
 
@@ -1517,12 +1562,6 @@ class _LoadElement(cache_key.HasCacheKey):
     def __init__(self, path, strategy, local_opts, propagate_to_loaders):
         raise NotImplementedError()
 
-    def _clone(self):
-        cls = self.__class__
-        s = cls.__new__(cls)
-        s.__dict__ = self.__dict__.copy()
-        return s
-
     def _prepend_path_from(self, parent):
         """adjust the path of this :class:`._LoadElement` to be
         a subpath of that of the given parent :class:`_orm.Load` object's
@@ -1617,20 +1656,28 @@ class _AttributeStrategyLoad(_LoadElement):
 
     """
 
-    _cache_key_traversal = _LoadElement._cache_key_traversal + [
+    __slots__ = ("_of_type", "_path_with_polymorphic_path")
+
+    __visit_name__ = "attribute_strategy_load_element"
+
+    _traverse_internals = _LoadElement._traverse_internals + [
         ("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
-        ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
+        (
+            "_path_with_polymorphic_path",
+            visitors.ExtendedInternalTraversal.dp_has_cache_key,
+        ),
     ]
 
-    _of_type: Union["Mapper", AliasedInsp, None] = None
-    _path_with_polymorphic_path = None
+    _of_type: Union["Mapper", AliasedInsp, None]
+    _path_with_polymorphic_path: Optional[PathRegistry]
 
-    inherit_cache = True
     is_class_strategy = False
     is_token_strategy = False
 
     def _init_path(self, path, attr, wildcard_key, attr_group, raiseerr):
         assert attr is not None
+        self._of_type = None
+        self._path_with_polymorphic_path = None
         insp, _, prop = _parse_attr_argument(attr)
 
         if insp.is_property:
@@ -1832,12 +1879,14 @@ class _AttributeStrategyLoad(_LoadElement):
         return [("loader", cast(PathRegistry, effective_path).natural_path)]
 
     def __getstate__(self):
-        d = self.__dict__.copy()
+        d = super().__getstate__()
+
+        # can't pickle this.  See
+        # test_pickled.py -> test_lazyload_extra_criteria_not_supported
+        # where we should be emitting a warning for the usual case where this
+        # would be non-None
         d["_extra_criteria"] = ()
-        d["path"] = self.path.serialize()
 
-        # TODO: we hope to do this logic only at compile time so that
-        # we aren't carrying these extra attributes around
         if self._path_with_polymorphic_path:
             d[
                 "_path_with_polymorphic_path"
@@ -1854,14 +1903,19 @@ class _AttributeStrategyLoad(_LoadElement):
         return d
 
     def __setstate__(self, state):
-        state["path"] = PathRegistry.deserialize(state["path"])
-        self.__dict__.update(state)
-        if "_path_with_polymorphic_path" in state:
+        super().__setstate__(state)
+
+        if state.get("_path_with_polymorphic_path", None):
             self._path_with_polymorphic_path = PathRegistry.deserialize(
-                self._path_with_polymorphic_path
+                state["_path_with_polymorphic_path"]
             )
-        if self._of_type is not None:
-            self._of_type = inspect(self._of_type)
+        else:
+            self._path_with_polymorphic_path = None
+
+        if state.get("_of_type", None):
+            self._of_type = inspect(state["_of_type"])
+        else:
+            self._of_type = None
 
 
 class _TokenStrategyLoad(_LoadElement):
@@ -1877,6 +1931,8 @@ class _TokenStrategyLoad(_LoadElement):
 
     """
 
+    __visit_name__ = "token_strategy_load_element"
+
     inherit_cache = True
     is_class_strategy = False
     is_token_strategy = True
@@ -1962,6 +2018,8 @@ class _ClassStrategyLoad(_LoadElement):
     is_class_strategy = True
     is_token_strategy = False
 
+    __visit_name__ = "class_strategy_load_element"
+
     def _init_path(self, path, attr, wildcard_key, attr_group, raiseerr):
         return path
 
index e84517670c8ca1d50a23f63e51ca0a98cff38930..75f711007890018417c21d721e3b0d2a9e6b6727 100644 (file)
@@ -45,6 +45,7 @@ from ..sql import util as sql_util
 from ..sql import visitors
 from ..sql.annotation import SupportsCloneAnnotations
 from ..sql.base import ColumnCollection
+from ..util.langhelpers import MemoizedSlots
 
 
 all_cascades = frozenset(
@@ -609,8 +610,9 @@ class AliasedClass:
 class AliasedInsp(
     ORMEntityColumnsClauseRole,
     ORMFromClauseRole,
-    sql_base.MemoizedHasCacheKey,
+    sql_base.HasCacheKey,
     InspectionAttr,
+    MemoizedSlots,
 ):
     """Provide an inspection interface for an
     :class:`.AliasedClass` object.
@@ -650,6 +652,30 @@ class AliasedInsp(
 
     """
 
+    __slots__ = (
+        "__weakref__",
+        "_weak_entity",
+        "mapper",
+        "selectable",
+        "name",
+        "_adapt_on_names",
+        "with_polymorphic_mappers",
+        "polymorphic_on",
+        "_use_mapper_path",
+        "_base_alias",
+        "represents_outer_join",
+        "persist_selectable",
+        "local_table",
+        "_is_with_polymorphic",
+        "_with_polymorphic_entities",
+        "_adapter",
+        "_target",
+        "__clause_element__",
+        "_memoized_values",
+        "_all_column_expressions",
+        "_nest_adapters",
+    )
+
     def __init__(
         self,
         entity,
@@ -738,8 +764,7 @@ class AliasedInsp(
     is_aliased_class = True
     "always returns True"
 
-    @util.memoized_instancemethod
-    def __clause_element__(self):
+    def _memoized_method___clause_element__(self):
         return self.selectable._annotate(
             {
                 "parentmapper": self.mapper,
@@ -863,8 +888,7 @@ class AliasedInsp(
         else:
             assert False, "mapper %s doesn't correspond to %s" % (mapper, self)
 
-    @util.memoized_property
-    def _get_clause(self):
+    def _memoized_attr__get_clause(self):
         onclause, replacemap = self.mapper._get_clause
         return (
             self._adapter.traverse(onclause),
@@ -874,12 +898,10 @@ class AliasedInsp(
             },
         )
 
-    @util.memoized_property
-    def _memoized_values(self):
+    def _memoized_attr__memoized_values(self):
         return {}
 
-    @util.memoized_property
-    def _all_column_expressions(self):
+    def _memoized_attr__all_column_expressions(self):
         if self._is_with_polymorphic:
             cols_plus_keys = self.mapper._columns_plus_keys(
                 [ent.mapper for ent in self._with_polymorphic_entities]
@@ -965,6 +987,15 @@ class LoaderCriteriaOption(CriteriaOption):
 
     """
 
+    __slots__ = (
+        "root_entity",
+        "entity",
+        "deferred_where_criteria",
+        "where_criteria",
+        "include_aliases",
+        "propagate_to_loaders",
+    )
+
     _traverse_internals = [
         ("root_entity", visitors.ExtendedInternalTraversal.dp_plain_obj),
         ("entity", visitors.ExtendedInternalTraversal.dp_has_cache_key),
index 74469b0350f3070ab1a0e637711be7e183c065e5..8ae8f8f65fcbc8ce88f41865aa0884edacc97580 100644 (file)
@@ -17,6 +17,7 @@ from itertools import zip_longest
 import operator
 import re
 import typing
+from typing import TypeVar
 
 from . import roles
 from . import visitors
@@ -571,11 +572,14 @@ class CompileState:
         return decorate
 
 
+SelfGenerative = TypeVar("SelfGenerative", bound="Generative")
+
+
 class Generative(HasMemoized):
     """Provide a method-chaining pattern in conjunction with the
     @_generative decorator."""
 
-    def _generate(self):
+    def _generate(self: SelfGenerative) -> SelfGenerative:
         skip = self._memoized_keys
         cls = self.__class__
         s = cls.__new__(cls)
@@ -783,6 +787,8 @@ class Options(metaclass=_MetaOptions):
 
 
 class CacheableOptions(Options, HasCacheKey):
+    __slots__ = ()
+
     @hybridmethod
     def _gen_cache_key(self, anon_map, bindparams):
         return HasCacheKey._gen_cache_key(self, anon_map, bindparams)
@@ -797,6 +803,8 @@ class CacheableOptions(Options, HasCacheKey):
 
 
 class ExecutableOption(HasCopyInternals):
+    __slots__ = ()
+
     _annotations = util.EMPTY_DICT
 
     __visit_name__ = "executable_option"
index 8dd44dbf080aa7b9c440269594b26c769453dccb..42bd603537146ed731f72146b8b2a805c91586e4 100644 (file)
@@ -47,6 +47,11 @@ class CacheTraverseTarget(enum.Enum):
 class HasCacheKey:
     """Mixin for objects which can produce a cache key.
 
+    This class is usually in a hierarchy that starts with the
+    :class:`.HasTraverseInternals` base, but this is optional.  Currently,
+    the class should be able to work on its own without including
+    :class:`.HasTraverseInternals`.
+
     .. seealso::
 
         :class:`.CacheKey`
@@ -55,6 +60,8 @@ class HasCacheKey:
 
     """
 
+    __slots__ = ()
+
     _cache_key_traversal = NO_CACHE
 
     _is_has_cache_key = True
@@ -106,11 +113,17 @@ class HasCacheKey:
             _cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
             if _cache_key_traversal is None:
                 try:
+                    # this would be HasTraverseInternals
                     _cache_key_traversal = cls._traverse_internals
                 except AttributeError:
                     cls._generated_cache_key_traversal = NO_CACHE
                     return NO_CACHE
 
+            assert _cache_key_traversal is not NO_CACHE, (
+                f"class {cls} has _cache_key_traversal=NO_CACHE, "
+                "which conflicts with inherit_cache=True"
+            )
+
             # TODO: wouldn't we instead get this from our superclass?
             # also, our superclass may not have this yet, but in any case,
             # we'd generate for the superclass that has it.   this is a little
@@ -323,6 +336,8 @@ class HasCacheKey:
 
 
 class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
+    __slots__ = ()
+
     @HasMemoized.memoized_instancemethod
     def _generate_cache_key(self):
         return HasCacheKey._generate_cache_key(self)
index 43979b4aeefd2f60a7206b632fee991c0eb676bf..d14521ba73f1772cf359c50e9e4d8e8a7a9e7c24 100644 (file)
@@ -46,7 +46,7 @@ from .traversals import HasCopyInternals
 from .visitors import cloned_traverse
 from .visitors import InternalTraversal
 from .visitors import traverse
-from .visitors import Traversible
+from .visitors import Visitable
 from .. import exc
 from .. import inspection
 from .. import util
@@ -126,7 +126,7 @@ def literal_column(text, type_=None):
     return ColumnClause(text, type_=type_, is_literal=True)
 
 
-class CompilerElement(Traversible):
+class CompilerElement(Visitable):
     """base class for SQL elements that can be compiled to produce a
     SQL string.
 
index 2fa3a040839bdb86b16c85347db2659daf245d75..18fd1d4b81ea0a9919dfba76d85a98659d28130e 100644 (file)
@@ -10,12 +10,22 @@ import collections.abc as collections_abc
 import itertools
 from itertools import zip_longest
 import operator
+import typing
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Type
+from typing import TypeVar
 
 from . import operators
+from .cache_key import HasCacheKey
+from .visitors import _TraverseInternalsType
 from .visitors import anon_map
+from .visitors import ExtendedInternalTraversal
+from .visitors import HasTraverseInternals
 from .visitors import InternalTraversal
 from .. import util
-
+from ..util import langhelpers
 
 SKIP_TRAVERSE = util.symbol("skip_traverse")
 COMPARE_FAILED = False
@@ -47,11 +57,158 @@ def _preconfigure_traversals(target_hierarchy):
             )
 
 
+SelfHasShallowCopy = TypeVar("SelfHasShallowCopy", bound="HasShallowCopy")
+
+
+class HasShallowCopy(HasTraverseInternals):
+    """attribute-wide operations that are useful for classes that use
+    __slots__ and therefore can't operate on their attributes in a dictionary.
+
+
+    """
+
+    __slots__ = ()
+
+    if typing.TYPE_CHECKING:
+
+        def _generated_shallow_copy_traversal(
+            self: SelfHasShallowCopy, other: SelfHasShallowCopy
+        ) -> None:
+            ...
+
+        def _generated_shallow_from_dict_traversal(
+            self, d: Dict[str, Any]
+        ) -> None:
+            ...
+
+        def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]:
+            ...
+
+    @classmethod
+    def _generate_shallow_copy(
+        cls: Type[SelfHasShallowCopy],
+        internal_dispatch: _TraverseInternalsType,
+        method_name: str,
+    ) -> Callable[[SelfHasShallowCopy, SelfHasShallowCopy], None]:
+        code = "\n".join(
+            f"    other.{attrname} = self.{attrname}"
+            for attrname, _ in internal_dispatch
+        )
+        meth_text = f"def {method_name}(self, other):\n{code}\n"
+        return langhelpers._exec_code_in_env(meth_text, {}, method_name)
+
+    @classmethod
+    def _generate_shallow_to_dict(
+        cls: Type[SelfHasShallowCopy],
+        internal_dispatch: _TraverseInternalsType,
+        method_name: str,
+    ) -> Callable[[SelfHasShallowCopy], Dict[str, Any]]:
+        code = ",\n".join(
+            f"    '{attrname}': self.{attrname}"
+            for attrname, _ in internal_dispatch
+        )
+        meth_text = f"def {method_name}(self):\n    return {{{code}}}\n"
+        return langhelpers._exec_code_in_env(meth_text, {}, method_name)
+
+    @classmethod
+    def _generate_shallow_from_dict(
+        cls: Type[SelfHasShallowCopy],
+        internal_dispatch: _TraverseInternalsType,
+        method_name: str,
+    ) -> Callable[[SelfHasShallowCopy, Dict[str, Any]], None]:
+        code = "\n".join(
+            f"    self.{attrname} = d['{attrname}']"
+            for attrname, _ in internal_dispatch
+        )
+        meth_text = f"def {method_name}(self, d):\n{code}\n"
+        return langhelpers._exec_code_in_env(meth_text, {}, method_name)
+
+    def _shallow_from_dict(self, d: Dict) -> None:
+        cls = self.__class__
+
+        try:
+            shallow_from_dict = cls.__dict__[
+                "_generated_shallow_from_dict_traversal"
+            ]
+        except KeyError:
+            shallow_from_dict = (
+                cls._generated_shallow_from_dict_traversal  # type: ignore
+            ) = self._generate_shallow_from_dict(
+                cls._traverse_internals,
+                "_generated_shallow_from_dict_traversal",
+            )
+
+        shallow_from_dict(self, d)
+
+    def _shallow_to_dict(self) -> Dict[str, Any]:
+        cls = self.__class__
+
+        try:
+            shallow_to_dict = cls.__dict__[
+                "_generated_shallow_to_dict_traversal"
+            ]
+        except KeyError:
+            shallow_to_dict = (
+                cls._generated_shallow_to_dict_traversal  # type: ignore
+            ) = self._generate_shallow_to_dict(
+                cls._traverse_internals, "_generated_shallow_to_dict_traversal"
+            )
+
+        return shallow_to_dict(self)
+
+    def _shallow_copy_to(self: SelfHasShallowCopy, other: SelfHasShallowCopy):
+        cls = self.__class__
+
+        try:
+            shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"]
+        except KeyError:
+            shallow_copy = (
+                cls._generated_shallow_copy_traversal  # type: ignore
+            ) = self._generate_shallow_copy(
+                cls._traverse_internals, "_generated_shallow_copy_traversal"
+            )
+
+        shallow_copy(self, other)
+
+    def _clone(self: SelfHasShallowCopy, **kw) -> SelfHasShallowCopy:
+        """Create a shallow copy"""
+        c = self.__class__.__new__(self.__class__)
+        self._shallow_copy_to(c)
+        return c
+
+
+SelfGenerativeOnTraversal = TypeVar(
+    "SelfGenerativeOnTraversal", bound="GenerativeOnTraversal"
+)
+
+
+class GenerativeOnTraversal(HasShallowCopy):
+    """Supplies Generative behavior but making use of traversals to shallow
+    copy.
+
+    .. seealso::
+
+        :class:`sqlalchemy.sql.base.Generative`
+
+
+    """
+
+    __slots__ = ()
+
+    def _generate(
+        self: SelfGenerativeOnTraversal,
+    ) -> SelfGenerativeOnTraversal:
+        cls = self.__class__
+        s = cls.__new__(cls)
+        self._shallow_copy_to(s)
+        return s
+
+
 def _clone(element, **kw):
     return element._clone()
 
 
-class HasCopyInternals:
+class HasCopyInternals(HasTraverseInternals):
     __slots__ = ()
 
     def _clone(self, **kw):
@@ -304,7 +461,9 @@ def _resolve_name_for_compare(element, name, anon_map, **kw):
     return name
 
 
-class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
+class TraversalComparatorStrategy(
+    ExtendedInternalTraversal, util.MemoizedSlots
+):
     __slots__ = "stack", "cache", "anon_map"
 
     def __init__(self):
@@ -377,6 +536,10 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
                     continue
 
                 dispatch = self.dispatch(left_visit_sym)
+                assert dispatch, (
+                    f"{self.__class__} has no dispatch for "
+                    f"'{self._dispatch_lookup[left_visit_sym]}'"
+                )
                 left_child = operator.attrgetter(left_attrname)(left)
                 right_child = operator.attrgetter(right_attrname)(right)
                 if left_child is None:
@@ -517,6 +680,46 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
     ):
         return left == right
 
+    def visit_string_multi_dict(
+        self, attrname, left_parent, left, right_parent, right, **kw
+    ):
+
+        for lk, rk in zip_longest(
+            sorted(left.keys()), sorted(right.keys()), fillvalue=(None, None)
+        ):
+            if lk != rk:
+                return COMPARE_FAILED
+
+            lv, rv = left[lk], right[rk]
+
+            lhc = isinstance(left, HasCacheKey)
+            rhc = isinstance(right, HasCacheKey)
+            if lhc and rhc:
+                if lv._gen_cache_key(
+                    self.anon_map[0], []
+                ) != rv._gen_cache_key(self.anon_map[1], []):
+                    return COMPARE_FAILED
+            elif lhc != rhc:
+                return COMPARE_FAILED
+            elif lv != rv:
+                return COMPARE_FAILED
+
+    def visit_multi(
+        self, attrname, left_parent, left, right_parent, right, **kw
+    ):
+
+        lhc = isinstance(left, HasCacheKey)
+        rhc = isinstance(right, HasCacheKey)
+        if lhc and rhc:
+            if left._gen_cache_key(
+                self.anon_map[0], []
+            ) != right._gen_cache_key(self.anon_map[1], []):
+                return COMPARE_FAILED
+        elif lhc != rhc:
+            return COMPARE_FAILED
+        else:
+            return left == right
+
     def visit_anon_name(
         self, attrname, left_parent, left, right_parent, right, **kw
     ):
index 70c4dc133275d54e9c97028396543c72d0501581..78384782b80c17e83c55125fed35ecccfcb3d6ee 100644 (file)
@@ -26,11 +26,14 @@ https://techspot.zzzeek.org/2008/01/23/expression-transformations/ .
 from collections import deque
 import itertools
 import operator
+from typing import List
+from typing import Tuple
 
 from .. import exc
 from .. import util
 from ..util import langhelpers
 from ..util import symbol
+from ..util.langhelpers import _symbol
 
 try:
     from sqlalchemy.cyextension.util import cache_anon_map as anon_map  # noqa
@@ -43,14 +46,67 @@ __all__ = [
     "traverse",
     "cloned_traverse",
     "replacement_traverse",
-    "Traversible",
+    "Visitable",
     "ExternalTraversal",
     "InternalTraversal",
 ]
 
+_TraverseInternalsType = List[Tuple[str, _symbol]]
 
-class Traversible:
-    """Base class for visitable objects."""
+
+class HasTraverseInternals:
+    """base for classes that have a "traverse internals" element,
+    which defines all kinds of ways of traversing the elements of an object.
+
+    """
+
+    __slots__ = ()
+
+    _traverse_internals: _TraverseInternalsType
+
+    @util.preload_module("sqlalchemy.sql.traversals")
+    def get_children(self, omit_attrs=(), **kw):
+        r"""Return immediate child :class:`.visitors.Visitable`
+        elements of this :class:`.visitors.Visitable`.
+
+        This is used for visit traversal.
+
+        \**kw may contain flags that change the collection that is
+        returned, for example to return a subset of items in order to
+        cut down on larger traversals, or to return child items from a
+        different context (such as schema-level collections instead of
+        clause-level).
+
+        """
+
+        traversals = util.preloaded.sql_traversals
+
+        try:
+            traverse_internals = self._traverse_internals
+        except AttributeError:
+            # user-defined classes may not have a _traverse_internals
+            return []
+
+        dispatch = traversals._get_children.run_generated_dispatch
+        return itertools.chain.from_iterable(
+            meth(obj, **kw)
+            for attrname, obj, meth in dispatch(
+                self, traverse_internals, "_generated_get_children_traversal"
+            )
+            if attrname not in omit_attrs and obj is not None
+        )
+
+
+class Visitable:
+    """Base class for visitable objects.
+
+    .. versionchanged:: 2.0  The :class:`.Visitable` class was named
+       :class:`.Traversible` in the 1.4 series; the name is changed back
+       to :class:`.Visitable` in 2.0 which is what it was prior to 1.4.
+
+       Both names remain importable in both 1.4 and 2.0 versions.
+
+    """
 
     __slots__ = ()
 
@@ -120,38 +176,6 @@ class Traversible:
         # allow generic classes in py3.9+
         return cls
 
-    @util.preload_module("sqlalchemy.sql.traversals")
-    def get_children(self, omit_attrs=(), **kw):
-        r"""Return immediate child :class:`.visitors.Traversible`
-        elements of this :class:`.visitors.Traversible`.
-
-        This is used for visit traversal.
-
-        \**kw may contain flags that change the collection that is
-        returned, for example to return a subset of items in order to
-        cut down on larger traversals, or to return child items from a
-        different context (such as schema-level collections instead of
-        clause-level).
-
-        """
-
-        traversals = util.preloaded.sql_traversals
-
-        try:
-            traverse_internals = self._traverse_internals
-        except AttributeError:
-            # user-defined classes may not have a _traverse_internals
-            return []
-
-        dispatch = traversals._get_children.run_generated_dispatch
-        return itertools.chain.from_iterable(
-            meth(obj, **kw)
-            for attrname, obj, meth in dispatch(
-                self, traverse_internals, "_generated_get_children_traversal"
-            )
-            if attrname not in omit_attrs and obj is not None
-        )
-
 
 class _HasTraversalDispatch:
     r"""Define infrastructure for the :class:`.InternalTraversal` class.
@@ -261,14 +285,14 @@ class InternalTraversal(_HasTraversalDispatch):
     :class:`.InternalTraversible` will have the following methods automatically
     implemented:
 
-    * :meth:`.Traversible.get_children`
+    * :meth:`.HasTraverseInternals.get_children`
 
-    * :meth:`.Traversible._copy_internals`
+    * :meth:`.HasTraverseInternals._copy_internals`
 
-    * :meth:`.Traversible._gen_cache_key`
+    * :meth:`.HasCacheKey._gen_cache_key`
 
     Subclasses can also implement these methods directly, particularly for the
-    :meth:`.Traversible._copy_internals` method, when special steps
+    :meth:`.HasTraverseInternals._copy_internals` method, when special steps
     are needed.
 
     .. versionadded:: 1.4
@@ -625,7 +649,8 @@ class ReplacingExternalTraversal(CloningExternalTraversal):
 
 
 # backwards compatibility
-Visitable = Traversible
+Traversible = Visitable
+
 ClauseVisitor = ExternalTraversal
 CloningVisitor = CloningExternalTraversal
 ReplacingCloningVisitor = ReplacingExternalTraversal
index 80ef3458c02f77153e0caf1162c9963c97df73eb..8b65fb4cf66d34b517768666123c251a9316de75 100644 (file)
@@ -1156,7 +1156,9 @@ class MemoizedSlots:
         raise AttributeError(key)
 
     def __getattr__(self, key):
-        if key.startswith("_memoized"):
+        if key.startswith("_memoized_attr_") or key.startswith(
+            "_memoized_method_"
+        ):
             raise AttributeError(key)
         elif hasattr(self, "_memoized_attr_%s" % key):
             value = getattr(self, "_memoized_attr_%s" % key)()
index 7098220801de01e98808dd77be1c50e58dffbd30..e74ffeced4385ce5b919cbac3917b4d0a961a358 100644 (file)
@@ -1,3 +1,5 @@
+import pickle
+
 import sqlalchemy as sa
 from sqlalchemy import Column
 from sqlalchemy import ForeignKey
@@ -5,6 +7,7 @@ from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import select
 from sqlalchemy import String
+from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import attributes
@@ -1241,31 +1244,61 @@ class OptionsNoPropTestInh(_Polymorphic):
         eq_(loader.path, orig_path)
 
 
-class PickleTest(PathTest, QueryTest):
-    def _option_fixture(self, *arg):
-        return strategy_options._generate_from_keys(
-            strategy_options.Load.joinedload, arg, True, {}
+class PickleTest(fixtures.MappedTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "users",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(30), nullable=False),
+        )
+        Table(
+            "addresses",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("user_id", None, ForeignKey("users.id")),
+            Column("email_address", String(50), nullable=False),
+        )
+
+    @testing.fixture
+    def user_address_fixture(self, registry):
+        from sqlalchemy.testing.pickleable import User, Address
+
+        registry.map_imperatively(
+            User,
+            self.tables.users,
+            properties={"addresses": relationship(Address)},
         )
+        registry.map_imperatively(Address, self.tables.addresses)
 
-    def test_modern_opt_getstate(self):
-        User = self.classes.User
+        return User, Address
 
-        opt = self._option_fixture(User.addresses)
+    def test_slots(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        opt = joinedload(User.addresses)
+
+        assert not hasattr(opt, "__dict__")
+        assert not hasattr(opt.context[0], "__dict__")
+
+    def test_pickle_relationship_loader(self, user_address_fixture):
+        User, Address = user_address_fixture
 
-        q1 = fixture_session().query(User).options(opt)
-        c1 = q1._compile_context()
+        for i in range(3):
+            opt = joinedload(User.addresses)
 
-        state = opt.__getstate__()
+            q1 = fixture_session().query(User).options(opt)
+            c1 = q1._compile_context()
 
-        opt2 = Load.__new__(Load)
-        opt2.__setstate__(state)
+            pickled = pickle.dumps(opt)
 
-        eq_(opt.__dict__, opt2.__dict__)
+            opt2 = pickle.loads(pickled)
 
-        q2 = fixture_session().query(User).options(opt2)
-        c2 = q2._compile_context()
+            q2 = fixture_session().query(User).options(opt2)
+            c2 = q2._compile_context()
 
-        eq_(c1.attributes, c2.attributes)
+            eq_(c1.attributes, c2.attributes)
 
 
 class LocalOptsTest(PathTest, QueryTest):
index 8e4c6ab17aff2f1685f6a7e9c5e86d92d6dfc295..a4250e375cd4a2b14cf2580becc93746d2e9922e 100644 (file)
@@ -271,7 +271,6 @@ class PickleTest(fixtures.MappedTest):
         sess.add(u1)
         sess.commit()
         sess.close()
-
         u1 = (
             sess.query(User)
             .options(