]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pep484: attributes and related
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 28 Apr 2022 20:19:43 +0000 (16:19 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 3 May 2022 19:58:45 +0000 (15:58 -0400)
also implements __slots__ for QueryableAttribute,
InstrumentedAttribute, Relationship.Comparator.

Change-Id: I47e823160706fc35a616f1179a06c7864089e5b5

27 files changed:
doc/build/orm/internals.rst
examples/custom_attributes/custom_management.py
lib/sqlalchemy/event/base.py
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/_typing.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/dynamic.py
lib/sqlalchemy/orm/instrumentation.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/cache_key.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/util/_collections.py
lib/sqlalchemy/util/langhelpers.py
lib/sqlalchemy/util/typing.py
test/ext/test_extendedattr.py
test/ext/test_mutable.py
test/orm/test_attributes.py
test/orm/test_collection.py
test/orm/test_deprecations.py
test/orm/test_instrumentation.py

index f251e43bd0a052cc1d44e560850e9a2b196acc5b..9aa3b2db67998cc2cabea97f47176a39d2a00922 100644 (file)
@@ -37,7 +37,7 @@ sections, are listed here.
 
 .. autodata:: CompositeProperty
 
-.. autoclass:: AttributeEvent
+.. autoclass:: AttributeEventToken
     :members:
 
 .. autoclass:: IdentityMap
index 5ee5a45f83c82a0b534bb79a441fde775f5460b7..aa9ea7a68998eb1ccf95f0917ea1ba117c806696 100644 (file)
@@ -17,7 +17,7 @@ from sqlalchemy import MetaData
 from sqlalchemy import Table
 from sqlalchemy import Text
 from sqlalchemy.ext.instrumentation import InstrumentationManager
-from sqlalchemy.orm import mapper
+from sqlalchemy.orm import registry as _reg
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.orm.attributes import del_attribute
@@ -26,6 +26,9 @@ from sqlalchemy.orm.attributes import set_attribute
 from sqlalchemy.orm.instrumentation import is_instrumented
 
 
+registry = _reg()
+
+
 class MyClassState(InstrumentationManager):
     def get_instance_dict(self, class_, instance):
         return instance._goofy_dict
@@ -97,9 +100,9 @@ if __name__ == "__main__":
     class B(MyClass):
         pass
 
-    mapper(A, table1, properties={"bs": relationship(B)})
+    registry.map_imperatively(A, table1, properties={"bs": relationship(B)})
 
-    mapper(B, table2)
+    registry.map_imperatively(B, table2)
 
     a1 = A(name="a1", bs=[B(name="b1"), B(name="b2")])
 
index c16f6870be2ba592294fbed01adba7b18e2d9c02..83b34a17fc9a5c8576c1570314a052012fe82dae 100644 (file)
@@ -108,10 +108,12 @@ class _Dispatch(_DispatchCommon[_ET]):
 
     """
 
-    # In one ORM edge case, an attribute is added to _Dispatch,
-    # so __dict__ is used in just that case and potentially others.
+    # "active_history" is an ORM case we add here.   ideally a better
+    # system would be in place for ad-hoc attributes.
     __slots__ = "_parent", "_instance_cls", "__dict__", "_empty_listeners"
 
+    _active_history: bool
+
     _empty_listener_reg: MutableMapping[
         Type[_ET], Dict[str, _EmptyListener[_ET]]
     ] = weakref.WeakKeyDictionary()
index 7200414a183116fd001e1b6edb1e40e8ca507b7c..ea558495b41b8a769fa345aad31c6541597d1fac 100644 (file)
@@ -824,15 +824,14 @@ from ..orm import attributes
 from ..orm import InspectionAttrExtensionType
 from ..orm import interfaces
 from ..orm import ORMDescriptor
+from ..sql import roles
 from ..sql._typing import is_has_clause_element
 from ..sql.elements import ColumnElement
 from ..sql.elements import SQLCoreOperations
 from ..util.typing import Literal
 from ..util.typing import Protocol
 
-
 if TYPE_CHECKING:
-    from ..orm._typing import _ORMColumnExprArgument
     from ..orm.interfaces import MapperProperty
     from ..orm.util import AliasedInsp
     from ..sql._typing import _ColumnExpressionArgument
@@ -840,7 +839,6 @@ if TYPE_CHECKING:
     from ..sql._typing import _HasClauseElement
     from ..sql._typing import _InfoType
     from ..sql.operators import OperatorType
-    from ..sql.roles import ColumnsClauseRole
 
 _T = TypeVar("_T", bound=Any)
 _T_co = TypeVar("_T_co", bound=Any, covariant=True)
@@ -1290,7 +1288,7 @@ class Comparator(interfaces.PropComparator[_T]):
     ):
         self.expression = expression
 
-    def __clause_element__(self) -> _ORMColumnExprArgument[_T]:
+    def __clause_element__(self) -> roles.ColumnsClauseRole:
         expr = self.expression
         if is_has_clause_element(expr):
             ret_expr = expr.__clause_element__()
@@ -1306,7 +1304,7 @@ class Comparator(interfaces.PropComparator[_T]):
             assert isinstance(ret_expr, ColumnElement)
         return ret_expr
 
-    @util.ro_non_memoized_property
+    @util.non_memoized_property
     def property(self) -> Optional[interfaces.MapperProperty[_T]]:
         return None
 
@@ -1345,8 +1343,11 @@ class ExprComparator(Comparator[_T]):
         else:
             return [(self.expression, value)]
 
-    @util.ro_non_memoized_property
+    @util.non_memoized_property
     def property(self) -> Optional[MapperProperty[_T]]:
+        # this accessor is not normally used, however is accessed by things
+        # like ORM synonyms if the hybrid is used in this context; the
+        # .property attribute is not necessarily accessible
         return self.expression.property  # type: ignore
 
     def operate(
index 58900ab99ac920df0ad081237c8129efb9f6216f..b7d1df532234f9de926672b0157b1183da92b37d 100644 (file)
@@ -41,7 +41,7 @@ from ._orm_constructors import synonym as synonym
 from ._orm_constructors import SynonymProperty as SynonymProperty
 from ._orm_constructors import with_loader_criteria as with_loader_criteria
 from ._orm_constructors import with_polymorphic as with_polymorphic
-from .attributes import AttributeEvent as AttributeEvent
+from .attributes import AttributeEventToken as AttributeEventToken
 from .attributes import InstrumentedAttribute as InstrumentedAttribute
 from .attributes import QueryableAttribute as QueryableAttribute
 from .base import class_mapper as class_mapper
index 339844f14756da63f55c3529c4864e3dad7838d8..29d82340aba3fe0225712a4f3096fcd4e4f6f750 100644 (file)
@@ -47,6 +47,8 @@ if TYPE_CHECKING:
 
 _InternalEntityType = Union["Mapper[_T]", "AliasedInsp[_T]"]
 
+_ExternalEntityType = Union[Type[_T], "AliasedClass[_T]"]
+
 _EntityType = Union[
     Type[_T], "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]"
 ]
index 9a6e94e228a09eb31e42422d336b952513debc66..9aeaeaa2726d4987d7d394eeb28297b930ebf2fa 100644 (file)
@@ -4,7 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+# mypy: allow-untyped-defs, allow-untyped-calls
 
 """Defines instrumentation for class attributes and their interaction
 with instances.
@@ -17,16 +17,18 @@ defines a large part of the ORM's interactivity.
 
 from __future__ import annotations
 
-from collections import namedtuple
+import dataclasses
 import operator
 from typing import Any
 from typing import Callable
-from typing import Collection
+from typing import cast
+from typing import ClassVar
 from typing import Dict
 from typing import List
 from typing import NamedTuple
 from typing import Optional
 from typing import overload
+from typing import Sequence
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
@@ -36,6 +38,7 @@ from typing import Union
 from . import collections
 from . import exc as orm_exc
 from . import interfaces
+from ._typing import insp_is_aliased_class
 from .base import ATTR_EMPTY
 from .base import ATTR_WAS_SET
 from .base import CALLABLES_OK
@@ -45,6 +48,7 @@ from .base import instance_dict as instance_dict
 from .base import instance_state as instance_state
 from .base import instance_str
 from .base import LOAD_AGAINST_COMMITTED
+from .base import LoaderCallableStatus
 from .base import manager_of_class as manager_of_class
 from .base import Mapped as Mapped  # noqa
 from .base import NEVER_SET  # noqa
@@ -70,17 +74,41 @@ from .. import event
 from .. import exc
 from .. import inspection
 from .. import util
+from ..event import dispatcher
+from ..event import EventTarget
 from ..sql import base as sql_base
 from ..sql import cache_key
+from ..sql import coercions
 from ..sql import roles
-from ..sql import traversals
 from ..sql import visitors
+from ..util.typing import Literal
+from ..util.typing import TypeGuard
 
 if TYPE_CHECKING:
+    from ._typing import _EntityType
+    from ._typing import _ExternalEntityType
+    from ._typing import _InstanceDict
+    from ._typing import _InternalEntityType
+    from ._typing import _LoaderCallable
+    from ._typing import _O
+    from .collections import _AdaptedCollectionProtocol
+    from .collections import CollectionAdapter
+    from .dynamic import DynamicAttributeImpl
     from .interfaces import MapperProperty
+    from .relationships import Relationship
     from .state import InstanceState
-    from ..sql.dml import _DMLColumnElement
+    from .util import AliasedInsp
+    from ..event.base import _Dispatch
+    from ..sql._typing import _ColumnExpressionArgument
+    from ..sql._typing import _DMLColumnArgument
+    from ..sql._typing import _InfoType
+    from ..sql._typing import _PropagateAttrsType
+    from ..sql.annotation import _AnnotationDict
     from ..sql.elements import ColumnElement
+    from ..sql.elements import Label
+    from ..sql.operators import OperatorType
+    from ..sql.selectable import FromClause
+
 
 _T = TypeVar("_T")
 
@@ -89,19 +117,27 @@ class NoKey(str):
     pass
 
 
+_AllPendingType = List[Tuple[Optional["InstanceState[Any]"], Optional[object]]]
+
 NO_KEY = NoKey("no name")
 
+SelfQueryableAttribute = TypeVar(
+    "SelfQueryableAttribute", bound="QueryableAttribute[Any]"
+)
+
 
 @inspection._self_inspects
 class QueryableAttribute(
+    roles.ExpressionElementRole[_T],
     interfaces._MappedAttribute[_T],
     interfaces.InspectionAttr,
     interfaces.PropComparator[_T],
-    traversals.HasCopyInternals,
     roles.JoinTargetRole,
     roles.OnClauseRole,
     sql_base.Immutable,
-    cache_key.MemoizedHasCacheKey,
+    cache_key.SlotsMemoizedHasCacheKey,
+    util.MemoizedSlots,
+    EventTarget,
 ):
     """Base class for :term:`descriptor` objects that intercept
     attribute events on behalf of a :class:`.MapperProperty`
@@ -121,9 +157,33 @@ class QueryableAttribute(
         :attr:`_orm.Mapper.attrs`
     """
 
+    __slots__ = (
+        "class_",
+        "key",
+        "impl",
+        "comparator",
+        "property",
+        "parent",
+        "expression",
+        "_of_type",
+        "_extra_criteria",
+        "_slots_dispatch",
+        "_propagate_attrs",
+        "_doc",
+    )
+
     is_attribute = True
 
+    dispatch: dispatcher[QueryableAttribute[_T]]
+
+    class_: _ExternalEntityType[Any]
+    key: str
+    parententity: _InternalEntityType[Any]
     impl: AttributeImpl
+    comparator: interfaces.PropComparator[_T]
+    _of_type: Optional[_InternalEntityType[Any]]
+    _extra_criteria: Tuple[ColumnElement[bool], ...]
+    _doc: Optional[str]
 
     # PropComparator has a __visit_name__ to participate within
     # traversals.   Disambiguate the attribute vs. a comparator.
@@ -131,21 +191,30 @@ class QueryableAttribute(
 
     def __init__(
         self,
-        class_,
-        key,
-        parententity,
-        impl=None,
-        comparator=None,
-        of_type=None,
-        extra_criteria=(),
+        class_: _ExternalEntityType[_O],
+        key: str,
+        parententity: _InternalEntityType[_O],
+        comparator: interfaces.PropComparator[_T],
+        impl: Optional[AttributeImpl] = None,
+        of_type: Optional[_InternalEntityType[Any]] = None,
+        extra_criteria: Tuple[ColumnElement[bool], ...] = (),
     ):
         self.class_ = class_
         self.key = key
-        self._parententity = parententity
-        self.impl = impl
+
+        self._parententity = self.parent = parententity
+
+        # this attribute is non-None after mappers are set up, however in the
+        # interim class manager setup, there's a check for None to see if it
+        # needs to be populated, so we assign None here leaving the attribute
+        # in a temporarily not-type-correct state
+        self.impl = impl  # type: ignore
+
+        assert comparator is not None
         self.comparator = comparator
         self._of_type = of_type
         self._extra_criteria = extra_criteria
+        self._doc = None
 
         manager = opt_manager_of_class(class_)
         # manager is None in the case of AliasedClass
@@ -156,7 +225,7 @@ class QueryableAttribute(
                 if key in base:
                     self.dispatch._update(base[key].dispatch)
                     if base[key].dispatch._active_history:
-                        self.dispatch._active_history = True
+                        self.dispatch._active_history = True  # type: ignore
 
     _cache_key_traversal = [
         ("key", visitors.ExtendedInternalTraversal.dp_string),
@@ -165,7 +234,7 @@ class QueryableAttribute(
         ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
     ]
 
-    def __reduce__(self):
+    def __reduce__(self) -> Any:
         # this method is only used in terms of the
         # sqlalchemy.ext.serializer extension
         return (
@@ -178,21 +247,19 @@ class QueryableAttribute(
             ),
         )
 
-    @util.memoized_property
-    def _supports_population(self):
-        return self.impl.supports_population
-
     @property
-    def _impl_uses_objects(self):
+    def _impl_uses_objects(self) -> bool:
         return self.impl.uses_objects
 
-    def get_history(self, instance, passive=PASSIVE_OFF):
+    def get_history(
+        self, instance: Any, passive: PassiveFlag = PASSIVE_OFF
+    ) -> History:
         return self.impl.get_history(
             instance_state(instance), instance_dict(instance), passive
         )
 
-    @util.memoized_property
-    def info(self):
+    @property
+    def info(self) -> _InfoType:
         """Return the 'info' dictionary for the underlying SQL element.
 
         The behavior here is as follows:
@@ -233,27 +300,28 @@ class QueryableAttribute(
         """
         return self.comparator.info
 
-    @util.memoized_property
-    def parent(self):
-        """Return an inspection instance representing the parent.
+    parent: _InternalEntityType[Any]
+    """Return an inspection instance representing the parent.
 
-        This will be either an instance of :class:`_orm.Mapper`
-        or :class:`.AliasedInsp`, depending upon the nature
-        of the parent entity which this attribute is associated
-        with.
+    This will be either an instance of :class:`_orm.Mapper`
+    or :class:`.AliasedInsp`, depending upon the nature
+    of the parent entity which this attribute is associated
+    with.
 
-        """
-        return inspection.inspect(self._parententity)
+    """
 
-    @util.memoized_property
-    def expression(self):
-        """The SQL expression object represented by this
-        :class:`.QueryableAttribute`.
+    expression: ColumnElement[_T]
+    """The SQL expression object represented by this
+    :class:`.QueryableAttribute`.
 
-        This will typically be an instance of a :class:`_sql.ColumnElement`
-        subclass representing a column expression.
+    This will typically be an instance of a :class:`_sql.ColumnElement`
+    subclass representing a column expression.
+
+    """
+
+    def _memoized_attr_expression(self) -> ColumnElement[_T]:
+        annotations: _AnnotationDict
 
-        """
         if self.key is NO_KEY:
             annotations = {"entity_namespace": self._entity_namespace}
         else:
@@ -265,6 +333,8 @@ class QueryableAttribute(
 
         ce = self.comparator.__clause_element__()
         try:
+            if TYPE_CHECKING:
+                assert isinstance(ce, ColumnElement)
             anno = ce._annotate
         except AttributeError as ae:
             raise exc.InvalidRequestError(
@@ -275,29 +345,42 @@ class QueryableAttribute(
         else:
             return anno(annotations)
 
+    def _memoized_attr__propagate_attrs(self) -> _PropagateAttrsType:
+        # this suits the case in coercions where we don't actually
+        # call ``__clause_element__()`` but still need to get
+        # resolved._propagate_attrs.  See #6558.
+        return util.immutabledict(
+            {
+                "compile_state_plugin": "orm",
+                "plugin_subject": self._parentmapper,
+            }
+        )
+
     @property
-    def _entity_namespace(self):
+    def _entity_namespace(self) -> _InternalEntityType[Any]:
         return self._parententity
 
     @property
-    def _annotations(self):
+    def _annotations(self) -> _AnnotationDict:
         return self.__clause_element__()._annotations
 
     def __clause_element__(self) -> ColumnElement[_T]:
         return self.expression
 
     @property
-    def _from_objects(self):
+    def _from_objects(self) -> List[FromClause]:
         return self.expression._from_objects
 
     def _bulk_update_tuples(
         self, value: Any
-    ) -> List[Tuple[_DMLColumnElement, Any]]:
+    ) -> Sequence[Tuple[_DMLColumnArgument, Any]]:
         """Return setter tuples for a bulk UPDATE."""
 
         return self.comparator._bulk_update_tuples(value)
 
-    def adapt_to_entity(self, adapt_to_entity):
+    def adapt_to_entity(
+        self: SelfQueryableAttribute, adapt_to_entity: AliasedInsp[Any]
+    ) -> SelfQueryableAttribute:
         assert not self._of_type
         return self.__class__(
             adapt_to_entity.entity,
@@ -307,7 +390,7 @@ class QueryableAttribute(
             parententity=adapt_to_entity,
         )
 
-    def of_type(self, entity):
+    def of_type(self, entity: _EntityType[_T]) -> QueryableAttribute[_T]:
         return QueryableAttribute(
             self.class_,
             self.key,
@@ -318,18 +401,28 @@ class QueryableAttribute(
             extra_criteria=self._extra_criteria,
         )
 
-    def and_(self, *other):
+    def and_(
+        self, *clauses: _ColumnExpressionArgument[bool]
+    ) -> interfaces.PropComparator[bool]:
+        if TYPE_CHECKING:
+            assert isinstance(self.comparator, Relationship.Comparator)
+
+        exprs = tuple(
+            coercions.expect(roles.WhereHavingRole, clause)
+            for clause in util.coerce_generator_arg(clauses)
+        )
+
         return QueryableAttribute(
             self.class_,
             self.key,
             self._parententity,
             impl=self.impl,
-            comparator=self.comparator.and_(*other),
+            comparator=self.comparator.and_(*exprs),
             of_type=self._of_type,
-            extra_criteria=self._extra_criteria + other,
+            extra_criteria=self._extra_criteria + exprs,
         )
 
-    def _clone(self, **kw):
+    def _clone(self, **kw: Any) -> QueryableAttribute[_T]:
         return QueryableAttribute(
             self.class_,
             self.key,
@@ -340,19 +433,30 @@ class QueryableAttribute(
             extra_criteria=self._extra_criteria,
         )
 
-    def label(self, name):
+    def label(self, name: Optional[str]) -> Label[_T]:
         return self.__clause_element__().label(name)
 
-    def operate(self, op, *other, **kwargs):
-        return op(self.comparator, *other, **kwargs)
+    def operate(
+        self, op: OperatorType, *other: Any, **kwargs: Any
+    ) -> ColumnElement[Any]:
+        return op(self.comparator, *other, **kwargs)  # type: ignore[return-value]  # noqa: E501
 
-    def reverse_operate(self, op, other, **kwargs):
-        return op(other, self.comparator, **kwargs)
+    def reverse_operate(
+        self, op: OperatorType, other: Any, **kwargs: Any
+    ) -> ColumnElement[Any]:
+        return op(other, self.comparator, **kwargs)  # type: ignore[return-value]  # noqa: E501
 
-    def hasparent(self, state, optimistic=False):
+    def hasparent(
+        self, state: InstanceState[Any], optimistic: bool = False
+    ) -> bool:
         return self.impl.hasparent(state, optimistic=optimistic) is not False
 
-    def __getattr__(self, key):
+    def __getattr__(self, key: str) -> Any:
+        try:
+            return util.MemoizedSlots.__getattr__(self, key)
+        except AttributeError:
+            pass
+
         try:
             return getattr(self.comparator, key)
         except AttributeError as err:
@@ -367,27 +471,22 @@ class QueryableAttribute(
                 )
             ) from err
 
-    def __str__(self):
-        return "%s.%s" % (self.class_.__name__, self.key)
+    def __str__(self) -> str:
+        return f"{self.class_.__name__}.{self.key}"
 
-    @util.memoized_property
-    def property(self) -> MapperProperty[_T]:
-        """Return the :class:`.MapperProperty` associated with this
-        :class:`.QueryableAttribute`.
-
-
-        Return values here will commonly be instances of
-        :class:`.ColumnProperty` or :class:`.Relationship`.
-
-
-        """
+    def _memoized_attr_property(self) -> Optional[MapperProperty[Any]]:
         return self.comparator.property
 
 
-def _queryable_attribute_unreduce(key, mapped_class, parententity, entity):
+def _queryable_attribute_unreduce(
+    key: str,
+    mapped_class: Type[_O],
+    parententity: _InternalEntityType[_O],
+    entity: _ExternalEntityType[Any],
+) -> Any:
     # this method is only used in terms of the
     # sqlalchemy.ext.serializer extension
-    if parententity.is_aliased_class:
+    if insp_is_aliased_class(parententity):
         return entity._get_from_serialized(key, mapped_class, parententity)
     else:
         return getattr(entity, key)
@@ -402,45 +501,60 @@ class InstrumentedAttribute(QueryableAttribute[_T]):
 
     """
 
+    __slots__ = ()
+
     inherit_cache = True
 
-    def __set__(self, instance, value):
+    #    if not TYPE_CHECKING:
+
+    @property  # type: ignore
+    def __doc__(self) -> Optional[str]:  # type: ignore
+        return self._doc
+
+    @__doc__.setter
+    def __doc__(self, value: Optional[str]) -> None:
+        self._doc = value
+
+    def __set__(self, instance: object, value: Any) -> None:
         self.impl.set(
             instance_state(instance), instance_dict(instance), value, None
         )
 
-    def __delete__(self, instance):
+    def __delete__(self, instance: object) -> None:
         self.impl.delete(instance_state(instance), instance_dict(instance))
 
     @overload
-    def __get__(
-        self, instance: None, owner: Type[Any]
-    ) -> InstrumentedAttribute:
+    def __get__(self, instance: None, owner: Any) -> InstrumentedAttribute[_T]:
         ...
 
     @overload
-    def __get__(self, instance: object, owner: Type[Any]) -> Optional[_T]:
+    def __get__(self, instance: object, owner: Any) -> _T:
         ...
 
     def __get__(
-        self, instance: Optional[object], owner: Type[Any]
-    ) -> Union[InstrumentedAttribute, Optional[_T]]:
+        self, instance: Optional[object], owner: Any
+    ) -> Union[InstrumentedAttribute[_T], _T]:
         if instance is None:
             return self
 
         dict_ = instance_dict(instance)
-        if self._supports_population and self.key in dict_:
-            return dict_[self.key]
+        if self.impl.supports_population and self.key in dict_:
+            return dict_[self.key]  # type: ignore[no-any-return]
         else:
             try:
                 state = instance_state(instance)
             except AttributeError as err:
                 raise orm_exc.UnmappedInstanceError(instance) from err
-            return self.impl.get(state, dict_)
+            return self.impl.get(state, dict_)  # type: ignore[no-any-return]
 
 
-HasEntityNamespace = namedtuple("HasEntityNamespace", ["entity_namespace"])
-HasEntityNamespace.is_mapper = HasEntityNamespace.is_aliased_class = False
+@dataclasses.dataclass(frozen=True)
+class AdHocHasEntityNamespace:
+    # py37 compat, no slots=True on dataclass
+    __slots__ = ("entity_namespace",)
+    entity_namespace: _ExternalEntityType[Any]
+    is_mapper: ClassVar[bool] = False
+    is_aliased_class: ClassVar[bool] = False
 
 
 def create_proxied_attribute(
@@ -455,7 +569,7 @@ def create_proxied_attribute(
     # TODO: can move this to descriptor_props if the need for this
     # function is removed from ext/hybrid.py
 
-    class Proxy(QueryableAttribute):
+    class Proxy(QueryableAttribute[Any]):
         """Presents the :class:`.QueryableAttribute` interface as a
         proxy on top of a Python descriptor / :class:`.PropComparator`
         combination.
@@ -464,6 +578,10 @@ def create_proxied_attribute(
 
         _extra_criteria = ()
 
+        # the attribute error catches inside of __getattr__ basically create a
+        # singularity if you try putting slots on this too
+        # __slots__ = ("descriptor", "original_property", "_comparator")
+
         def __init__(
             self,
             class_,
@@ -480,7 +598,15 @@ def create_proxied_attribute(
             self.original_property = original_property
             self._comparator = comparator
             self._adapt_to_entity = adapt_to_entity
-            self.__doc__ = doc
+            self._doc = self.__doc__ = doc
+
+        @property
+        def _parententity(self):
+            return inspection.inspect(self.class_)
+
+        @property
+        def parent(self):
+            return inspection.inspect(self.class_)
 
         _is_internal_proxy = True
 
@@ -496,10 +622,6 @@ def create_proxied_attribute(
                 and getattr(self.class_, self.key).impl.uses_objects
             )
 
-        @property
-        def _parententity(self):
-            return inspection.inspect(self.class_, raiseerr=False)
-
         @property
         def _entity_namespace(self):
             if hasattr(self._comparator, "_parententity"):
@@ -507,7 +629,7 @@ def create_proxied_attribute(
             else:
                 # used by hybrid attributes which try to remain
                 # agnostic of any ORM concepts like mappers
-                return HasEntityNamespace(self.class_)
+                return AdHocHasEntityNamespace(self.class_)
 
         @property
         def property(self):
@@ -552,12 +674,22 @@ def create_proxied_attribute(
             else:
                 return retval
 
-        def __str__(self):
-            return "%s.%s" % (self.class_.__name__, self.key)
+        def __str__(self) -> str:
+            return f"{self.class_.__name__}.{self.key}"
 
         def __getattr__(self, attribute):
             """Delegate __getattr__ to the original descriptor and/or
             comparator."""
+
+            # this is unfortunately very complicated, and is easily prone
+            # to recursion overflows when implementations of related
+            # __getattr__ schemes are changed
+
+            try:
+                return util.MemoizedSlots.__getattr__(self, attribute)
+            except AttributeError:
+                pass
+
             try:
                 return getattr(descriptor, attribute)
             except AttributeError as err:
@@ -602,7 +734,7 @@ OP_BULK_REPLACE = util.symbol("BULK_REPLACE")
 OP_MODIFIED = util.symbol("MODIFIED")
 
 
-class AttributeEvent:
+class AttributeEventToken:
     """A token propagated throughout the course of a chain of attribute
     events.
 
@@ -619,7 +751,8 @@ class AttributeEvent:
     event handlers, and is used to control the propagation of operations
     across two mutually-dependent attributes.
 
-    .. versionadded:: 0.9.0
+    .. versionchanged:: 2.0  Changed the name from ``AttributeEvent``
+       to ``AttributeEventToken``.
 
     :attribute impl: The :class:`.AttributeImpl` which is the current event
      initiator.
@@ -639,7 +772,7 @@ class AttributeEvent:
 
     def __eq__(self, other):
         return (
-            isinstance(other, AttributeEvent)
+            isinstance(other, AttributeEventToken)
             and other.impl is self.impl
             and other.op == self.op
         )
@@ -652,28 +785,37 @@ class AttributeEvent:
         return self.impl.hasparent(state)
 
 
-Event = AttributeEvent
+AttributeEvent = AttributeEventToken  # legacy
+Event = AttributeEventToken  # legacy
 
 
 class AttributeImpl:
     """internal implementation for instrumented attributes."""
 
     collection: bool
+    default_accepts_scalar_loader: bool
+    uses_objects: bool
+    supports_population: bool
+    dynamic: bool
+
+    _replace_token: AttributeEventToken
+    _remove_token: AttributeEventToken
+    _append_token: AttributeEventToken
 
     def __init__(
         self,
-        class_,
-        key,
-        callable_,
-        dispatch,
-        trackparent=False,
-        compare_function=None,
-        active_history=False,
-        parent_token=None,
-        load_on_unexpire=True,
-        send_modified_events=True,
-        accepts_scalar_loader=None,
-        **kwargs,
+        class_: _ExternalEntityType[_O],
+        key: str,
+        callable_: _LoaderCallable,
+        dispatch: _Dispatch[QueryableAttribute[Any]],
+        trackparent: bool = False,
+        compare_function: Optional[Callable[..., bool]] = None,
+        active_history: bool = False,
+        parent_token: Optional[AttributeEventToken] = None,
+        load_on_unexpire: bool = True,
+        send_modified_events: bool = True,
+        accepts_scalar_loader: Optional[bool] = None,
+        **kwargs: Any,
     ):
         r"""Construct an AttributeImpl.
 
@@ -743,7 +885,7 @@ class AttributeImpl:
             self.dispatch._active_history = True
 
         self.load_on_unexpire = load_on_unexpire
-        self._modified_token = Event(self, OP_MODIFIED)
+        self._modified_token = AttributeEventToken(self, OP_MODIFIED)
 
     __slots__ = (
         "class_",
@@ -760,8 +902,8 @@ class AttributeImpl:
         "_deferred_history",
     )
 
-    def __str__(self):
-        return "%s.%s" % (self.class_.__name__, self.key)
+    def __str__(self) -> str:
+        return f"{self.class_.__name__}.{self.key}"
 
     def _get_active_history(self):
         """Backwards compat for impl.active_history"""
@@ -773,7 +915,9 @@ class AttributeImpl:
 
     active_history = property(_get_active_history, _set_active_history)
 
-    def hasparent(self, state, optimistic=False):
+    def hasparent(
+        self, state: InstanceState[Any], optimistic: bool = False
+    ) -> bool:
         """Return the boolean value of a `hasparent` flag attached to
         the given state.
 
@@ -796,7 +940,12 @@ class AttributeImpl:
             state.parents.get(id(self.parent_token), optimistic) is not False
         )
 
-    def sethasparent(self, state, parent_state, value):
+    def sethasparent(
+        self,
+        state: InstanceState[Any],
+        parent_state: InstanceState[Any],
+        value: bool,
+    ) -> None:
         """Set a boolean flag on the given item corresponding to
         whether or not it is attached to a parent object via the
         attribute represented by this ``InstrumentedAttribute``.
@@ -839,11 +988,16 @@ class AttributeImpl:
         self,
         state: InstanceState[Any],
         dict_: _InstanceDict,
-        passive=PASSIVE_OFF,
+        passive: PassiveFlag = PASSIVE_OFF,
     ) -> History:
         raise NotImplementedError()
 
-    def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+    def get_all_pending(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        passive: PassiveFlag = PASSIVE_NO_INITIALIZE,
+    ) -> _AllPendingType:
         """Return a list of tuples of (state, obj)
         for all objects in this attribute's current state
         + history.
@@ -861,7 +1015,9 @@ class AttributeImpl:
         """
         raise NotImplementedError()
 
-    def _default_value(self, state, dict_):
+    def _default_value(
+        self, state: InstanceState[Any], dict_: _InstanceDict
+    ) -> Any:
         """Produce an empty value for an uninitialized scalar attribute."""
 
         assert self.key not in dict_, (
@@ -877,7 +1033,12 @@ class AttributeImpl:
 
         return value
 
-    def get(self, state, dict_, passive=PASSIVE_OFF):
+    def get(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> Any:
         """Retrieve a value from the given object.
         If a callable is assembled on this object's attribute, and
         passive is False, the callable will be executed and the
@@ -917,7 +1078,9 @@ class AttributeImpl:
             else:
                 return self._default_value(state, dict_)
 
-    def _fire_loader_callables(self, state, key, passive):
+    def _fire_loader_callables(
+        self, state: InstanceState[Any], key: str, passive: PassiveFlag
+    ) -> Any:
         if (
             self.accepts_scalar_loader
             and self.load_on_unexpire
@@ -932,15 +1095,36 @@ class AttributeImpl:
         else:
             return ATTR_EMPTY
 
-    def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+    def append(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: Any,
+        initiator: Optional[AttributeEventToken],
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> None:
         self.set(state, dict_, value, initiator, passive=passive)
 
-    def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+    def remove(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: Any,
+        initiator: Optional[AttributeEventToken],
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> None:
         self.set(
             state, dict_, None, initiator, passive=passive, check_old=value
         )
 
-    def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+    def pop(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: Any,
+        initiator: Optional[AttributeEventToken],
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> None:
         self.set(
             state,
             dict_,
@@ -953,17 +1137,25 @@ class AttributeImpl:
 
     def set(
         self,
-        state,
-        dict_,
-        value,
-        initiator,
-        passive=PASSIVE_OFF,
-        check_old=None,
-        pop=False,
-    ):
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: Any,
+        initiator: Optional[AttributeEventToken],
+        passive: PassiveFlag = PASSIVE_OFF,
+        check_old: Any = None,
+        pop: bool = False,
+    ) -> None:
+        raise NotImplementedError()
+
+    def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None:
         raise NotImplementedError()
 
-    def get_committed_value(self, state, dict_, passive=PASSIVE_OFF):
+    def get_committed_value(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> Any:
         """return the unchanged value of this attribute"""
 
         if self.key in state.committed_state:
@@ -996,10 +1188,12 @@ class ScalarAttributeImpl(AttributeImpl):
 
     def __init__(self, *arg, **kw):
         super(ScalarAttributeImpl, self).__init__(*arg, **kw)
-        self._replace_token = self._append_token = Event(self, OP_REPLACE)
-        self._remove_token = Event(self, OP_REMOVE)
+        self._replace_token = self._append_token = AttributeEventToken(
+            self, OP_REPLACE
+        )
+        self._remove_token = AttributeEventToken(self, OP_REMOVE)
 
-    def delete(self, state, dict_):
+    def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None:
         if self.dispatch._active_history:
             old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE)
         else:
@@ -1042,11 +1236,11 @@ class ScalarAttributeImpl(AttributeImpl):
         state: InstanceState[Any],
         dict_: Dict[str, Any],
         value: Any,
-        initiator: Optional[Event],
+        initiator: Optional[AttributeEventToken],
         passive: PassiveFlag = PASSIVE_OFF,
         check_old: Optional[object] = None,
         pop: bool = False,
-    ):
+    ) -> None:
         if self.dispatch._active_history:
             old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE)
         else:
@@ -1059,21 +1253,30 @@ class ScalarAttributeImpl(AttributeImpl):
         state._modified_event(dict_, self, old)
         dict_[self.key] = value
 
-    def fire_replace_event(self, state, dict_, value, previous, initiator):
+    def fire_replace_event(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: _T,
+        previous: Any,
+        initiator: Optional[AttributeEventToken],
+    ) -> _T:
         for fn in self.dispatch.set:
             value = fn(
                 state, value, previous, initiator or self._replace_token
             )
         return value
 
-    def fire_remove_event(self, state, dict_, value, initiator):
+    def fire_remove_event(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: Any,
+        initiator: Optional[AttributeEventToken],
+    ) -> None:
         for fn in self.dispatch.remove:
             fn(state, value, initiator or self._remove_token)
 
-    @property
-    def type(self):
-        self.property.columns[0].type
-
 
 class ScalarObjectAttributeImpl(ScalarAttributeImpl):
     """represents a scalar-holding InstrumentedAttribute,
@@ -1090,7 +1293,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
 
     __slots__ = ()
 
-    def delete(self, state, dict_):
+    def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None:
         if self.dispatch._active_history:
             old = self.get(
                 state,
@@ -1122,7 +1325,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         ):
             raise AttributeError("%s object does not have a value" % self)
 
-    def get_history(self, state, dict_, passive=PASSIVE_OFF):
+    def get_history(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> History:
         if self.key in dict_:
             current = dict_[self.key]
         else:
@@ -1152,7 +1360,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
                 self, state, current, original=original
             )
 
-    def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+    def get_all_pending(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        passive: PassiveFlag = PASSIVE_NO_INITIALIZE,
+    ) -> _AllPendingType:
         if self.key in dict_:
             current = dict_[self.key]
         elif passive & CALLABLES_OK:
@@ -1160,6 +1373,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         else:
             return []
 
+        ret: _AllPendingType
+
         # can't use __hash__(), can't use __eq__() here
         if (
             current is not None
@@ -1184,14 +1399,14 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
 
     def set(
         self,
-        state,
-        dict_,
-        value,
-        initiator,
-        passive=PASSIVE_OFF,
-        check_old=None,
-        pop=False,
-    ):
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: Any,
+        initiator: Optional[AttributeEventToken],
+        passive: PassiveFlag = PASSIVE_OFF,
+        check_old: Any = None,
+        pop: bool = False,
+    ) -> None:
         """Set a value on the given InstanceState."""
 
         if self.dispatch._active_history:
@@ -1227,7 +1442,13 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         value = self.fire_replace_event(state, dict_, value, old, initiator)
         dict_[self.key] = value
 
-    def fire_remove_event(self, state, dict_, value, initiator):
+    def fire_remove_event(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: Any,
+        initiator: Optional[AttributeEventToken],
+    ) -> None:
         if self.trackparent and value not in (
             None,
             PASSIVE_NO_RESULT,
@@ -1240,7 +1461,14 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
 
         state._modified_event(dict_, self, value)
 
-    def fire_replace_event(self, state, dict_, value, previous, initiator):
+    def fire_replace_event(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: _T,
+        previous: Any,
+        initiator: Optional[AttributeEventToken],
+    ) -> _T:
         if self.trackparent:
             if previous is not value and previous not in (
                 None,
@@ -1263,7 +1491,64 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         return value
 
 
-class CollectionAttributeImpl(AttributeImpl):
+class HasCollectionAdapter:
+    __slots__ = ()
+
+    def _dispose_previous_collection(
+        self,
+        state: InstanceState[Any],
+        collection: _AdaptedCollectionProtocol,
+        adapter: CollectionAdapter,
+        fire_event: bool,
+    ) -> None:
+        raise NotImplementedError()
+
+    @overload
+    def get_collection(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        user_data: Optional[_AdaptedCollectionProtocol] = None,
+        passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
+    ) -> CollectionAdapter:
+        ...
+
+    @overload
+    def get_collection(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        user_data: Optional[_AdaptedCollectionProtocol] = None,
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> Union[
+        Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+    ]:
+        ...
+
+    def get_collection(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        user_data: Optional[_AdaptedCollectionProtocol] = None,
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> Union[
+        Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+    ]:
+        raise NotImplementedError()
+
+
+if TYPE_CHECKING:
+
+    def _is_collection_attribute_impl(
+        impl: AttributeImpl,
+    ) -> TypeGuard[CollectionAttributeImpl]:
+        ...
+
+else:
+    _is_collection_attribute_impl = operator.attrgetter("collection")
+
+
+class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
     """A collection-holding attribute that instruments changes in membership.
 
     Only handles collections of instrumented objects.
@@ -1275,12 +1560,14 @@ class CollectionAttributeImpl(AttributeImpl):
 
     """
 
-    default_accepts_scalar_loader = False
     uses_objects = True
-    supports_population = True
     collection = True
+    default_accepts_scalar_loader = False
+    supports_population = True
     dynamic = False
 
+    _bulk_replace_token: AttributeEventToken
+
     __slots__ = (
         "copy",
         "collection_factory",
@@ -1316,9 +1603,9 @@ class CollectionAttributeImpl(AttributeImpl):
             copy_function = self.__copy
         self.copy = copy_function
         self.collection_factory = typecallable
-        self._append_token = Event(self, OP_APPEND)
-        self._remove_token = Event(self, OP_REMOVE)
-        self._bulk_replace_token = Event(self, OP_BULK_REPLACE)
+        self._append_token = AttributeEventToken(self, OP_APPEND)
+        self._remove_token = AttributeEventToken(self, OP_REMOVE)
+        self._bulk_replace_token = AttributeEventToken(self, OP_BULK_REPLACE)
         self._duck_typed_as = util.duck_type_collection(
             self.collection_factory()
         )
@@ -1336,14 +1623,24 @@ class CollectionAttributeImpl(AttributeImpl):
     def __copy(self, item):
         return [y for y in collections.collection_adapter(item)]
 
-    def get_history(self, state, dict_, passive=PASSIVE_OFF):
+    def get_history(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> History:
         current = self.get(state, dict_, passive=passive)
         if current is PASSIVE_NO_RESULT:
             return HISTORY_BLANK
         else:
             return History.from_collection(self, state, current)
 
-    def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+    def get_all_pending(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        passive: PassiveFlag = PASSIVE_NO_INITIALIZE,
+    ) -> _AllPendingType:
         # NOTE: passive is ignored here at the moment
 
         if self.key not in dict_:
@@ -1383,7 +1680,13 @@ class CollectionAttributeImpl(AttributeImpl):
 
         return [(instance_state(o), o) for o in current]
 
-    def fire_append_event(self, state, dict_, value, initiator):
+    def fire_append_event(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: _T,
+        initiator: Optional[AttributeEventToken],
+    ) -> _T:
         for fn in self.dispatch.append:
             value = fn(state, value, initiator or self._append_token)
 
@@ -1394,13 +1697,24 @@ class CollectionAttributeImpl(AttributeImpl):
 
         return value
 
-    def fire_append_wo_mutation_event(self, state, dict_, value, initiator):
+    def fire_append_wo_mutation_event(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: _T,
+        initiator: Optional[AttributeEventToken],
+    ) -> _T:
         for fn in self.dispatch.append_wo_mutation:
             value = fn(state, value, initiator or self._append_token)
 
         return value
 
-    def fire_pre_remove_event(self, state, dict_, initiator):
+    def fire_pre_remove_event(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        initiator: Optional[AttributeEventToken],
+    ) -> None:
         """A special event used for pop() operations.
 
         The "remove" event needs to have the item to be removed passed to
@@ -1411,7 +1725,13 @@ class CollectionAttributeImpl(AttributeImpl):
         """
         state._modified_event(dict_, self, NO_VALUE, True)
 
-    def fire_remove_event(self, state, dict_, value, initiator):
+    def fire_remove_event(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: Any,
+        initiator: Optional[AttributeEventToken],
+    ) -> None:
         if self.trackparent and value is not None:
             self.sethasparent(instance_state(value), state, False)
 
@@ -1420,7 +1740,7 @@ class CollectionAttributeImpl(AttributeImpl):
 
         state._modified_event(dict_, self, NO_VALUE, True)
 
-    def delete(self, state, dict_):
+    def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None:
         if self.key not in dict_:
             return
 
@@ -1433,7 +1753,9 @@ class CollectionAttributeImpl(AttributeImpl):
         # del is a no-op if collection not present.
         del dict_[self.key]
 
-    def _default_value(self, state, dict_):
+    def _default_value(
+        self, state: InstanceState[Any], dict_: _InstanceDict
+    ) -> _AdaptedCollectionProtocol:
         """Produce an empty collection for an un-initialized attribute"""
 
         assert self.key not in dict_, (
@@ -1448,7 +1770,9 @@ class CollectionAttributeImpl(AttributeImpl):
         adapter._set_empty(user_data)
         return user_data
 
-    def _initialize_collection(self, state):
+    def _initialize_collection(
+        self, state: InstanceState[Any]
+    ) -> Tuple[CollectionAdapter, _AdaptedCollectionProtocol]:
 
         adapter, collection = state.manager.initialize_collection(
             self.key, state, self.collection_factory
@@ -1458,7 +1782,14 @@ class CollectionAttributeImpl(AttributeImpl):
 
         return adapter, collection
 
-    def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+    def append(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: Any,
+        initiator: Optional[AttributeEventToken],
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> None:
         collection = self.get_collection(state, dict_, passive=passive)
         if collection is PASSIVE_NO_RESULT:
             value = self.fire_append_event(state, dict_, value, initiator)
@@ -1467,9 +1798,18 @@ class CollectionAttributeImpl(AttributeImpl):
             ), "Collection was loaded during event handling."
             state._get_pending_mutation(self.key).append(value)
         else:
+            if TYPE_CHECKING:
+                assert isinstance(collection, CollectionAdapter)
             collection.append_with_event(value, initiator)
 
-    def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+    def remove(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: Any,
+        initiator: Optional[AttributeEventToken],
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> None:
         collection = self.get_collection(state, state.dict, passive=passive)
         if collection is PASSIVE_NO_RESULT:
             self.fire_remove_event(state, dict_, value, initiator)
@@ -1478,9 +1818,18 @@ class CollectionAttributeImpl(AttributeImpl):
             ), "Collection was loaded during event handling."
             state._get_pending_mutation(self.key).remove(value)
         else:
+            if TYPE_CHECKING:
+                assert isinstance(collection, CollectionAdapter)
             collection.remove_with_event(value, initiator)
 
-    def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+    def pop(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: Any,
+        initiator: Optional[AttributeEventToken],
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> None:
         try:
             # TODO: better solution here would be to add
             # a "popper" role to collections.py to complement
@@ -1491,15 +1840,15 @@ class CollectionAttributeImpl(AttributeImpl):
 
     def set(
         self,
-        state,
-        dict_,
-        value,
-        initiator=None,
-        passive=PASSIVE_OFF,
-        check_old=None,
-        pop=False,
-        _adapt=True,
-    ):
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        value: Any,
+        initiator: Optional[AttributeEventToken] = None,
+        passive: PassiveFlag = PASSIVE_OFF,
+        check_old: Any = None,
+        pop: bool = False,
+        _adapt: bool = True,
+    ) -> None:
         iterable = orig_iterable = value
 
         # pulling a new collection first so that an adaptation exception does
@@ -1518,7 +1867,7 @@ class CollectionAttributeImpl(AttributeImpl):
                         and "None"
                         or iterable.__class__.__name__
                     )
-                    wanted = self._duck_typed_as.__name__
+                    wanted = self._duck_typed_as.__name__  # type: ignore
                     raise TypeError(
                         "Incompatible collection type: %s is not %s-like"
                         % (given, wanted)
@@ -1560,8 +1909,12 @@ class CollectionAttributeImpl(AttributeImpl):
         self._dispose_previous_collection(state, old, old_collection, True)
 
     def _dispose_previous_collection(
-        self, state, collection, adapter, fire_event
-    ):
+        self,
+        state: InstanceState[Any],
+        collection: _AdaptedCollectionProtocol,
+        adapter: CollectionAdapter,
+        fire_event: bool,
+    ) -> None:
         del collection._sa_adapter
 
         # discarding old collection make sure it is not referenced in empty
@@ -1570,11 +1923,15 @@ class CollectionAttributeImpl(AttributeImpl):
         if fire_event:
             self.dispatch.dispose_collection(state, collection, adapter)
 
-    def _invalidate_collection(self, collection: Collection) -> None:
+    def _invalidate_collection(
+        self, collection: _AdaptedCollectionProtocol
+    ) -> None:
         adapter = getattr(collection, "_sa_adapter")
         adapter.invalidated = True
 
-    def set_committed_value(self, state, dict_, value):
+    def set_committed_value(
+        self, state: InstanceState[Any], dict_: _InstanceDict, value: Any
+    ) -> _AdaptedCollectionProtocol:
         """Set an attribute value on the given instance and 'commit' it."""
 
         collection, user_data = self._initialize_collection(state)
@@ -1601,9 +1958,37 @@ class CollectionAttributeImpl(AttributeImpl):
 
         return user_data
 
+    @overload
     def get_collection(
-        self, state, dict_, user_data=None, passive=PASSIVE_OFF
-    ):
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        user_data: Optional[_AdaptedCollectionProtocol] = None,
+        passive: Literal[PassiveFlag.PASSIVE_OFF] = ...,
+    ) -> CollectionAdapter:
+        ...
+
+    @overload
+    def get_collection(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        user_data: Optional[_AdaptedCollectionProtocol] = None,
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> Union[
+        Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+    ]:
+        ...
+
+    def get_collection(
+        self,
+        state: InstanceState[Any],
+        dict_: _InstanceDict,
+        user_data: Optional[_AdaptedCollectionProtocol] = None,
+        passive: PassiveFlag = PASSIVE_OFF,
+    ) -> Union[
+        Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter
+    ]:
         """Retrieve the CollectionAdapter associated with the given state.
 
         if user_data is None, retrieves it from the state using normal
@@ -1612,14 +1997,18 @@ class CollectionAttributeImpl(AttributeImpl):
 
         """
         if user_data is None:
-            user_data = self.get(state, dict_, passive=passive)
-            if user_data is PASSIVE_NO_RESULT:
-                return user_data
+            fetch_user_data = self.get(state, dict_, passive=passive)
+            if fetch_user_data is LoaderCallableStatus.PASSIVE_NO_RESULT:
+                return fetch_user_data
+            else:
+                user_data = cast("_AdaptedCollectionProtocol", fetch_user_data)
 
         return user_data._sa_adapter
 
 
-def backref_listeners(attribute, key, uselist):
+def backref_listeners(
+    attribute: QueryableAttribute[Any], key: str, uselist: bool
+) -> None:
     """Apply listeners to synchronize a two-way relationship."""
 
     # use easily recognizable names for stack traces.
@@ -1695,7 +2084,7 @@ def backref_listeners(attribute, key, uselist):
             check_append_token = child_impl._append_token
             check_bulk_replace_token = (
                 child_impl._bulk_replace_token
-                if child_impl.collection
+                if _is_collection_attribute_impl(child_impl)
                 else None
             )
 
@@ -1728,7 +2117,9 @@ def backref_listeners(attribute, key, uselist):
         # tokens to test for a recursive loop.
         check_append_token = child_impl._append_token
         check_bulk_replace_token = (
-            child_impl._bulk_replace_token if child_impl.collection else None
+            child_impl._bulk_replace_token
+            if _is_collection_attribute_impl(child_impl)
+            else None
         )
 
         if (
@@ -1756,6 +2147,8 @@ def backref_listeners(attribute, key, uselist):
             )
             child_impl = child_state.manager[key].impl
 
+            check_replace_token: Optional[AttributeEventToken]
+
             # tokens to test for a recursive loop.
             if not child_impl.collection and not child_impl.dynamic:
                 check_remove_token = child_impl._remove_token
@@ -1765,7 +2158,7 @@ def backref_listeners(attribute, key, uselist):
                 check_remove_token = child_impl._remove_token
                 check_replace_token = (
                     child_impl._bulk_replace_token
-                    if child_impl.collection
+                    if _is_collection_attribute_impl(child_impl)
                     else None
                 )
                 check_for_dupes_on_remove = False
@@ -1848,10 +2241,10 @@ class History(NamedTuple):
     unchanged: Union[Tuple[()], List[Any]]
     deleted: Union[Tuple[()], List[Any]]
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return self != HISTORY_BLANK
 
-    def empty(self):
+    def empty(self) -> bool:
         """Return True if this :class:`.History` has no changes
         and no existing, unchanged state.
 
@@ -1859,29 +2252,29 @@ class History(NamedTuple):
 
         return not bool((self.added or self.deleted) or self.unchanged)
 
-    def sum(self):
+    def sum(self) -> Sequence[Any]:
         """Return a collection of added + unchanged + deleted."""
 
         return (
             (self.added or []) + (self.unchanged or []) + (self.deleted or [])
         )
 
-    def non_deleted(self):
+    def non_deleted(self) -> Sequence[Any]:
         """Return a collection of added + unchanged."""
 
         return (self.added or []) + (self.unchanged or [])
 
-    def non_added(self):
+    def non_added(self) -> Sequence[Any]:
         """Return a collection of unchanged + deleted."""
 
         return (self.unchanged or []) + (self.deleted or [])
 
-    def has_changes(self):
+    def has_changes(self) -> bool:
         """Return True if this :class:`.History` has changes."""
 
         return bool(self.added or self.deleted)
 
-    def as_state(self):
+    def as_state(self) -> History:
         return History(
             [
                 (c is not None) and instance_state(c) or None
@@ -1898,9 +2291,16 @@ class History(NamedTuple):
         )
 
     @classmethod
-    def from_scalar_attribute(cls, attribute, state, current):
+    def from_scalar_attribute(
+        cls,
+        attribute: ScalarAttributeImpl,
+        state: InstanceState[Any],
+        current: Any,
+    ) -> History:
         original = state.committed_state.get(attribute.key, _NO_HISTORY)
 
+        deleted: Union[Tuple[()], List[Any]]
+
         if original is _NO_HISTORY:
             if current is NO_VALUE:
                 return cls((), (), ())
@@ -1933,8 +2333,14 @@ class History(NamedTuple):
 
     @classmethod
     def from_object_attribute(
-        cls, attribute, state, current, original=_NO_HISTORY
-    ):
+        cls,
+        attribute: ScalarObjectAttributeImpl,
+        state: InstanceState[Any],
+        current: Any,
+        original: Any = _NO_HISTORY,
+    ) -> History:
+        deleted: Union[Tuple[()], List[Any]]
+
         if original is _NO_HISTORY:
             original = state.committed_state.get(attribute.key, _NO_HISTORY)
 
@@ -1965,7 +2371,12 @@ class History(NamedTuple):
                 return cls([current], (), deleted)
 
     @classmethod
-    def from_collection(cls, attribute, state, current):
+    def from_collection(
+        cls,
+        attribute: CollectionAttributeImpl,
+        state: InstanceState[Any],
+        current: Any,
+    ) -> History:
         original = state.committed_state.get(attribute.key, _NO_HISTORY)
         if current is NO_VALUE:
             return cls((), (), ())
@@ -1999,7 +2410,9 @@ class History(NamedTuple):
 HISTORY_BLANK = History((), (), ())
 
 
-def get_history(obj, key, passive=PASSIVE_OFF):
+def get_history(
+    obj: object, key: str, passive: PassiveFlag = PASSIVE_OFF
+) -> History:
     """Return a :class:`.History` record for the given object
     and attribute key.
 
@@ -2037,36 +2450,47 @@ def get_history(obj, key, passive=PASSIVE_OFF):
     return get_state_history(instance_state(obj), key, passive)
 
 
-def get_state_history(state, key, passive=PASSIVE_OFF):
+def get_state_history(
+    state: InstanceState[Any], key: str, passive: PassiveFlag = PASSIVE_OFF
+) -> History:
     return state.get_history(key, passive)
 
 
-def has_parent(cls, obj, key, optimistic=False):
+def has_parent(
+    cls: Type[_O], obj: _O, key: str, optimistic: bool = False
+) -> bool:
     """TODO"""
     manager = manager_of_class(cls)
     state = instance_state(obj)
     return manager.has_parent(state, key, optimistic)
 
 
-def register_attribute(class_, key, **kw):
-    comparator = kw.pop("comparator", None)
-    parententity = kw.pop("parententity", None)
-    doc = kw.pop("doc", None)
-    desc = register_descriptor(class_, key, comparator, parententity, doc=doc)
+def register_attribute(
+    class_: Type[_O],
+    key: str,
+    *,
+    comparator: interfaces.PropComparator[_T],
+    parententity: _InternalEntityType[_O],
+    doc: Optional[str] = None,
+    **kw: Any,
+) -> InstrumentedAttribute[_T]:
+    desc = register_descriptor(
+        class_, key, comparator=comparator, parententity=parententity, doc=doc
+    )
     register_attribute_impl(class_, key, **kw)
     return desc
 
 
 def register_attribute_impl(
-    class_,
-    key,
-    uselist=False,
-    callable_=None,
-    useobject=False,
-    impl_class=None,
-    backref=None,
-    **kw,
-):
+    class_: Type[_O],
+    key: str,
+    uselist: bool = False,
+    callable_: Optional[_LoaderCallable] = None,
+    useobject: bool = False,
+    impl_class: Optional[Type[AttributeImpl]] = None,
+    backref: Optional[str] = None,
+    **kw: Any,
+) -> InstrumentedAttribute[Any]:
 
     manager = manager_of_class(class_)
     if uselist:
@@ -2077,10 +2501,18 @@ def register_attribute_impl(
     else:
         typecallable = kw.pop("typecallable", None)
 
-    dispatch = manager[key].dispatch
+    dispatch = cast(
+        "_Dispatch[QueryableAttribute[Any]]", manager[key].dispatch
+    )  # noqa: E501
+
+    impl: AttributeImpl
 
     if impl_class:
-        impl = impl_class(class_, key, typecallable, dispatch, **kw)
+        # TODO: this appears to be the DynamicAttributeImpl constructor
+        # which is hardcoded
+        impl = cast("Type[DynamicAttributeImpl]", impl_class)(
+            class_, key, typecallable, dispatch, **kw
+        )
     elif uselist:
         impl = CollectionAttributeImpl(
             class_, key, callable_, dispatch, typecallable=typecallable, **kw
@@ -2102,8 +2534,13 @@ def register_attribute_impl(
 
 
 def register_descriptor(
-    class_, key, comparator=None, parententity=None, doc=None
-):
+    class_: Type[Any],
+    key: str,
+    *,
+    comparator: interfaces.PropComparator[_T],
+    parententity: _InternalEntityType[Any],
+    doc: Optional[str] = None,
+) -> InstrumentedAttribute[_T]:
     manager = manager_of_class(class_)
 
     descriptor = InstrumentedAttribute(
@@ -2116,11 +2553,11 @@ def register_descriptor(
     return descriptor
 
 
-def unregister_attribute(class_, key):
+def unregister_attribute(class_: Type[Any], key: str) -> None:
     manager_of_class(class_).uninstrument_attribute(key)
 
 
-def init_collection(obj, key):
+def init_collection(obj: object, key: str) -> CollectionAdapter:
     """Initialize a collection attribute and return the collection adapter.
 
     This function is used to provide direct access to collection internals
@@ -2143,7 +2580,9 @@ def init_collection(obj, key):
     return init_state_collection(state, dict_, key)
 
 
-def init_state_collection(state, dict_, key):
+def init_state_collection(
+    state: InstanceState[Any], dict_: _InstanceDict, key: str
+) -> CollectionAdapter:
     """Initialize a collection attribute and return the collection adapter.
 
     Discards any existing collection which may be there.
@@ -2151,6 +2590,9 @@ def init_state_collection(state, dict_, key):
     """
     attr = state.manager[key].impl
 
+    if TYPE_CHECKING:
+        assert isinstance(attr, HasCollectionAdapter)
+
     old = dict_.pop(key, None)  # discard old collection
     if old is not None:
         old_collection = old._sa_adapter
@@ -2182,7 +2624,12 @@ def set_committed_value(instance, key, value):
     state.manager[key].impl.set_committed_value(state, dict_, value)
 
 
-def set_attribute(instance, key, value, initiator=None):
+def set_attribute(
+    instance: object,
+    key: str,
+    value: Any,
+    initiator: Optional[AttributeEventToken] = None,
+) -> None:
     """Set the value of an attribute, firing history events.
 
     This function may be used regardless of instrumentation
@@ -2211,7 +2658,7 @@ def set_attribute(instance, key, value, initiator=None):
     state.manager[key].impl.set(state, dict_, value, initiator)
 
 
-def get_attribute(instance, key):
+def get_attribute(instance: object, key: str) -> Any:
     """Get the value of an attribute, firing any callables required.
 
     This function may be used regardless of instrumentation
@@ -2225,7 +2672,7 @@ def get_attribute(instance, key):
     return state.manager[key].impl.get(state, dict_)
 
 
-def del_attribute(instance, key):
+def del_attribute(instance: object, key: str) -> None:
     """Delete the value of an attribute, firing history events.
 
     This function may be used regardless of instrumentation
@@ -2239,7 +2686,7 @@ def del_attribute(instance, key):
     state.manager[key].impl.delete(state, dict_)
 
 
-def flag_modified(instance, key):
+def flag_modified(instance: object, key: str) -> None:
     """Mark an attribute on an instance as 'modified'.
 
     This sets the 'modified' flag on the instance and
@@ -2262,7 +2709,7 @@ def flag_modified(instance, key):
     state._modified_event(dict_, impl, NO_VALUE, is_userland=True)
 
 
-def flag_dirty(instance):
+def flag_dirty(instance: object) -> None:
     """Mark an instance as 'dirty' without any specific attribute mentioned.
 
     This is a special operation that will allow the object to travel through
index 367a5332dee061c87da1b3d6fec685eba6e2526b..0ace9b1cb625ac933085f9903f50d53c571ad2c9 100644 (file)
@@ -38,14 +38,15 @@ from ..util.typing import Literal
 from ..util.typing import Self
 
 if typing.TYPE_CHECKING:
+    from ._typing import _ExternalEntityType
     from ._typing import _InternalEntityType
     from .attributes import InstrumentedAttribute
     from .instrumentation import ClassManager
     from .mapper import Mapper
     from .state import InstanceState
+    from .util import AliasedClass
     from ..sql._typing import _InfoType
 
-
 _T = TypeVar("_T", bound=Any)
 
 _O = TypeVar("_O", bound=object)
@@ -267,10 +268,22 @@ def _assertions(
 
 if TYPE_CHECKING:
 
-    def manager_of_class(cls: Type[Any]) -> ClassManager:
+    def manager_of_class(cls: Type[_O]) -> ClassManager[_O]:
+        ...
+
+    @overload
+    def opt_manager_of_class(cls: AliasedClass[Any]) -> None:
         ...
 
-    def opt_manager_of_class(cls: Type[Any]) -> Optional[ClassManager]:
+    @overload
+    def opt_manager_of_class(
+        cls: _ExternalEntityType[_O],
+    ) -> Optional[ClassManager[_O]]:
+        ...
+
+    def opt_manager_of_class(
+        cls: _ExternalEntityType[_O],
+    ) -> Optional[ClassManager[_O]]:
         ...
 
     def instance_state(instance: _O) -> InstanceState[_O]:
@@ -719,7 +732,7 @@ class Mapped(ORMDescriptor[_T], roles.TypedColumnsClauseRole[_T], TypingOnly):
             ...
 
         def __get__(
-            self, instance: object, owner: Any
+            self, instance: Optional[object], owner: Any
         ) -> Union[InstrumentedAttribute[_T], _T]:
             ...
 
@@ -729,10 +742,10 @@ class Mapped(ORMDescriptor[_T], roles.TypedColumnsClauseRole[_T], TypingOnly):
 
         def __set__(
             self, instance: Any, value: Union[SQLCoreOperations[_T], _T]
-        ):
+        ) -> None:
             ...
 
-        def __delete__(self, instance: Any):
+        def __delete__(self, instance: Any) -> None:
             ...
 
 
index 717f1d0d68b078fdb212e99d0e3189628b3b48a1..da0da0fcfc011cc5293477941e15f175ea96b0aa 100644 (file)
@@ -4,7 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+# mypy: allow-untyped-defs, allow-untyped-calls
 
 """Support for collections of mapped entities.
 
@@ -109,17 +109,34 @@ import operator
 import threading
 import typing
 from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
 import weakref
 
 from .. import exc as sa_exc
 from .. import util
 from ..util.compat import inspect_getfullargspec
+from ..util.typing import Protocol
 
 if typing.TYPE_CHECKING:
+    from .attributes import CollectionAttributeImpl
     from .mapped_collection import attribute_mapped_collection
     from .mapped_collection import column_mapped_collection
     from .mapped_collection import mapped_collection
     from .mapped_collection import MappedCollection  # noqa: F401
+    from .state import InstanceState
+
 
 __all__ = [
     "collection",
@@ -132,6 +149,28 @@ __all__ = [
 __instrumentation_mutex = threading.Lock()
 
 
+_CollectionFactoryType = Callable[[], "_AdaptedCollectionProtocol"]
+
+_T = TypeVar("_T", bound=Any)
+_KT = TypeVar("_KT", bound=Any)
+_VT = TypeVar("_VT", bound=Any)
+_COL = TypeVar("_COL", bound="Collection[Any]")
+_FN = TypeVar("_FN", bound="Callable[..., Any]")
+
+
+class _CollectionConverterProtocol(Protocol):
+    def __call__(self, collection: _COL) -> _COL:
+        ...
+
+
+class _AdaptedCollectionProtocol(Protocol):
+    _sa_adapter: CollectionAdapter
+    _sa_appender: Callable[..., Any]
+    _sa_remover: Callable[..., Any]
+    _sa_iterator: Callable[..., Iterable[Any]]
+    _sa_converter: _CollectionConverterProtocol
+
+
 class collection:
     """Decorators for entity collection classes.
 
@@ -396,8 +435,13 @@ class collection:
         return decorator
 
 
-collection_adapter = operator.attrgetter("_sa_adapter")
-"""Fetch the :class:`.CollectionAdapter` for a collection."""
+if TYPE_CHECKING:
+
+    def collection_adapter(collection: Collection[Any]) -> CollectionAdapter:
+        """Fetch the :class:`.CollectionAdapter` for a collection."""
+
+else:
+    collection_adapter = operator.attrgetter("_sa_adapter")
 
 
 class CollectionAdapter:
@@ -423,10 +467,33 @@ class CollectionAdapter:
         "empty",
     )
 
-    def __init__(self, attr, owner_state, data):
+    attr: CollectionAttributeImpl
+    _key: str
+
+    # this is actually a weakref; see note in constructor
+    _data: Callable[..., _AdaptedCollectionProtocol]
+
+    owner_state: InstanceState[Any]
+    _converter: _CollectionConverterProtocol
+    invalidated: bool
+    empty: bool
+
+    def __init__(
+        self,
+        attr: CollectionAttributeImpl,
+        owner_state: InstanceState[Any],
+        data: _AdaptedCollectionProtocol,
+    ):
         self.attr = attr
         self._key = attr.key
-        self._data = weakref.ref(data)
+
+        # this weakref stays referenced throughout the lifespan of
+        # CollectionAdapter.  so while the weakref can return None, this
+        # is realistically only during garbage collection of this object, so
+        # we type this as a callable that returns _AdaptedCollectionProtocol
+        # in all cases.
+        self._data = weakref.ref(data)  # type: ignore
+
         self.owner_state = owner_state
         data._sa_adapter = self
         self._converter = data._sa_converter
@@ -437,7 +504,7 @@ class CollectionAdapter:
         util.warn("This collection has been invalidated.")
 
     @property
-    def data(self):
+    def data(self) -> _AdaptedCollectionProtocol:
         "The entity collection being adapted."
         return self._data()
 
@@ -634,7 +701,10 @@ class CollectionAdapter:
     def __setstate__(self, d):
         self._key = d["key"]
         self.owner_state = d["owner_state"]
-        self._data = weakref.ref(d["data"])
+
+        # see note in constructor regarding this type: ignore
+        self._data = weakref.ref(d["data"])  # type: ignore
+
         self._converter = d["data"]._sa_converter
         d["data"]._sa_adapter = self
         self.invalidated = d["invalidated"]
@@ -682,7 +752,9 @@ def bulk_replace(values, existing_adapter, new_adapter, initiator=None):
             existing_adapter.fire_remove_event(member, initiator=initiator)
 
 
-def prepare_instrumentation(factory):
+def prepare_instrumentation(
+    factory: Union[Type[Collection[Any]], _CollectionFactoryType],
+) -> _CollectionFactoryType:
     """Prepare a callable for future use as a collection class factory.
 
     Given a collection class factory (either a type or no-arg callable),
@@ -693,18 +765,30 @@ def prepare_instrumentation(factory):
     into the run-time behavior of collection_class=InstrumentedList.
 
     """
+
+    impl_factory: _CollectionFactoryType
+
     # Convert a builtin to 'Instrumented*'
     if factory in __canned_instrumentation:
-        factory = __canned_instrumentation[factory]
+        impl_factory = __canned_instrumentation[factory]
+    else:
+        impl_factory = cast(_CollectionFactoryType, factory)
+
+    cls: Union[_CollectionFactoryType, Type[Collection[Any]]]
 
     # Create a specimen
-    cls = type(factory())
+    cls = type(impl_factory())
 
     # Did factory callable return a builtin?
     if cls in __canned_instrumentation:
-        # Wrap it so that it returns our 'Instrumented*'
-        factory = __converting_factory(cls, factory)
-        cls = factory()
+
+        # if so, just convert.
+        # in previous major releases, this codepath wasn't working and was
+        # not covered by tests.   prior to that it supplied a "wrapper"
+        # function that would return the class, though the rationale for this
+        # case is not known
+        impl_factory = __canned_instrumentation[cls]
+        cls = type(impl_factory())
 
     # Instrument the class if needed.
     if __instrumentation_mutex.acquire():
@@ -714,26 +798,7 @@ def prepare_instrumentation(factory):
         finally:
             __instrumentation_mutex.release()
 
-    return factory
-
-
-def __converting_factory(specimen_cls, original_factory):
-    """Return a wrapper that converts a "canned" collection like
-    set, dict, list into the Instrumented* version.
-
-    """
-
-    instrumented_cls = __canned_instrumentation[specimen_cls]
-
-    def wrapper():
-        collection = original_factory()
-        return instrumented_cls(collection)
-
-    # often flawed but better than nothing
-    wrapper.__name__ = "%sWrapper" % original_factory.__name__
-    wrapper.__doc__ = original_factory.__doc__
-
-    return wrapper
+    return impl_factory
 
 
 def _instrument_class(cls):
@@ -763,8 +828,8 @@ def _locate_roles_and_methods(cls):
 
     """
 
-    roles = {}
-    methods = {}
+    roles: Dict[str, str] = {}
+    methods: Dict[str, Tuple[Optional[str], Optional[int], Optional[str]]] = {}
 
     for supercls in cls.__mro__:
         for name, method in vars(supercls).items():
@@ -784,7 +849,9 @@ def _locate_roles_and_methods(cls):
 
             # transfer instrumentation requests from decorated function
             # to the combined queue
-            before, after = None, None
+            before: Optional[Tuple[str, int]] = None
+            after: Optional[str] = None
+
             if hasattr(method, "_sa_instrument_before"):
                 op, argument = method._sa_instrument_before
                 assert op in ("fire_append_event", "fire_remove_event")
@@ -809,6 +876,7 @@ def _setup_canned_roles(cls, roles, methods):
     """
     collection_type = util.duck_type_collection(cls)
     if collection_type in __interfaces:
+        assert collection_type is not None
         canned_roles, decorators = __interfaces[collection_type]
         for role, name in canned_roles.items():
             roles.setdefault(role, name)
@@ -934,9 +1002,9 @@ def _instrument_membership_mutator(method, before, argument, after):
                 getattr(executor, after)(res, initiator)
             return res
 
-    wrapper._sa_instrumented = True
+    wrapper._sa_instrumented = True  # type: ignore[attr-defined]
     if hasattr(method, "_sa_instrument_role"):
-        wrapper._sa_instrument_role = method._sa_instrument_role
+        wrapper._sa_instrument_role = method._sa_instrument_role  # type: ignore[attr-defined]  # noqa: E501
     wrapper.__name__ = method.__name__
     wrapper.__doc__ = method.__doc__
     return wrapper
@@ -990,7 +1058,7 @@ def __before_pop(collection, _sa_initiator=None):
         executor.fire_pre_remove_event(_sa_initiator)
 
 
-def _list_decorators():
+def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
     """Tailored instrumentation wrappers for any list-like class."""
 
     def _tidy(fn):
@@ -1131,7 +1199,7 @@ def _list_decorators():
     return l
 
 
-def _dict_decorators():
+def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
     """Tailored instrumentation wrappers for any dict-like mapping class."""
 
     def _tidy(fn):
@@ -1255,7 +1323,7 @@ def _set_binops_check_loose(self: Any, obj: Any) -> bool:
     )
 
 
-def _set_decorators():
+def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]:
     """Tailored instrumentation wrappers for any set-like class."""
 
     def _tidy(fn):
@@ -1420,36 +1488,52 @@ def _set_decorators():
     return l
 
 
-class InstrumentedList(list):
+class InstrumentedList(List[_T]):
     """An instrumented version of the built-in list."""
 
 
-class InstrumentedSet(set):
+class InstrumentedSet(Set[_T]):
     """An instrumented version of the built-in set."""
 
 
-class InstrumentedDict(dict):
+class InstrumentedDict(Dict[_KT, _VT]):
     """An instrumented version of the built-in dict."""
 
 
-__canned_instrumentation = {
-    list: InstrumentedList,
-    set: InstrumentedSet,
-    dict: InstrumentedDict,
-}
-
-__interfaces = {
-    list: (
-        {"appender": "append", "remover": "remove", "iterator": "__iter__"},
-        _list_decorators(),
-    ),
-    set: (
-        {"appender": "add", "remover": "remove", "iterator": "__iter__"},
-        _set_decorators(),
-    ),
-    # decorators are required for dicts and object collections.
-    dict: ({"iterator": "values"}, _dict_decorators()),
-}
+__canned_instrumentation: util.immutabledict[
+    Any, _CollectionFactoryType
+] = util.immutabledict(
+    {
+        list: InstrumentedList,
+        set: InstrumentedSet,
+        dict: InstrumentedDict,
+    }
+)
+
+__interfaces: util.immutabledict[
+    Any,
+    Tuple[
+        Dict[str, str],
+        Dict[str, Callable[..., Any]],
+    ],
+] = util.immutabledict(
+    {
+        list: (
+            {
+                "appender": "append",
+                "remover": "remove",
+                "iterator": "__iter__",
+            },
+            _list_decorators(),
+        ),
+        set: (
+            {"appender": "add", "remover": "remove", "iterator": "__iter__"},
+            _set_decorators(),
+        ),
+        # decorators are required for dicts and object collections.
+        dict: ({"iterator": "values"}, _dict_decorators()),
+    }
+)
 
 
 def __go(lcls):
index 63a37d0dae368b51397a28580ab06883331d24fe..1b4f573b506bf63ff54377931ba84c1bf46d76cf 100644 (file)
@@ -64,7 +64,9 @@ class DynaLoader(strategies.AbstractRelationshipLoader):
         )
 
 
-class DynamicAttributeImpl(attributes.AttributeImpl):
+class DynamicAttributeImpl(
+    attributes.HasCollectionAdapter, attributes.AttributeImpl
+):
     uses_objects = True
     default_accepts_scalar_loader = False
     supports_population = False
@@ -120,11 +122,11 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
 
     @util.memoized_property
     def _append_token(self):
-        return attributes.Event(self, attributes.OP_APPEND)
+        return attributes.AttributeEventToken(self, attributes.OP_APPEND)
 
     @util.memoized_property
     def _remove_token(self):
-        return attributes.Event(self, attributes.OP_REMOVE)
+        return attributes.AttributeEventToken(self, attributes.OP_REMOVE)
 
     def fire_append_event(
         self, state, dict_, value, initiator, collection_history=None
index 356958562f4d7e2d913140a57d44ae518ae7c741..85b85215ea019043b3f2881957973d446a760d4e 100644 (file)
@@ -4,7 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+# mypy: allow-untyped-defs, allow-untyped-calls
 
 """Defines SQLAlchemy's system of class instrumentation.
 
@@ -35,14 +35,19 @@ from __future__ import annotations
 
 from typing import Any
 from typing import Callable
+from typing import cast
+from typing import Collection
 from typing import Dict
 from typing import Generic
+from typing import Iterable
+from typing import List
 from typing import Optional
 from typing import Set
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
 from typing import TypeVar
+from typing import Union
 import weakref
 
 from . import base
@@ -51,15 +56,21 @@ from . import exc
 from . import interfaces
 from . import state
 from ._typing import _O
+from .attributes import _is_collection_attribute_impl
 from .. import util
 from ..event import EventTarget
 from ..util import HasMemoized
+from ..util.typing import Literal
 from ..util.typing import Protocol
 
 if TYPE_CHECKING:
     from ._typing import _RegistryType
+    from .attributes import AttributeImpl
     from .attributes import InstrumentedAttribute
+    from .collections import _AdaptedCollectionProtocol
+    from .collections import _CollectionFactoryType
     from .decl_base import _MapperConfig
+    from .events import InstanceEvents
     from .mapper import Mapper
     from .state import InstanceState
     from ..event import dispatcher
@@ -74,7 +85,7 @@ class _ExpiredAttributeLoaderProto(Protocol):
         state: state.InstanceState[Any],
         toload: Set[str],
         passive: base.PassiveFlag,
-    ):
+    ) -> None:
         ...
 
 
@@ -91,7 +102,7 @@ class ClassManager(
 ):
     """Tracks state information at the class level."""
 
-    dispatch: dispatcher[ClassManager]
+    dispatch: dispatcher[ClassManager[_O]]
 
     MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR
     STATE_ATTR = base.DEFAULT_STATE_ATTR
@@ -108,8 +119,9 @@ class ClassManager(
     declarative_scan: Optional[weakref.ref[_MapperConfig]] = None
     registry: Optional[_RegistryType] = None
 
-    @property
-    @util.deprecated(
+    _bases: List[ClassManager[Any]]
+
+    @util.deprecated_property(
         "1.4",
         message="The ClassManager.deferred_scalar_loader attribute is now "
         "named expired_attribute_loader",
@@ -117,7 +129,7 @@ class ClassManager(
     def deferred_scalar_loader(self):
         return self.expired_attribute_loader
 
-    @deferred_scalar_loader.setter
+    @deferred_scalar_loader.setter  # type: ignore[no-redef]
     @util.deprecated(
         "1.4",
         message="The ClassManager.deferred_scalar_loader attribute is now "
@@ -138,18 +150,23 @@ class ClassManager(
 
         self._bases = [
             mgr
-            for mgr in [
-                opt_manager_of_class(base)
-                for base in self.class_.__bases__
-                if isinstance(base, type)
-            ]
+            for mgr in cast(
+                "List[Optional[ClassManager[Any]]]",
+                [
+                    opt_manager_of_class(base)
+                    for base in self.class_.__bases__
+                    if isinstance(base, type)
+                ],
+            )
             if mgr is not None
         ]
 
         for base_ in self._bases:
             self.update(base_)
 
-        self.dispatch._events._new_classmanager_instance(class_, self)
+        cast(
+            "InstanceEvents", self.dispatch._events
+        )._new_classmanager_instance(class_, self)
 
         for basecls in class_.__mro__:
             mgr = opt_manager_of_class(basecls)
@@ -263,7 +280,7 @@ class ClassManager(
 
         """
 
-        found = {}
+        found: Dict[str, Any] = {}
 
         # constraints:
         # 1. yield keys in cls.__dict__ order
@@ -303,7 +320,7 @@ class ClassManager(
 
         return key in self and self[key].impl is not None
 
-    def _subclass_manager(self, cls):
+    def _subclass_manager(self, cls: Type[_T]) -> ClassManager[_T]:
         """Create a new ClassManager for a subclass of this ClassManager's
         class.
 
@@ -321,7 +338,7 @@ class ClassManager(
         self.install_member("__init__", self.new_init)
 
     @util.memoized_property
-    def _state_constructor(self):
+    def _state_constructor(self) -> Type[state.InstanceState[_O]]:
         self.dispatch.first_init(self, self.class_)
         return state.InstanceState
 
@@ -393,13 +410,15 @@ class ClassManager(
             if manager:
                 manager.uninstrument_attribute(key, True)
 
-    def unregister(self):
+    def unregister(self) -> None:
         """remove all instrumentation established by this ClassManager."""
 
         for key in list(self.originals):
             self.uninstall_member(key)
 
-        self.mapper = self.dispatch = self.new_init = None
+        self.mapper = None  # type: ignore
+        self.dispatch = None  # type: ignore
+        self.new_init = None
         self.info.clear()
 
         for key in list(self):
@@ -409,7 +428,9 @@ class ClassManager(
         if self.MANAGER_ATTR in self.class_.__dict__:
             delattr(self.class_, self.MANAGER_ATTR)
 
-    def install_descriptor(self, key, inst):
+    def install_descriptor(
+        self, key: str, inst: InstrumentedAttribute[Any]
+    ) -> None:
         if key in (self.STATE_ATTR, self.MANAGER_ATTR):
             raise KeyError(
                 "%r: requested attribute name conflicts with "
@@ -417,10 +438,10 @@ class ClassManager(
             )
         setattr(self.class_, key, inst)
 
-    def uninstall_descriptor(self, key):
+    def uninstall_descriptor(self, key: str) -> None:
         delattr(self.class_, key)
 
-    def install_member(self, key, implementation):
+    def install_member(self, key: str, implementation: Any) -> None:
         if key in (self.STATE_ATTR, self.MANAGER_ATTR):
             raise KeyError(
                 "%r: requested attribute name conflicts with "
@@ -429,34 +450,41 @@ class ClassManager(
         self.originals.setdefault(key, self.class_.__dict__.get(key, DEL_ATTR))
         setattr(self.class_, key, implementation)
 
-    def uninstall_member(self, key):
+    def uninstall_member(self, key: str) -> None:
         original = self.originals.pop(key, None)
         if original is not DEL_ATTR:
             setattr(self.class_, key, original)
         else:
             delattr(self.class_, key)
 
-    def instrument_collection_class(self, key, collection_class):
+    def instrument_collection_class(
+        self, key: str, collection_class: Type[Collection[Any]]
+    ) -> _CollectionFactoryType:
         return collections.prepare_instrumentation(collection_class)
 
-    def initialize_collection(self, key, state, factory):
+    def initialize_collection(
+        self,
+        key: str,
+        state: InstanceState[_O],
+        factory: _CollectionFactoryType,
+    ) -> Tuple[collections.CollectionAdapter, _AdaptedCollectionProtocol]:
         user_data = factory()
-        adapter = collections.CollectionAdapter(
-            self.get_impl(key), state, user_data
-        )
+        impl = self.get_impl(key)
+        assert _is_collection_attribute_impl(impl)
+        adapter = collections.CollectionAdapter(impl, state, user_data)
         return adapter, user_data
 
-    def is_instrumented(self, key, search=False):
+    def is_instrumented(self, key: str, search: bool = False) -> bool:
         if search:
             return key in self
         else:
             return key in self.local_attrs
 
-    def get_impl(self, key):
+    def get_impl(self, key: str) -> AttributeImpl:
         return self[key].impl
 
     @property
-    def attributes(self):
+    def attributes(self) -> Iterable[Any]:
         return iter(self.values())
 
     # InstanceState management
@@ -466,22 +494,26 @@ class ClassManager(
         if state is None:
             state = self._state_constructor(instance, self)
         self._state_setter(instance, state)
-        return instance
+        return instance  # type: ignore[no-any-return]
 
-    def setup_instance(self, instance, state=None):
+    def setup_instance(
+        self, instance: _O, state: Optional[InstanceState[_O]] = None
+    ) -> None:
         if state is None:
             state = self._state_constructor(instance, self)
         self._state_setter(instance, state)
 
-    def teardown_instance(self, instance):
+    def teardown_instance(self, instance: _O) -> None:
         delattr(instance, self.STATE_ATTR)
 
     def _serialize(
-        self, state: state.InstanceState, state_dict: Dict[str, Any]
+        self, state: InstanceState[_O], state_dict: Dict[str, Any]
     ) -> _SerializeManager:
         return _SerializeManager(state, state_dict)
 
-    def _new_state_if_none(self, instance):
+    def _new_state_if_none(
+        self, instance: _O
+    ) -> Union[Literal[False], InstanceState[_O]]:
         """Install a default InstanceState if none is present.
 
         A private convenience method used by the __init__ decorator.
@@ -503,20 +535,20 @@ class ClassManager(
             self._state_setter(instance, state)
             return state
 
-    def has_state(self, instance):
+    def has_state(self, instance: _O) -> bool:
         return hasattr(instance, self.STATE_ATTR)
 
-    def has_parent(self, state, key, optimistic=False):
+    def has_parent(
+        self, state: InstanceState[_O], key: str, optimistic: bool = False
+    ) -> bool:
         """TODO"""
         return self.get_impl(key).hasparent(state, optimistic=optimistic)
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         """All ClassManagers are non-zero regardless of attribute state."""
         return True
 
-    __nonzero__ = __bool__
-
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<%s of %r at %x>" % (
             self.__class__.__name__,
             self.class_,
@@ -558,9 +590,11 @@ class _SerializeManager:
         manager.dispatch.unpickle(state, state_dict)
 
 
-class InstrumentationFactory:
+class InstrumentationFactory(EventTarget):
     """Factory for new ClassManager instances."""
 
+    dispatch: dispatcher[InstrumentationFactory]
+
     def create_manager_for_cls(self, class_: Type[_O]) -> ClassManager[_O]:
         assert class_ is not None
         assert opt_manager_of_class(class_) is None
@@ -589,11 +623,10 @@ class InstrumentationFactory:
 
     def _check_conflicts(
         self, class_: Type[_O], factory: Callable[[Type[_O]], ClassManager[_O]]
-    ):
+    ) -> None:
         """Overridden by a subclass to test for conflicting factories."""
-        return
 
-    def unregister(self, class_):
+    def unregister(self, class_: Type[_O]) -> None:
         manager = manager_of_class(class_)
         manager.unregister()
         self.dispatch.class_uninstrument(class_)
index 3e21b01023352337f263ebcb6d636ef10b870f2e..3d093d367c47e61b64e3968ba0888c5cc28c02c8 100644 (file)
@@ -60,6 +60,7 @@ from ..sql.base import ExecutableOption
 from ..sql.cache_key import HasCacheKey
 from ..sql.schema import Column
 from ..sql.type_api import TypeEngine
+from ..util.typing import DescriptorReference
 from ..util.typing import TypedDict
 
 if typing.TYPE_CHECKING:
@@ -68,7 +69,6 @@ if typing.TYPE_CHECKING:
     from ._typing import _InstanceDict
     from ._typing import _InternalEntityType
     from ._typing import _ORMAdapterProto
-    from ._typing import _ORMColumnExprArgument
     from .attributes import InstrumentedAttribute
     from .context import _MapperEntity
     from .context import ORMCompileState
@@ -89,7 +89,6 @@ if typing.TYPE_CHECKING:
     from ..sql._typing import _ColumnsClauseArgument
     from ..sql._typing import _DMLColumnArgument
     from ..sql._typing import _InfoType
-    from ..sql._typing import _PropagateAttrsType
     from ..sql.operators import OperatorType
     from ..sql.util import ColumnAdapter
     from ..sql.visitors import _TraverseInternalsType
@@ -171,12 +170,18 @@ class _MapsColumns(_MappedAttribute[_T]):
         raise NotImplementedError()
 
 
+# NOTE: MapperProperty needs to extend _MappedAttribute so that declarative
+# typing works, i.e. "Mapped[A] = relationship()".   This introduces an
+# inconvenience which is that all the MapperProperty objects are treated
+# as descriptors by typing tools, which are misled by this as assignment /
+# access to a descriptor attribute wants to move through __get__.
+# Therefore, references to MapperProperty as an instance variable, such
+# as in PropComparator, may have some special typing workarounds such as the
+# use of sqlalchemy.util.typing.DescriptorReference to avoid mis-interpretation
+# by typing tools
 @inspection._self_inspects
 class MapperProperty(
-    HasCacheKey,
-    _MappedAttribute[_T],
-    InspectionAttrInfo,
-    util.MemoizedSlots,
+    HasCacheKey, _MappedAttribute[_T], InspectionAttrInfo, util.MemoizedSlots
 ):
     """Represent a particular class attribute mapped by :class:`_orm.Mapper`.
 
@@ -522,6 +527,7 @@ class PropComparator(SQLORMOperations[_T]):
 
     _parententity: _InternalEntityType[Any]
     _adapt_to_entity: Optional[AliasedInsp[Any]]
+    prop: DescriptorReference[MapperProperty[_T]]
 
     def __init__(
         self,
@@ -533,11 +539,20 @@ class PropComparator(SQLORMOperations[_T]):
         self._parententity = adapt_to_entity or parentmapper
         self._adapt_to_entity = adapt_to_entity
 
-    @util.ro_non_memoized_property
+    @util.non_memoized_property
     def property(self) -> Optional[MapperProperty[_T]]:
+        """Return the :class:`.MapperProperty` associated with this
+        :class:`.PropComparator`.
+
+
+        Return values here will commonly be instances of
+        :class:`.ColumnProperty` or :class:`.Relationship`.
+
+
+        """
         return self.prop
 
-    def __clause_element__(self) -> _ORMColumnExprArgument[_T]:
+    def __clause_element__(self) -> roles.ColumnsClauseRole:
         raise NotImplementedError("%r" % self)
 
     def _bulk_update_tuples(
@@ -567,18 +582,6 @@ class PropComparator(SQLORMOperations[_T]):
         compatible with QueryableAttribute."""
         return self._parententity.mapper
 
-    @util.memoized_property
-    def _propagate_attrs(self) -> _PropagateAttrsType:
-        # this suits the case in coercions where we don't actually
-        # call ``__clause_element__()`` but still need to get
-        # resolved._propagate_attrs.  See #6558.
-        return util.immutabledict(
-            {
-                "compile_state_plugin": "orm",
-                "plugin_subject": self._parentmapper,
-            }
-        )
-
     def _criterion_exists(
         self,
         criterion: Optional[_ColumnExpressionArgument[bool]] = None,
@@ -657,7 +660,7 @@ class PropComparator(SQLORMOperations[_T]):
 
     def and_(
         self, *criteria: _ColumnExpressionArgument[bool]
-    ) -> ColumnElement[bool]:
+    ) -> PropComparator[bool]:
         """Add additional criteria to the ON clause that's represented by this
         relationship attribute.
 
index 66021c9c2028f6d693221330a66b183d46f2491c..514ad7023ee175db23c519d706bba47b6083c9fd 100644 (file)
@@ -71,6 +71,7 @@ from ..util.typing import Literal
 
 if typing.TYPE_CHECKING:
     from ._typing import _EntityType
+    from ._typing import _InternalEntityType
     from .mapper import Mapper
     from .util import AliasedClass
     from .util import AliasedInsp
@@ -348,7 +349,7 @@ class Relationship(
             doc=self.doc,
         )
 
-    class Comparator(PropComparator[_PT]):
+    class Comparator(util.MemoizedSlots, PropComparator[_PT]):
         """Produce boolean, comparison, and other operators for
         :class:`.Relationship` attributes.
 
@@ -369,8 +370,13 @@ class Relationship(
 
         """
 
-        _of_type = None
-        _extra_criteria = ()
+        __slots__ = (
+            "entity",
+            "mapper",
+            "property",
+            "_of_type",
+            "_extra_criteria",
+        )
 
         def __init__(
             self,
@@ -389,6 +395,8 @@ class Relationship(
             self._adapt_to_entity = adapt_to_entity
             if of_type:
                 self._of_type = of_type
+            else:
+                self._of_type = None
             self._extra_criteria = extra_criteria
 
         def adapt_to_entity(self, adapt_to_entity):
@@ -399,40 +407,35 @@ class Relationship(
                 of_type=self._of_type,
             )
 
-        @util.memoized_property
-        def entity(self):
-            """The target entity referred to by this
-            :class:`.Relationship.Comparator`.
+        entity: _InternalEntityType
+        """The target entity referred to by this
+        :class:`.Relationship.Comparator`.
 
-            This is either a :class:`_orm.Mapper` or :class:`.AliasedInsp`
-            object.
+        This is either a :class:`_orm.Mapper` or :class:`.AliasedInsp`
+        object.
 
-            This is the "target" or "remote" side of the
-            :func:`_orm.relationship`.
+        This is the "target" or "remote" side of the
+        :func:`_orm.relationship`.
 
-            """
-            # this is a relatively recent change made for
-            # 1.4.27 as part of #7244.
-            # TODO: shouldn't _of_type be inspected up front when received?
-            if self._of_type is not None:
-                return inspect(self._of_type)
-            else:
-                return self.property.entity
+        """
 
-        @util.memoized_property
-        def mapper(self):
-            """The target :class:`_orm.Mapper` referred to by this
-            :class:`.Relationship.Comparator`.
+        mapper: Mapper[Any]
+        """The target :class:`_orm.Mapper` referred to by this
+        :class:`.Relationship.Comparator`.
 
-            This is the "target" or "remote" side of the
-            :func:`_orm.relationship`.
+        This is the "target" or "remote" side of the
+        :func:`_orm.relationship`.
 
-            """
-            return self.property.mapper
+        """
 
-        @util.memoized_property
-        def _parententity(self):
-            return self.property.parent
+        def _memoized_attr_entity(self) -> _InternalEntityType:
+            if self._of_type:
+                return inspect(self._of_type)
+            else:
+                return self.prop.entity
+
+        def _memoized_attr_mapper(self) -> Mapper[Any]:
+            return self.entity.mapper
 
         def _source_selectable(self):
             if self._adapt_to_entity:
@@ -481,7 +484,9 @@ class Relationship(
                 extra_criteria=self._extra_criteria,
             )
 
-        def and_(self, *other):
+        def and_(
+            self, *criteria: _ColumnExpressionArgument[bool]
+        ) -> interfaces.PropComparator[bool]:
             """Add AND criteria.
 
             See :meth:`.PropComparator.and_` for an example.
@@ -489,12 +494,17 @@ class Relationship(
             .. versionadded:: 1.4
 
             """
+            exprs = tuple(
+                coercions.expect(roles.WhereHavingRole, clause)
+                for clause in util.coerce_generator_arg(criteria)
+            )
+
             return Relationship.Comparator(
                 self.property,
                 self._parententity,
                 adapt_to_entity=self._adapt_to_entity,
                 of_type=self._of_type,
-                extra_criteria=self._extra_criteria + other,
+                extra_criteria=self._extra_criteria + exprs,
             )
 
         def in_(self, other):
@@ -924,8 +934,7 @@ class Relationship(
             else:
                 return _orm_annotate(self.__negated_contains_or_equals(other))
 
-        @util.memoized_property
-        def property(self):
+        def _memoized_attr_property(self):
             self.prop.parent._check_configure()
             return self.prop
 
index ab32a3981a9e715be6d03f2e951eec2a69a4cc9f..49ee701b442c911f5b14edd531b5eeaa89a979cb 100644 (file)
@@ -23,6 +23,7 @@ from typing import Optional
 from typing import Set
 from typing import Tuple
 from typing import TYPE_CHECKING
+from typing import Union
 import weakref
 
 from . import base
@@ -43,6 +44,7 @@ from .path_registry import PathRegistry
 from .. import exc as sa_exc
 from .. import inspection
 from .. import util
+from ..util.typing import Literal
 from ..util.typing import Protocol
 
 if TYPE_CHECKING:
@@ -53,6 +55,7 @@ if TYPE_CHECKING:
     from .attributes import History
     from .base import LoaderCallableStatus
     from .base import PassiveFlag
+    from .collections import _AdaptedCollectionProtocol
     from .identity import IdentityMap
     from .instrumentation import ClassManager
     from .interfaces import ORMOption
@@ -421,7 +424,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
         return self.key
 
     @util.memoized_property
-    def parents(self) -> Dict[int, InstanceState[Any]]:
+    def parents(self) -> Dict[int, Union[Literal[False], InstanceState[Any]]]:
         return {}
 
     @util.memoized_property
@@ -429,7 +432,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
         return {}
 
     @util.memoized_property
-    def _empty_collections(self) -> Dict[Any, Any]:
+    def _empty_collections(self) -> Dict[str, _AdaptedCollectionProtocol]:
         return {}
 
     @util.memoized_property
@@ -844,7 +847,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
     def _modified_event(
         self,
         dict_: _InstanceDict,
-        attr: AttributeImpl,
+        attr: Optional[AttributeImpl],
         previous: Any,
         collection: bool = False,
         is_userland: bool = False,
index 7e8a6b4c6e9ad2597a96a8e8a4e6e573bcae9890..b095e3f7aedd2a786b1be7e780c0a1c4bf62145b 100644 (file)
@@ -24,6 +24,7 @@ from typing import Optional
 from typing import Sequence
 from typing import Tuple
 from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 import weakref
@@ -531,6 +532,8 @@ class AliasedClass(
 
     """
 
+    __name__: str
+
     def __init__(
         self,
         mapped_class_or_ac: _EntityType[_O],
@@ -1529,7 +1532,7 @@ class _ORMJoin(expression.Join):
         full: bool = False,
         _left_memo: Optional[Any] = None,
         _right_memo: Optional[Any] = None,
-        _extra_criteria: Sequence[ColumnElement[bool]] = (),
+        _extra_criteria: Tuple[ColumnElement[bool], ...] = (),
     ):
         left_info = cast(
             "Union[FromClause, _InternalEntityType[Any]]",
@@ -1547,6 +1550,8 @@ class _ORMJoin(expression.Join):
         self._right_memo = _right_memo
 
         if isinstance(onclause, attributes.QueryableAttribute):
+            if TYPE_CHECKING:
+                assert isinstance(onclause.comparator, Relationship.Comparator)
             on_selectable = onclause.comparator._source_selectable()
             prop = onclause.property
             _extra_criteria += onclause._extra_criteria
@@ -1728,12 +1733,15 @@ def with_parent(
     elif isinstance(prop, attributes.QueryableAttribute):
         if prop._of_type:
             from_entity = prop._of_type
-        if not prop_is_relationship(prop.property):
+        mapper_property = prop.property
+        if mapper_property is None or not prop_is_relationship(
+            mapper_property
+        ):
             raise sa_exc.ArgumentError(
                 f"Expected relationship property for with_parent(), "
-                f"got {prop.property}"
+                f"got {mapper_property}"
             )
-        prop_t = prop.property
+        prop_t = mapper_property
     else:
         prop_t = prop
 
index fb959654febce0fba2993b7410207758e993178d..248b48a250b33195c5b8558e6e35cccd7a516235 100644 (file)
@@ -133,6 +133,8 @@ class Immutable:
 
     """
 
+    __slots__ = ()
+
     _is_immutable = True
 
     def unique_params(self, *optionaldict, **kwargs):
@@ -145,7 +147,7 @@ class Immutable:
         return self
 
     def _copy_internals(
-        self, omit_attrs: Iterable[str] = (), **kw: Any
+        self, *, omit_attrs: Iterable[str] = (), **kw: Any
     ) -> None:
         pass
 
index 15fbc2afb9b48fc4fc90037ce542f166b760919b..c16fbdae1366349fc69da8a7eee278e8d3f23723 100644 (file)
@@ -36,7 +36,6 @@ if typing.TYPE_CHECKING:
     from .elements import BindParameter
     from .elements import ClauseElement
     from .visitors import _TraverseInternalsType
-    from ..engine.base import _CompiledCacheType
     from ..engine.interfaces import _CoreSingleExecuteParams
 
 
@@ -393,6 +392,13 @@ class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
         return HasCacheKey._generate_cache_key(self)
 
 
+class SlotsMemoizedHasCacheKey(HasCacheKey, util.MemoizedSlots):
+    __slots__ = ()
+
+    def _memoized_method__generate_cache_key(self) -> Optional[CacheKey]:
+        return HasCacheKey._generate_cache_key(self)
+
+
 class CacheKey(NamedTuple):
     """The key used to identify a SQL statement construct in the
     SQL compilation cache.
index 2e0112f08f3cc1d5f1651b676b399072eabacf57..2655adbdc93e221b27569f2e48d8d07f482f3126 100644 (file)
@@ -1265,7 +1265,16 @@ class ColumnAdapter(ClauseAdapter):
         if self.adapt_required and c is col:
             return None
 
-        c._allow_label_resolve = self.allow_label_resolve
+        # allow_label_resolve is consumed by one case for joined eager loading
+        # as part of its logic to prevent its own columns from being affected
+        # by .order_by().  Before full typing were applied to the ORM, this
+        # logic would set this attribute on the incoming object (which is
+        # typically a column, but we have a test for it being a non-column
+        # object) if no column were found.  While this seemed to
+        # have no negative effects, this adjustment should only occur on the
+        # new column which is assumed to be local to an adapted selectable.
+        if c is not col:
+            c._allow_label_resolve = self.allow_label_resolve
 
         return c
 
index e9b0c93f283d6dac601d306e41fef3f730207f2f..7150dedcf89ac61c6207563f78d3042a22d7522c 100644 (file)
@@ -410,11 +410,11 @@ class UniqueAppender(Generic[_T]):
         return iter(self.data)
 
 
-def coerce_generator_arg(arg):
+def coerce_generator_arg(arg: Any) -> List[Any]:
     if len(arg) == 1 and isinstance(arg[0], types.GeneratorType):
         return list(arg[0])
     else:
-        return arg
+        return cast("List[Any]", arg)
 
 
 def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]:
index 10110dbbee66ae00b2dcc113054af24de3a0fb0e..24c66bfa4e4373d53cf284c009b2f25d37b0114e 100644 (file)
@@ -1272,17 +1272,20 @@ class MemoizedSlots:
     def _fallback_getattr(self, key):
         raise AttributeError(key)
 
-    def __getattr__(self, key):
+    def __getattr__(self, key: str) -> Any:
         if key.startswith("_memoized_attr_") or key.startswith(
             "_memoized_method_"
         ):
             raise AttributeError(key)
-        elif hasattr(self, "_memoized_attr_%s" % key):
-            value = getattr(self, "_memoized_attr_%s" % key)()
+        # to avoid recursion errors when interacting with other __getattr__
+        # schemes that refer to this one, when testing for memoized method
+        # look at __class__ only rather than going into __getattr__ again.
+        elif hasattr(self.__class__, f"_memoized_attr_{key}"):
+            value = getattr(self, f"_memoized_attr_{key}")()
             setattr(self, key, value)
             return value
-        elif hasattr(self, "_memoized_method_%s" % key):
-            fn = getattr(self, "_memoized_method_%s" % key)
+        elif hasattr(self.__class__, f"_memoized_method_{key}"):
+            fn = getattr(self, f"_memoized_method_{key}")
 
             def oneshot(*args, **kw):
                 result = fn(*args, **kw)
index 4929ba1a658c65b1dd5ae27577ebe6081666a328..a95f5ab93cf14dc0f0775359aa387240d6975cd1 100644 (file)
@@ -9,6 +9,7 @@ from typing import Callable
 from typing import cast
 from typing import Dict
 from typing import ForwardRef
+from typing import Generic
 from typing import Iterable
 from typing import Optional
 from typing import Tuple
@@ -213,3 +214,41 @@ def _get_type_name(type_: Type[Any]) -> str:
             typ_name = getattr(type_, "_name", None)
 
         return typ_name  # type: ignore
+
+
+class DescriptorProto(Protocol):
+    def __get__(self, instance: object, owner: Any) -> Any:
+        ...
+
+    def __set__(self, instance: Any, value: Any) -> None:
+        ...
+
+    def __delete__(self, instance: Any) -> None:
+        ...
+
+
+_DESC = TypeVar("_DESC", bound=DescriptorProto)
+
+
+class DescriptorReference(Generic[_DESC]):
+    """a descriptor that refers to a descriptor.
+
+    used for cases where we need to have an instance variable referring to an
+    object that is itself a descriptor, which typically confuses typing tools
+    as they don't know when they should use ``__get__`` or not when referring
+    to the descriptor assignment as an instance variable. See
+    sqlalchemy.orm.interfaces.PropComparator.prop
+
+    """
+
+    def __get__(self, instance: object, owner: Any) -> _DESC:
+        ...
+
+    def __set__(self, instance: Any, value: _DESC) -> None:
+        ...
+
+    def __delete__(self, instance: Any) -> None:
+        ...
+
+
+# $def ro_descriptor_reference(fn: Callable[])
index 7830fcee68fc05a12815344fd214865b5f786560..c222a0a93309bba9e2c0dee5f71849f49c589228 100644 (file)
@@ -26,6 +26,13 @@ from sqlalchemy.testing import ne_
 from sqlalchemy.testing.util import decorator
 
 
+def _register_attribute(class_, key, **kw):
+    kw.setdefault("comparator", object())
+    kw.setdefault("parententity", object())
+
+    attributes.register_attribute(class_, key, **kw)
+
+
 @decorator
 def modifies_instrumentation_finders(fn, *args, **kw):
     pristine = instrumentation.instrumentation_finders[:]
@@ -205,13 +212,9 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
             pass
 
         register_class(User)
-        attributes.register_attribute(
-            User, "user_id", uselist=False, useobject=False
-        )
-        attributes.register_attribute(
-            User, "user_name", uselist=False, useobject=False
-        )
-        attributes.register_attribute(
+        _register_attribute(User, "user_id", uselist=False, useobject=False)
+        _register_attribute(User, "user_name", uselist=False, useobject=False)
+        _register_attribute(
             User, "email_address", uselist=False, useobject=False
         )
 
@@ -238,13 +241,13 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
                 pass
 
             register_class(User)
-            attributes.register_attribute(
+            _register_attribute(
                 User, "user_id", uselist=False, useobject=False
             )
-            attributes.register_attribute(
+            _register_attribute(
                 User, "user_name", uselist=False, useobject=False
             )
-            attributes.register_attribute(
+            _register_attribute(
                 User, "email_address", uselist=False, useobject=False
             )
 
@@ -284,12 +287,8 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
 
             manager = register_class(Foo)
             manager.expired_attribute_loader = loader
-            attributes.register_attribute(
-                Foo, "a", uselist=False, useobject=False
-            )
-            attributes.register_attribute(
-                Foo, "b", uselist=False, useobject=False
-            )
+            _register_attribute(Foo, "a", uselist=False, useobject=False)
+            _register_attribute(Foo, "b", uselist=False, useobject=False)
 
             if base is object:
                 assert Foo not in (
@@ -360,13 +359,13 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
             def func3(state, passive):
                 return "this is the shared attr"
 
-            attributes.register_attribute(
+            _register_attribute(
                 Foo, "element", uselist=False, callable_=func1, useobject=True
             )
-            attributes.register_attribute(
+            _register_attribute(
                 Foo, "element2", uselist=False, callable_=func3, useobject=True
             )
-            attributes.register_attribute(
+            _register_attribute(
                 Bar, "element", uselist=False, callable_=func2, useobject=True
             )
 
@@ -388,7 +387,7 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
 
             register_class(Post)
             register_class(Blog)
-            attributes.register_attribute(
+            _register_attribute(
                 Post,
                 "blog",
                 uselist=False,
@@ -396,7 +395,7 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
                 trackparent=True,
                 useobject=True,
             )
-            attributes.register_attribute(
+            _register_attribute(
                 Blog,
                 "posts",
                 uselist=True,
@@ -438,15 +437,11 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
 
             register_class(Foo)
             register_class(Bar)
-            attributes.register_attribute(
-                Foo, "name", uselist=False, useobject=False
-            )
-            attributes.register_attribute(
+            _register_attribute(Foo, "name", uselist=False, useobject=False)
+            _register_attribute(
                 Foo, "bars", uselist=True, trackparent=True, useobject=True
             )
-            attributes.register_attribute(
-                Bar, "name", uselist=False, useobject=False
-            )
+            _register_attribute(Bar, "name", uselist=False, useobject=False)
 
             f1 = Foo()
             f1.name = "f1"
@@ -517,10 +512,8 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
             pass
 
         register_class(Foo)
-        attributes.register_attribute(
-            Foo, "name", uselist=False, useobject=False
-        )
-        attributes.register_attribute(
+        _register_attribute(Foo, "name", uselist=False, useobject=False)
+        _register_attribute(
             Foo, "bars", uselist=True, trackparent=True, useobject=True
         )
 
index 3370fa1b550f12752fe9fe3f46c31331b6381469..81cecb08b3196e4308675954f46f53a79d2c8259 100644 (file)
@@ -147,7 +147,10 @@ class _MutableDictTestBase(_MutableDictTestFixture):
             canary.mock_calls,
             [
                 mock.call(
-                    f1, attributes.Event(Foo.data.impl, attributes.OP_MODIFIED)
+                    f1,
+                    attributes.AttributeEventToken(
+                        Foo.data.impl, attributes.OP_MODIFIED
+                    ),
                 )
             ],
         )
index 8e6ddf9d4e717af0ac742761dabc29eb28ec858e..e1274a8051e6dbebe943812752055d0ce96102b1 100644 (file)
@@ -36,6 +36,13 @@ def _set_callable(state, dict_, key, callable_):
     fn(state, dict_, None)
 
 
+def _register_attribute(class_, key, **kw):
+    kw.setdefault("comparator", object())
+    kw.setdefault("parententity", object())
+
+    attributes.register_attribute(class_, key, **kw)
+
+
 class AttributeImplAPITest(fixtures.MappedTest):
     def _scalar_obj_fixture(self):
         class A:
@@ -46,7 +53,7 @@ class AttributeImplAPITest(fixtures.MappedTest):
 
         instrumentation.register_class(A)
         instrumentation.register_class(B)
-        attributes.register_attribute(A, "b", uselist=False, useobject=True)
+        _register_attribute(A, "b", uselist=False, useobject=True)
         return A, B
 
     def _collection_obj_fixture(self):
@@ -58,7 +65,7 @@ class AttributeImplAPITest(fixtures.MappedTest):
 
         instrumentation.register_class(A)
         instrumentation.register_class(B)
-        attributes.register_attribute(A, "b", uselist=True, useobject=True)
+        _register_attribute(A, "b", uselist=True, useobject=True)
         return A, B
 
     def test_scalar_obj_remove_invalid(self):
@@ -228,13 +235,9 @@ class AttributesTest(fixtures.ORMTest):
             pass
 
         instrumentation.register_class(User)
-        attributes.register_attribute(
-            User, "user_id", uselist=False, useobject=False
-        )
-        attributes.register_attribute(
-            User, "user_name", uselist=False, useobject=False
-        )
-        attributes.register_attribute(
+        _register_attribute(User, "user_id", uselist=False, useobject=False)
+        _register_attribute(User, "user_name", uselist=False, useobject=False)
+        _register_attribute(
             User, "email_address", uselist=False, useobject=False
         )
         u = User()
@@ -263,28 +266,22 @@ class AttributesTest(fixtures.ORMTest):
     def test_pickleness(self):
         instrumentation.register_class(MyTest)
         instrumentation.register_class(MyTest2)
-        attributes.register_attribute(
-            MyTest, "user_id", uselist=False, useobject=False
-        )
-        attributes.register_attribute(
+        _register_attribute(MyTest, "user_id", uselist=False, useobject=False)
+        _register_attribute(
             MyTest, "user_name", uselist=False, useobject=False
         )
-        attributes.register_attribute(
+        _register_attribute(
             MyTest, "email_address", uselist=False, useobject=False
         )
-        attributes.register_attribute(
-            MyTest2, "a", uselist=False, useobject=False
-        )
-        attributes.register_attribute(
-            MyTest2, "b", uselist=False, useobject=False
-        )
+        _register_attribute(MyTest2, "a", uselist=False, useobject=False)
+        _register_attribute(MyTest2, "b", uselist=False, useobject=False)
 
         # shouldn't be pickling callables at the class level
 
         def somecallable(state, passive):
             return None
 
-        attributes.register_attribute(
+        _register_attribute(
             MyTest,
             "mt2",
             uselist=True,
@@ -350,9 +347,7 @@ class AttributesTest(fixtures.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
-            Foo, "bars", uselist=True, useobject=True
-        )
+        _register_attribute(Foo, "bars", uselist=True, useobject=True)
 
         assert_raises_message(
             orm_exc.ObjectDereferencedError,
@@ -367,9 +362,7 @@ class AttributesTest(fixtures.ORMTest):
             pass
 
         instrumentation.register_class(User)
-        attributes.register_attribute(
-            User, "user_name", uselist=False, useobject=False
-        )
+        _register_attribute(User, "user_name", uselist=False, useobject=False)
 
         class Blog:
             name = User.user_name
@@ -388,7 +381,7 @@ class AttributesTest(fixtures.ORMTest):
             pass
 
         instrumentation.register_class(Foo)
-        attributes.register_attribute(Foo, "b", uselist=False, useobject=False)
+        _register_attribute(Foo, "b", uselist=False, useobject=False)
 
         f1 = Foo()
 
@@ -417,7 +410,7 @@ class AttributesTest(fixtures.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(Foo, "b", uselist=False, useobject=True)
+        _register_attribute(Foo, "b", uselist=False, useobject=True)
 
         f1 = Foo()
 
@@ -444,7 +437,7 @@ class AttributesTest(fixtures.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(Foo, "b", uselist=True, useobject=True)
+        _register_attribute(Foo, "b", uselist=True, useobject=True)
 
         f1 = Foo()
 
@@ -472,8 +465,8 @@ class AttributesTest(fixtures.ORMTest):
         instrumentation.register_class(Foo)
         manager = attributes.manager_of_class(Foo)
         manager.expired_attribute_loader = loader
-        attributes.register_attribute(Foo, "a", uselist=False, useobject=False)
-        attributes.register_attribute(Foo, "b", uselist=False, useobject=False)
+        _register_attribute(Foo, "a", uselist=False, useobject=False)
+        _register_attribute(Foo, "b", uselist=False, useobject=False)
 
         f = Foo()
         attributes.instance_state(f)._expire(
@@ -518,12 +511,8 @@ class AttributesTest(fixtures.ORMTest):
         instrumentation.register_class(MyTest)
         manager = attributes.manager_of_class(MyTest)
         manager.expired_attribute_loader = loader
-        attributes.register_attribute(
-            MyTest, "a", uselist=False, useobject=False
-        )
-        attributes.register_attribute(
-            MyTest, "b", uselist=False, useobject=False
-        )
+        _register_attribute(MyTest, "a", uselist=False, useobject=False)
+        _register_attribute(MyTest, "b", uselist=False, useobject=False)
 
         m = MyTest()
         attributes.instance_state(m)._expire(
@@ -544,19 +533,13 @@ class AttributesTest(fixtures.ORMTest):
 
         instrumentation.register_class(User)
         instrumentation.register_class(Address)
-        attributes.register_attribute(
-            User, "user_id", uselist=False, useobject=False
-        )
-        attributes.register_attribute(
-            User, "user_name", uselist=False, useobject=False
-        )
-        attributes.register_attribute(
-            User, "addresses", uselist=True, useobject=True
-        )
-        attributes.register_attribute(
+        _register_attribute(User, "user_id", uselist=False, useobject=False)
+        _register_attribute(User, "user_name", uselist=False, useobject=False)
+        _register_attribute(User, "addresses", uselist=True, useobject=True)
+        _register_attribute(
             Address, "address_id", uselist=False, useobject=False
         )
-        attributes.register_attribute(
+        _register_attribute(
             Address, "email_address", uselist=False, useobject=False
         )
 
@@ -613,7 +596,7 @@ class AttributesTest(fixtures.ORMTest):
         instrumentation.register_class(Blog)
 
         # set up instrumented attributes with backrefs
-        attributes.register_attribute(
+        _register_attribute(
             Post,
             "blog",
             uselist=False,
@@ -621,7 +604,7 @@ class AttributesTest(fixtures.ORMTest):
             trackparent=True,
             useobject=True,
         )
-        attributes.register_attribute(
+        _register_attribute(
             Blog,
             "posts",
             uselist=True,
@@ -675,7 +658,7 @@ class AttributesTest(fixtures.ORMTest):
         instrumentation.register_class(Post)
         instrumentation.register_class(Blog)
 
-        attributes.register_attribute(Post, "blog", useobject=True)
+        _register_attribute(Post, "blog", useobject=True)
         assert_raises_message(
             AssertionError,
             "This AttributeImpl is not configured to track parents.",
@@ -714,13 +697,13 @@ class AttributesTest(fixtures.ORMTest):
         def func3(state, passive):
             return "this is the shared attr"
 
-        attributes.register_attribute(
+        _register_attribute(
             Foo, "element", uselist=False, callable_=func1, useobject=True
         )
-        attributes.register_attribute(
+        _register_attribute(
             Foo, "element2", uselist=False, callable_=func3, useobject=True
         )
-        attributes.register_attribute(
+        _register_attribute(
             Bar, "element", uselist=False, callable_=func2, useobject=True
         )
 
@@ -766,9 +749,7 @@ class AttributesTest(fixtures.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
-            Foo, "element", uselist=False, useobject=True
-        )
+        _register_attribute(Foo, "element", uselist=False, useobject=True)
         el = Element()
         x = Bar()
         x.element = el
@@ -804,13 +785,13 @@ class AttributesTest(fixtures.ORMTest):
         def func2(state, passive):
             return [bar1, bar2, bar3]
 
-        attributes.register_attribute(
+        _register_attribute(
             Foo, "col1", uselist=False, callable_=func1, useobject=True
         )
-        attributes.register_attribute(
+        _register_attribute(
             Foo, "col2", uselist=True, callable_=func2, useobject=True
         )
-        attributes.register_attribute(Bar, "id", uselist=False, useobject=True)
+        _register_attribute(Bar, "id", uselist=False, useobject=True)
         x = Foo()
         attributes.instance_state(x)._commit_all(attributes.instance_dict(x))
         x.col2.append(bar4)
@@ -828,10 +809,10 @@ class AttributesTest(fixtures.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
+        _register_attribute(
             Foo, "element", uselist=False, trackparent=True, useobject=True
         )
-        attributes.register_attribute(
+        _register_attribute(
             Bar, "element", uselist=False, trackparent=True, useobject=True
         )
         f1 = Foo()
@@ -878,7 +859,7 @@ class AttributesTest(fixtures.ORMTest):
             pass
 
         instrumentation.register_class(Foo)
-        attributes.register_attribute(
+        _register_attribute(
             Foo, "collection", uselist=True, typecallable=set, useobject=True
         )
         assert attributes.manager_of_class(Foo).is_instrumented("collection")
@@ -888,7 +869,7 @@ class AttributesTest(fixtures.ORMTest):
             "collection"
         )
         try:
-            attributes.register_attribute(
+            _register_attribute(
                 Foo,
                 "collection",
                 uselist=True,
@@ -911,7 +892,7 @@ class AttributesTest(fixtures.ORMTest):
             def remove(self, item):
                 del self[item.foo]
 
-        attributes.register_attribute(
+        _register_attribute(
             Foo,
             "collection",
             uselist=True,
@@ -925,7 +906,7 @@ class AttributesTest(fixtures.ORMTest):
             pass
 
         try:
-            attributes.register_attribute(
+            _register_attribute(
                 Foo,
                 "collection",
                 uselist=True,
@@ -952,7 +933,7 @@ class AttributesTest(fixtures.ORMTest):
             def remove(self, item):
                 pass
 
-        attributes.register_attribute(
+        _register_attribute(
             Foo,
             "collection",
             uselist=True,
@@ -970,9 +951,9 @@ class AttributesTest(fixtures.ORMTest):
             pass
 
         instrumentation.register_class(Foo)
-        attributes.register_attribute(Foo, "a", useobject=False)
-        attributes.register_attribute(Foo, "b", useobject=False)
-        attributes.register_attribute(Foo, "c", useobject=False)
+        _register_attribute(Foo, "a", useobject=False)
+        _register_attribute(Foo, "b", useobject=False)
+        _register_attribute(Foo, "c", useobject=False)
 
         f1 = Foo()
         state = attributes.instance_state(f1)
@@ -1026,7 +1007,7 @@ class GetNoValueTest(fixtures.ORMTest):
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
         if expected is not None:
-            attributes.register_attribute(
+            _register_attribute(
                 Foo,
                 "attr",
                 useobject=True,
@@ -1034,9 +1015,7 @@ class GetNoValueTest(fixtures.ORMTest):
                 callable_=lazy_callable,
             )
         else:
-            attributes.register_attribute(
-                Foo, "attr", useobject=True, uselist=False
-            )
+            _register_attribute(Foo, "attr", useobject=True, uselist=False)
 
         f1 = self.f1 = Foo()
         return (
@@ -1092,9 +1071,7 @@ class UtilTest(fixtures.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
-            Foo, "coll", uselist=True, useobject=True
-        )
+        _register_attribute(Foo, "coll", uselist=True, useobject=True)
 
         f1 = Foo()
         b1 = Bar()
@@ -1123,10 +1100,8 @@ class UtilTest(fixtures.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
-            Foo, "col_list", uselist=True, useobject=True
-        )
-        attributes.register_attribute(
+        _register_attribute(Foo, "col_list", uselist=True, useobject=True)
+        _register_attribute(
             Foo, "col_set", uselist=True, useobject=True, typecallable=set
         )
 
@@ -1146,8 +1121,8 @@ class UtilTest(fixtures.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(Foo, "a", uselist=False, useobject=False)
-        attributes.register_attribute(Bar, "b", uselist=False, useobject=False)
+        _register_attribute(Foo, "a", uselist=False, useobject=False)
+        _register_attribute(Bar, "b", uselist=False, useobject=False)
 
         @event.listens_for(Foo.a, "set")
         def sync_a(target, value, oldvalue, initiator):
@@ -1182,14 +1157,14 @@ class BackrefTest(fixtures.ORMTest):
 
         instrumentation.register_class(Student)
         instrumentation.register_class(Course)
-        attributes.register_attribute(
+        _register_attribute(
             Student,
             "courses",
             uselist=True,
             backref="students",
             useobject=True,
         )
-        attributes.register_attribute(
+        _register_attribute(
             Course, "students", uselist=True, backref="courses", useobject=True
         )
 
@@ -1217,7 +1192,7 @@ class BackrefTest(fixtures.ORMTest):
 
         instrumentation.register_class(Post)
         instrumentation.register_class(Blog)
-        attributes.register_attribute(
+        _register_attribute(
             Post,
             "blog",
             uselist=False,
@@ -1225,7 +1200,7 @@ class BackrefTest(fixtures.ORMTest):
             trackparent=True,
             useobject=True,
         )
-        attributes.register_attribute(
+        _register_attribute(
             Blog,
             "posts",
             uselist=True,
@@ -1266,11 +1241,11 @@ class BackrefTest(fixtures.ORMTest):
         instrumentation.register_class(Port)
         instrumentation.register_class(Jack)
 
-        attributes.register_attribute(
+        _register_attribute(
             Port, "jack", uselist=False, useobject=True, backref="port"
         )
 
-        attributes.register_attribute(
+        _register_attribute(
             Jack, "port", uselist=False, useobject=True, backref="jack"
         )
 
@@ -1306,7 +1281,7 @@ class BackrefTest(fixtures.ORMTest):
         instrumentation.register_class(Parent)
         instrumentation.register_class(Child)
         instrumentation.register_class(SubChild)
-        attributes.register_attribute(
+        _register_attribute(
             Parent,
             "child",
             uselist=False,
@@ -1314,7 +1289,7 @@ class BackrefTest(fixtures.ORMTest):
             parent_token=p_token,
             useobject=True,
         )
-        attributes.register_attribute(
+        _register_attribute(
             Child,
             "parent",
             uselist=False,
@@ -1322,7 +1297,7 @@ class BackrefTest(fixtures.ORMTest):
             parent_token=c_token,
             useobject=True,
         )
-        attributes.register_attribute(
+        _register_attribute(
             SubChild,
             "parent",
             uselist=False,
@@ -1354,7 +1329,7 @@ class BackrefTest(fixtures.ORMTest):
         instrumentation.register_class(Parent)
         instrumentation.register_class(SubParent)
         instrumentation.register_class(Child)
-        attributes.register_attribute(
+        _register_attribute(
             Parent,
             "children",
             uselist=True,
@@ -1362,7 +1337,7 @@ class BackrefTest(fixtures.ORMTest):
             parent_token=p_token,
             useobject=True,
         )
-        attributes.register_attribute(
+        _register_attribute(
             SubParent,
             "children",
             uselist=True,
@@ -1370,7 +1345,7 @@ class BackrefTest(fixtures.ORMTest):
             parent_token=p_token,
             useobject=True,
         )
-        attributes.register_attribute(
+        _register_attribute(
             Child,
             "parent",
             uselist=False,
@@ -1444,15 +1419,11 @@ class CyclicBackrefAssertionTest(fixtures.TestBase):
         instrumentation.register_class(A)
         instrumentation.register_class(B)
         instrumentation.register_class(C)
-        attributes.register_attribute(C, "a", backref="c", useobject=True)
-        attributes.register_attribute(C, "b", backref="c", useobject=True)
+        _register_attribute(C, "a", backref="c", useobject=True)
+        _register_attribute(C, "b", backref="c", useobject=True)
 
-        attributes.register_attribute(
-            A, "c", backref="a", useobject=True, uselist=True
-        )
-        attributes.register_attribute(
-            B, "c", backref="b", useobject=True, uselist=True
-        )
+        _register_attribute(A, "c", backref="a", useobject=True, uselist=True)
+        _register_attribute(B, "c", backref="b", useobject=True, uselist=True)
 
         return A, B, C
 
@@ -1470,15 +1441,11 @@ class CyclicBackrefAssertionTest(fixtures.TestBase):
         instrumentation.register_class(B)
         instrumentation.register_class(C)
 
-        attributes.register_attribute(
-            C, "a", backref="c", useobject=True, uselist=True
-        )
-        attributes.register_attribute(
-            C, "b", backref="c", useobject=True, uselist=True
-        )
+        _register_attribute(C, "a", backref="c", useobject=True, uselist=True)
+        _register_attribute(C, "b", backref="c", useobject=True, uselist=True)
 
-        attributes.register_attribute(A, "c", backref="a", useobject=True)
-        attributes.register_attribute(B, "c", backref="b", useobject=True)
+        _register_attribute(A, "c", backref="a", useobject=True)
+        _register_attribute(B, "c", backref="b", useobject=True)
 
         return A, B, C
 
@@ -1492,14 +1459,10 @@ class CyclicBackrefAssertionTest(fixtures.TestBase):
         instrumentation.register_class(A)
         instrumentation.register_class(B)
 
-        attributes.register_attribute(A, "b", backref="a1", useobject=True)
-        attributes.register_attribute(
-            B, "a1", backref="b", useobject=True, uselist=True
-        )
+        _register_attribute(A, "b", backref="a1", useobject=True)
+        _register_attribute(B, "a1", backref="b", useobject=True, uselist=True)
 
-        attributes.register_attribute(
-            B, "a2", backref="b", useobject=True, uselist=True
-        )
+        _register_attribute(B, "a2", backref="b", useobject=True, uselist=True)
 
         return A, B
 
@@ -1542,7 +1505,7 @@ class PendingBackrefTest(fixtures.ORMTest):
 
         instrumentation.register_class(Post)
         instrumentation.register_class(Blog)
-        attributes.register_attribute(
+        _register_attribute(
             Post,
             "blog",
             uselist=False,
@@ -1550,7 +1513,7 @@ class PendingBackrefTest(fixtures.ORMTest):
             trackparent=True,
             useobject=True,
         )
-        attributes.register_attribute(
+        _register_attribute(
             Blog,
             "posts",
             uselist=True,
@@ -1777,7 +1740,7 @@ class HistoryTest(fixtures.TestBase):
             pass
 
         instrumentation.register_class(Foo)
-        attributes.register_attribute(
+        _register_attribute(
             Foo,
             "someattr",
             uselist=uselist,
@@ -1797,7 +1760,7 @@ class HistoryTest(fixtures.TestBase):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
+        _register_attribute(
             Foo,
             "someattr",
             uselist=uselist,
@@ -2617,7 +2580,7 @@ class HistoryTest(fixtures.TestBase):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
+        _register_attribute(
             Foo,
             "someattr",
             uselist=True,
@@ -2675,12 +2638,8 @@ class HistoryTest(fixtures.TestBase):
             pass
 
         instrumentation.register_class(Foo)
-        attributes.register_attribute(
-            Foo, "someattr", uselist=True, useobject=True
-        )
-        attributes.register_attribute(
-            Foo, "id", uselist=False, useobject=False
-        )
+        _register_attribute(Foo, "someattr", uselist=True, useobject=True)
+        _register_attribute(Foo, "id", uselist=False, useobject=False)
         instrumentation.register_class(Bar)
         hi = Bar(name="hi")
         there = Bar(name="there")
@@ -2868,7 +2827,7 @@ class HistoryTest(fixtures.TestBase):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
+        _register_attribute(
             Foo,
             "bars",
             uselist=True,
@@ -2876,7 +2835,7 @@ class HistoryTest(fixtures.TestBase):
             trackparent=True,
             useobject=True,
         )
-        attributes.register_attribute(
+        _register_attribute(
             Bar,
             "foo",
             uselist=False,
@@ -2945,7 +2904,7 @@ class LazyloadHistoryTest(fixtures.TestBase):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
+        _register_attribute(
             Foo,
             "bars",
             uselist=True,
@@ -2954,7 +2913,7 @@ class LazyloadHistoryTest(fixtures.TestBase):
             callable_=lazyload,
             useobject=True,
         )
-        attributes.register_attribute(
+        _register_attribute(
             Bar,
             "foo",
             uselist=False,
@@ -3004,7 +2963,7 @@ class LazyloadHistoryTest(fixtures.TestBase):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
+        _register_attribute(
             Foo,
             "bars",
             uselist=True,
@@ -3063,7 +3022,7 @@ class LazyloadHistoryTest(fixtures.TestBase):
             return lazy_load
 
         instrumentation.register_class(Foo)
-        attributes.register_attribute(
+        _register_attribute(
             Foo, "bar", uselist=False, callable_=lazyload, useobject=False
         )
         lazy_load = "hi"
@@ -3119,7 +3078,7 @@ class LazyloadHistoryTest(fixtures.TestBase):
             return lazy_load
 
         instrumentation.register_class(Foo)
-        attributes.register_attribute(
+        _register_attribute(
             Foo,
             "bar",
             uselist=False,
@@ -3184,7 +3143,7 @@ class LazyloadHistoryTest(fixtures.TestBase):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
+        _register_attribute(
             Foo,
             "bar",
             uselist=False,
@@ -3254,18 +3213,12 @@ class ListenerTest(fixtures.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
-            Foo, "data", uselist=False, useobject=False
-        )
-        attributes.register_attribute(
-            Foo, "barlist", uselist=True, useobject=True
-        )
-        attributes.register_attribute(
+        _register_attribute(Foo, "data", uselist=False, useobject=False)
+        _register_attribute(Foo, "barlist", uselist=True, useobject=True)
+        _register_attribute(
             Foo, "barset", typecallable=set, uselist=True, useobject=True
         )
-        attributes.register_attribute(
-            Bar, "data", uselist=False, useobject=False
-        )
+        _register_attribute(Bar, "data", uselist=False, useobject=False)
         event.listen(Foo.data, "set", on_set, retval=True)
         event.listen(Foo.barlist, "append", append, retval=True)
         event.listen(Foo.barset, "append", append, retval=True)
@@ -3291,12 +3244,8 @@ class ListenerTest(fixtures.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
-            Foo, "data", uselist=False, useobject=False
-        )
-        attributes.register_attribute(
-            Foo, "barlist", uselist=True, useobject=True
-        )
+        _register_attribute(Foo, "data", uselist=False, useobject=False)
+        _register_attribute(Foo, "barlist", uselist=True, useobject=True)
 
         event.listen(Foo.data, "set", canary.set, named=True)
         event.listen(Foo.barlist, "append", canary.append, named=True)
@@ -3312,21 +3261,21 @@ class ListenerTest(fixtures.ORMTest):
             [
                 call.set(
                     oldvalue=attributes.NO_VALUE,
-                    initiator=attributes.Event(
+                    initiator=attributes.AttributeEventToken(
                         Foo.data.impl, attributes.OP_REPLACE
                     ),
                     target=f1,
                     value=5,
                 ),
                 call.append(
-                    initiator=attributes.Event(
+                    initiator=attributes.AttributeEventToken(
                         Foo.barlist.impl, attributes.OP_APPEND
                     ),
                     target=f1,
                     value=b1,
                 ),
                 call.remove(
-                    initiator=attributes.Event(
+                    initiator=attributes.AttributeEventToken(
                         Foo.barlist.impl, attributes.OP_REMOVE
                     ),
                     target=f1,
@@ -3344,9 +3293,7 @@ class ListenerTest(fixtures.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
-            Foo, "barlist", uselist=True, useobject=True
-        )
+        _register_attribute(Foo, "barlist", uselist=True, useobject=True)
 
         canary = Mock()
         event.listen(Foo.barlist, "init_collection", canary.init)
@@ -3389,9 +3336,7 @@ class ListenerTest(fixtures.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(
-            Foo, "barlist", uselist=True, useobject=True
-        )
+        _register_attribute(Foo, "barlist", uselist=True, useobject=True)
         canary = []
 
         def append(state, child, initiator):
@@ -3428,7 +3373,7 @@ class ListenerTest(fixtures.ORMTest):
             pass
 
         instrumentation.register_class(Foo)
-        attributes.register_attribute(Foo, "bar")
+        _register_attribute(Foo, "bar")
 
         event.listen(Foo.bar, "modified", canary)
         f1 = Foo()
@@ -3436,7 +3381,14 @@ class ListenerTest(fixtures.ORMTest):
         attributes.flag_modified(f1, "bar")
         eq_(
             canary.mock_calls,
-            [call(f1, attributes.Event(Foo.bar.impl, attributes.OP_MODIFIED))],
+            [
+                call(
+                    f1,
+                    attributes.AttributeEventToken(
+                        Foo.bar.impl, attributes.OP_MODIFIED
+                    ),
+                )
+            ],
         )
 
     def test_none_init_scalar(self):
@@ -3446,7 +3398,7 @@ class ListenerTest(fixtures.ORMTest):
             pass
 
         instrumentation.register_class(Foo)
-        attributes.register_attribute(Foo, "bar")
+        _register_attribute(Foo, "bar")
 
         event.listen(Foo.bar, "set", canary)
 
@@ -3462,7 +3414,7 @@ class ListenerTest(fixtures.ORMTest):
             pass
 
         instrumentation.register_class(Foo)
-        attributes.register_attribute(Foo, "bar", useobject=True)
+        _register_attribute(Foo, "bar", useobject=True)
 
         event.listen(Foo.bar, "set", canary)
 
@@ -3482,7 +3434,7 @@ class ListenerTest(fixtures.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(Foo, "bar", useobject=True, uselist=True)
+        _register_attribute(Foo, "bar", useobject=True, uselist=True)
 
         event.listen(Foo.bar, "set", canary)
 
@@ -3622,17 +3574,17 @@ class EventPropagateTest(fixtures.TestBase):
             instrumentation.register_class(classes[3])
 
         def attr_a():
-            attributes.register_attribute(
+            _register_attribute(
                 classes[0], "attrib", uselist=False, useobject=useobject
             )
 
         def attr_b():
-            attributes.register_attribute(
+            _register_attribute(
                 classes[1], "attrib", uselist=False, useobject=useobject
             )
 
         def attr_c():
-            attributes.register_attribute(
+            _register_attribute(
                 classes[2], "attrib", uselist=False, useobject=useobject
             )
 
@@ -3702,7 +3654,7 @@ class CollectionInitTest(fixtures.TestBase):
         self.B = B
         instrumentation.register_class(A)
         instrumentation.register_class(B)
-        attributes.register_attribute(A, "bs", uselist=True, useobject=True)
+        _register_attribute(A, "bs", uselist=True, useobject=True)
 
     def test_bulk_replace_resets_empty(self):
         A = self.A
@@ -3761,7 +3713,7 @@ class TestUnlink(fixtures.TestBase):
         self.B = B
         instrumentation.register_class(A)
         instrumentation.register_class(B)
-        attributes.register_attribute(A, "bs", uselist=True, useobject=True)
+        _register_attribute(A, "bs", uselist=True, useobject=True)
 
     def test_expired(self):
         A, B = self.A, self.B
index 806d98a69b51d294c87ac79382cbf020a1265d66..f7e8ac9f68ae3f9ee799dbba732d70d83107bb5d 100644 (file)
@@ -28,6 +28,13 @@ from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 
 
+def _register_attribute(class_, key, **kw):
+    kw.setdefault("comparator", object())
+    kw.setdefault("parententity", object())
+
+    return attributes.register_attribute(class_, key, **kw)
+
+
 class Canary:
     def __init__(self):
         self.data = set()
@@ -123,7 +130,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
 
         canary = Canary()
         instrumentation.register_class(Foo)
-        d = attributes.register_attribute(
+        d = _register_attribute(
             Foo,
             "attr",
             uselist=True,
@@ -175,7 +182,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
             pass
 
         instrumentation.register_class(Foo)
-        attributes.register_attribute(
+        _register_attribute(
             Foo,
             "attr",
             uselist=True,
@@ -212,7 +219,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
 
         canary = Canary()
         instrumentation.register_class(Foo)
-        d = attributes.register_attribute(
+        d = _register_attribute(
             Foo,
             "attr",
             uselist=True,
@@ -457,7 +464,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
 
         canary = Canary()
         instrumentation.register_class(Foo)
-        d = attributes.register_attribute(
+        d = _register_attribute(
             Foo,
             "attr",
             uselist=True,
@@ -664,7 +671,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
 
         canary = Canary()
         instrumentation.register_class(Foo)
-        d = attributes.register_attribute(
+        d = _register_attribute(
             Foo,
             "attr",
             uselist=True,
@@ -706,7 +713,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
 
         canary = Canary()
         instrumentation.register_class(Foo)
-        d = attributes.register_attribute(
+        d = _register_attribute(
             Foo,
             "attr",
             uselist=True,
@@ -974,7 +981,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
 
         canary = Canary()
         instrumentation.register_class(Foo)
-        d = attributes.register_attribute(
+        d = _register_attribute(
             Foo,
             "attr",
             uselist=True,
@@ -1115,7 +1122,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
 
         canary = Canary()
         instrumentation.register_class(Foo)
-        d = attributes.register_attribute(
+        d = _register_attribute(
             Foo,
             "attr",
             uselist=True,
@@ -1176,7 +1183,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
 
         canary = Canary()
         instrumentation.register_class(Foo)
-        d = attributes.register_attribute(
+        d = _register_attribute(
             Foo,
             "attr",
             uselist=True,
@@ -1304,7 +1311,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
 
         canary = Canary()
         instrumentation.register_class(Foo)
-        d = attributes.register_attribute(
+        d = _register_attribute(
             Foo,
             "attr",
             uselist=True,
@@ -1534,7 +1541,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
 
         canary = Canary()
         instrumentation.register_class(Foo)
-        d = attributes.register_attribute(
+        d = _register_attribute(
             Foo,
             "attr",
             uselist=True,
@@ -1695,7 +1702,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
 
         canary = Canary()
         instrumentation.register_class(Foo)
-        d = attributes.register_attribute(
+        d = _register_attribute(
             Foo, "attr", uselist=True, typecallable=Custom, useobject=True
         )
         canary.listen(d)
@@ -1769,9 +1776,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
         canary = Canary()
         creator = self.entity_maker
         instrumentation.register_class(Foo)
-        d = attributes.register_attribute(
-            Foo, "attr", uselist=True, useobject=True
-        )
+        d = _register_attribute(Foo, "attr", uselist=True, useobject=True)
         canary.listen(d)
 
         obj = Foo()
@@ -2321,10 +2326,19 @@ class CustomCollectionsTest(fixtures.MappedTest):
         replaced = set([id(b) for b in list(f.bars.values())])
         ne_(existing, replaced)
 
-    def test_list(self):
-        self._test_list(list)
+    @testing.combinations("direct", "as_callable", argnames="factory_type")
+    def test_list(self, factory_type):
+        if factory_type == "as_callable":
+            # test passing as callable
 
-    def test_list_no_setslice(self):
+            # this codepath likely was not working for many major
+            # versions, at least through 1.3
+            self._test_list(lambda: [])
+        else:
+            self._test_list(list)
+
+    @testing.combinations("direct", "as_callable", argnames="factory_type")
+    def test_list_no_setslice(self, factory_type):
         class ListLike:
             def __init__(self):
                 self.data = list()
@@ -2367,7 +2381,15 @@ class CustomCollectionsTest(fixtures.MappedTest):
             def __repr__(self):
                 return "ListLike(%s)" % repr(self.data)
 
-        self._test_list(ListLike)
+        if factory_type == "as_callable":
+            # test passing as callable
+
+            # this codepath likely was not working for many major
+            # versions, at least through 1.3
+
+            self._test_list(lambda: ListLike())
+        else:
+            self._test_list(ListLike)
 
     def _test_list(self, listcls):
         someothertable, sometable = (
@@ -2598,9 +2620,7 @@ class InstrumentationTest(fixtures.ORMTest):
             pass
 
         instrumentation.register_class(Foo)
-        attributes.register_attribute(
-            Foo, "attr", uselist=True, useobject=True
-        )
+        _register_attribute(Foo, "attr", uselist=True, useobject=True)
 
         f1 = Foo()
         f1.attr.append(3)
index 587c498eae91f4801ccc8d0bfb5f5aa4609e8262..a5506a2967de260427ade26b246f80de3ad2a755 100644 (file)
@@ -887,7 +887,13 @@ class InstrumentationTest(fixtures.ORMTest):
 
         instrumentation.register_class(Foo)
         attributes.register_attribute(
-            Foo, "attr", uselist=True, typecallable=MyDict, useobject=True
+            Foo,
+            "attr",
+            parententity=object(),
+            comparator=object(),
+            uselist=True,
+            typecallable=MyDict,
+            useobject=True,
         )
 
         f = Foo()
index 437129af16cda65615c5db531559c3257bcaef11..99d5498d6878b374b1a682cb9df2a1f6fdc16981 100644 (file)
@@ -736,7 +736,14 @@ class MiscTest(fixtures.MappedTest):
             pass
 
         manager = instrumentation.register_class(A)
-        attributes.register_attribute(A, "x", uselist=False, useobject=False)
+        attributes.register_attribute(
+            A,
+            "x",
+            comparator=object(),
+            parententity=object(),
+            uselist=False,
+            useobject=False,
+        )
 
         assert instrumentation.manager_of_class(A) is manager
         instrumentation.unregister_class(A)