]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pep484 + abc bases for assocaitionproxy
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 28 Feb 2022 04:05:46 +0000 (23:05 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 2 Mar 2022 02:05:14 +0000 (21:05 -0500)
went to this one next as it was going to be hard,
and also exercises the ORM expression hierarchy a bit.
made some adjustments to SQLCoreOperations etc.

Change-Id: Ie5dde9218dc1318252826b766d3e70b17dd24ea7
References: #6810
References: #7774

21 files changed:
doc/build/orm/internals.rst
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/clsregistry.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/sql/_elements_constructors.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/operators.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/util/langhelpers.py
lib/sqlalchemy/util/typing.py
pyproject.toml
test/ext/mypy/plain_files/association_proxy_one.py [new file with mode: 0644]
test/ext/mypy/plain_files/sql_operations.py [new file with mode: 0644]
test/ext/test_associationproxy.py
test/orm/test_inspect.py

index 05cf83b394eef8fb481a32b402c73f9850608f3d..f251e43bd0a052cc1d44e560850e9a2b196acc5b 100644 (file)
@@ -88,7 +88,10 @@ sections, are listed here.
 
             :attr:`.SchemaItem.info`
 
-.. autodata:: NOT_EXTENSION
+.. autoclass:: InspectionAttrExtensionType
+
+.. autoclass:: NotExtension
+    :members:
 
 .. autofunction:: merge_result
 
index d5119907eda88bed0dff97e981955d48d6926127..709c13c1468e33bee1bd289d277a77dbb03454b8 100644 (file)
@@ -13,19 +13,70 @@ transparent proxied access to the endpoint of an association object.
 See the example ``examples/association/proxied_association.py``.
 
 """
-import operator
+from __future__ import annotations
 
+import operator
+import typing
+from typing import AbstractSet
+from typing import Any
+from typing import cast
+from typing import Collection
+from typing import Dict
+from typing import Generic
+from typing import ItemsView
+from typing import Iterable
+from typing import Iterator
+from typing import KeysView
+from typing import Mapping
+from typing import MutableMapping
+from typing import MutableSequence
+from typing import MutableSet
+from typing import NoReturn
+from typing import Optional
+from typing import overload
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TypeVar
+from typing import Union
+from typing import ValuesView
+
+from .. import ColumnElement
 from .. import exc
 from .. import inspect
 from .. import orm
 from .. import util
 from ..orm import collections
+from ..orm import InspectionAttrExtensionType
 from ..orm import interfaces
+from ..orm import ORMDescriptor
+from ..orm.base import SQLORMOperations
+from ..sql import operators
 from ..sql import or_
+from ..sql.elements import SQLCoreOperations
 from ..sql.operators import ColumnOperators
-
-
-def association_proxy(target_collection, attr, **kw):
+from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import Self
+from ..util.typing import SupportsIndex
+
+if typing.TYPE_CHECKING:
+    from ..orm.attributes import InstrumentedAttribute
+    from ..orm.interfaces import MapperProperty
+    from ..orm.interfaces import PropComparator
+    from ..orm.mapper import Mapper
+
+_T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
+_T_con = TypeVar("_T_con", bound=Any, contravariant=True)
+_S = TypeVar("_S", bound=Any)
+_KT = TypeVar("_KT", bound=Any)
+_VT = TypeVar("_VT", bound=Any)
+
+
+def association_proxy(
+    target_collection: str, attr: str, **kw: Any
+) -> AssociationProxy[Any]:
     r"""Return a Python property implementing a view of a target
     attribute which references an attribute on members of the
     target.
@@ -80,32 +131,136 @@ def association_proxy(target_collection, attr, **kw):
     return AssociationProxy(target_collection, attr, **kw)
 
 
-ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY")
-"""Symbol indicating an :class:`.InspectionAttr` that's
+class AssociationProxyExtensionType(InspectionAttrExtensionType):
+    ASSOCIATION_PROXY = "ASSOCIATION_PROXY"
+    """Symbol indicating an :class:`.InspectionAttr` that's
     of type :class:`.AssociationProxy`.
 
-   Is assigned to the :attr:`.InspectionAttr.extension_type`
-   attribute.
+    Is assigned to the :attr:`.InspectionAttr.extension_type`
+    attribute.
+
+    """
+
+
+class _GetterProtocol(Protocol[_T_co]):
+    def __call__(self, instance: Any) -> _T_co:
+        ...
+
+
+class _SetterProtocol(Protocol[_T_co]):
+    ...
+
+
+class _PlainSetterProtocol(_SetterProtocol[_T_con]):
+    def __call__(self, instance: Any, value: _T_con) -> None:
+        ...
+
+
+class _DictSetterProtocol(_SetterProtocol[_T_con]):
+    def __call__(self, instance: Any, key: Any, value: _T_con) -> None:
+        ...
+
+
+class _CreatorProtocol(Protocol[_T_co]):
+    ...
+
+
+class _PlainCreatorProtocol(_CreatorProtocol[_T_con]):
+    def __call__(self, value: _T_con) -> Any:
+        ...
+
+
+class _KeyCreatorProtocol(_CreatorProtocol[_T_con]):
+    def __call__(self, key: Any, value: Optional[_T_con]) -> Any:
+        ...
 
-"""
 
+class _LazyCollectionProtocol(Protocol[_T]):
+    def __call__(
+        self,
+    ) -> Union[MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]]:
+        ...
+
+
+class _GetSetFactoryProtocol(Protocol):
+    def __call__(
+        self,
+        collection_class: Optional[Type[Any]],
+        assoc_instance: AssociationProxyInstance[Any],
+    ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]:
+        ...
+
+
+class _ProxyFactoryProtocol(Protocol):
+    def __call__(
+        self,
+        lazy_collection: _LazyCollectionProtocol[Any],
+        creator: _CreatorProtocol[Any],
+        value_attr: str,
+        parent: AssociationProxyInstance[Any],
+    ) -> _T:
+        ...
+
+
+class _ProxyBulkSetProtocol(Protocol):
+    def __call__(
+        self, proxy: _AssociationCollection[Any], collection: Iterable[Any]
+    ) -> None:
+        ...
+
+
+class _AssociationProxyProtocol(Protocol[_T]):
+    """describes the interface of :class:`.AssociationProxy`
+    without including descriptor methods in the interface."""
 
-class AssociationProxy(interfaces.InspectionAttrInfo):
+    creator: Optional[_CreatorProtocol[Any]]
+    key: str
+    target_collection: str
+    value_attr: str
+    getset_factory: Optional[_GetSetFactoryProtocol]
+    proxy_factory: Optional[_ProxyFactoryProtocol]
+    proxy_bulk_set: Optional[_ProxyBulkSetProtocol]
+
+    @util.memoized_property
+    def info(self) -> Dict[Any, Any]:
+        ...
+
+    def for_class(
+        self, class_: Type[Any], obj: Optional[object] = None
+    ) -> AssociationProxyInstance[_T]:
+        ...
+
+    def _default_getset(
+        self, collection_class: Any
+    ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]:
+        ...
+
+
+_SelfAssociationProxy = TypeVar(
+    "_SelfAssociationProxy", bound="AssociationProxy[Any]"
+)
+
+
+class AssociationProxy(
+    interfaces.InspectionAttrInfo,
+    ORMDescriptor[_T],
+    _AssociationProxyProtocol[_T],
+):
     """A descriptor that presents a read/write view of an object attribute."""
 
     is_attribute = True
-    extension_type = ASSOCIATION_PROXY
+    extension_type = AssociationProxyExtensionType.ASSOCIATION_PROXY
 
     def __init__(
         self,
-        target_collection,
-        attr,
-        creator=None,
-        getset_factory=None,
-        proxy_factory=None,
-        proxy_bulk_set=None,
-        info=None,
-        cascade_scalar_deletes=False,
+        target_collection: str,
+        attr: str,
+        creator: Optional[_CreatorProtocol[Any]] = None,
+        getset_factory: Optional[_GetSetFactoryProtocol] = None,
+        proxy_factory: Optional[_ProxyFactoryProtocol] = None,
+        proxy_bulk_set: Optional[_ProxyBulkSetProtocol] = None,
+        info: Optional[Dict[Any, Any]] = None,
+        cascade_scalar_deletes: bool = False,
     ):
         """Construct a new :class:`.AssociationProxy`.
 
@@ -185,27 +340,46 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
         if info:
             self.info = info
 
-    def __get__(self, obj, class_):
-        if class_ is None:
+    @overload
+    def __get__(
+        self: _SelfAssociationProxy, instance: Any, owner: Literal[None]
+    ) -> _SelfAssociationProxy:
+        ...
+
+    @overload
+    def __get__(
+        self, instance: Literal[None], owner: Any
+    ) -> AssociationProxyInstance[_T]:
+        ...
+
+    @overload
+    def __get__(self, instance: object, owner: Any) -> _T:
+        ...
+
+    def __get__(
+        self, instance: object, owner: Any
+    ) -> Union[AssociationProxyInstance[_T], _T, AssociationProxy[_T]]:
+        if owner is None:
             return self
-        inst = self._as_instance(class_, obj)
+        inst = self._as_instance(owner, instance)
         if inst:
-            return inst.get(obj)
+            return inst.get(instance)
 
-        # obj has to be None here
-        # assert obj is None
+        assert instance is None
 
         return self
 
-    def __set__(self, obj, values):
-        class_ = type(obj)
-        return self._as_instance(class_, obj).set(obj, values)
+    def __set__(self, instance: object, values: _T) -> None:
+        class_ = type(instance)
+        self._as_instance(class_, instance).set(instance, values)
 
-    def __delete__(self, obj):
-        class_ = type(obj)
-        return self._as_instance(class_, obj).delete(obj)
+    def __delete__(self, instance: object) -> None:
+        class_ = type(instance)
+        self._as_instance(class_, instance).delete(instance)
 
-    def for_class(self, class_, obj=None):
+    def for_class(
+        self, class_: Type[Any], obj: Optional[object] = None
+    ) -> AssociationProxyInstance[_T]:
         r"""Return the internal state local to a specific mapped class.
 
         E.g., given a class ``User``::
@@ -240,7 +414,9 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
         """
         return self._as_instance(class_, obj)
 
-    def _as_instance(self, class_, obj):
+    def _as_instance(
+        self, class_: Any, obj: Any
+    ) -> AssociationProxyInstance[_T]:
         try:
             inst = class_.__dict__[self.key + "_inst"]
         except KeyError:
@@ -261,11 +437,11 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
             # class, only on subclasses of it, which might be
             # different.  only return for the specific
             # object's current value
-            return inst._non_canonical_get_for_object(obj)
+            return inst._non_canonical_get_for_object(obj)  # type: ignore
         else:
-            return inst
+            return inst  # type: ignore  # TODO
 
-    def _calc_owner(self, target_cls):
+    def _calc_owner(self, target_cls: Any) -> Any:
         # we might be getting invoked for a subclass
         # that is not mapped yet, in some declarative situations.
         # save until we are mapped
@@ -280,33 +456,44 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
         else:
             return insp.mapper.class_manager.class_
 
-    def _default_getset(self, collection_class):
+    def _default_getset(
+        self, collection_class: Any
+    ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]:
         attr = self.value_attr
         _getter = operator.attrgetter(attr)
 
-        def getter(target):
-            return _getter(target) if target is not None else None
+        def getter(instance: Any) -> Optional[Any]:
+            return _getter(instance) if instance is not None else None
 
         if collection_class is dict:
 
-            def setter(o, k, v):
-                setattr(o, attr, v)
+            def dict_setter(instance: Any, k: Any, value: Any) -> None:
+                setattr(instance, attr, value)
+
+            return getter, dict_setter
 
         else:
 
-            def setter(o, v):
+            def plain_setter(o: Any, v: Any) -> None:
                 setattr(o, attr, v)
 
-        return getter, setter
+            return getter, plain_setter
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "AssociationProxy(%r, %r)" % (
             self.target_collection,
             self.value_attr,
         )
 
 
-class AssociationProxyInstance:
+_SelfAssociationProxyInstance = TypeVar(
+    "_SelfAssociationProxyInstance", bound="AssociationProxyInstance[Any]"
+)
+
+
+class AssociationProxyInstance(
+    SQLORMOperations[_T], ColumnOperators[SQLORMOperations[_T]]
+):
     """A per-class object that serves class- and object-specific results.
 
     This is used by :class:`.AssociationProxy` when it is invoked
@@ -336,7 +523,16 @@ class AssociationProxyInstance:
 
     """  # noqa
 
-    def __init__(self, parent, owning_class, target_class, value_attr):
+    collection_class: Optional[Type[Any]]
+    parent: _AssociationProxyProtocol[_T]
+
+    def __init__(
+        self,
+        parent: _AssociationProxyProtocol[_T],
+        owning_class: Type[Any],
+        target_class: Type[Any],
+        value_attr: str,
+    ):
         self.parent = parent
         self.key = parent.key
         self.owning_class = owning_class
@@ -345,7 +541,7 @@ class AssociationProxyInstance:
         self.target_class = target_class
         self.value_attr = value_attr
 
-    target_class = None
+    target_class: Type[Any]
     """The intermediary class handled by this
     :class:`.AssociationProxyInstance`.
 
@@ -355,10 +551,18 @@ class AssociationProxyInstance:
     """
 
     @classmethod
-    def for_proxy(cls, parent, owning_class, parent_instance):
+    def for_proxy(
+        cls,
+        parent: AssociationProxy[_T],
+        owning_class: Type[Any],
+        parent_instance: Any,
+    ) -> AssociationProxyInstance[_T]:
         target_collection = parent.target_collection
         value_attr = parent.value_attr
-        prop = orm.class_mapper(owning_class).get_property(target_collection)
+        prop = cast(
+            "orm.Relationship[_T]",
+            orm.class_mapper(owning_class).get_property(target_collection),
+        )
 
         # this was never asserted before but this should be made clear.
         if not isinstance(prop, orm.Relationship):
@@ -370,8 +574,9 @@ class AssociationProxyInstance:
         target_class = prop.mapper.class_
 
         try:
-            target_assoc = cls._cls_unwrap_target_assoc_proxy(
-                target_class, value_attr
+            target_assoc = cast(
+                "AssociationProxyInstance[_T]",
+                cls._cls_unwrap_target_assoc_proxy(target_class, value_attr),
             )
         except AttributeError:
             # the proxied attribute doesn't exist on the target class;
@@ -387,8 +592,13 @@ class AssociationProxyInstance:
 
     @classmethod
     def _construct_for_assoc(
-        cls, target_assoc, parent, owning_class, target_class, value_attr
-    ):
+        cls,
+        target_assoc: Optional[AssociationProxyInstance[_T]],
+        parent: _AssociationProxyProtocol[_T],
+        owning_class: Type[Any],
+        target_class: Type[Any],
+        value_attr: str,
+    ) -> AssociationProxyInstance[_T]:
         if target_assoc is not None:
             return ObjectAssociationProxyInstance(
                 parent, owning_class, target_class, value_attr
@@ -409,36 +619,41 @@ class AssociationProxyInstance:
                 parent, owning_class, target_class, value_attr
             )
 
-    def _get_property(self):
+    def _get_property(self) -> MapperProperty[Any]:
         return orm.class_mapper(self.owning_class).get_property(
             self.target_collection
         )
 
     @property
-    def _comparator(self):
+    def _comparator(self) -> PropComparator[Any]:
         return self._get_property().comparator
 
-    def __clause_element__(self):
+    def __clause_element__(self) -> NoReturn:
         raise NotImplementedError(
             "The association proxy can't be used as a plain column "
             "expression; it only works inside of a comparison expression"
         )
 
     @classmethod
-    def _cls_unwrap_target_assoc_proxy(cls, target_class, value_attr):
+    def _cls_unwrap_target_assoc_proxy(
+        cls, target_class: Any, value_attr: str
+    ) -> Optional[AssociationProxyInstance[_T]]:
         attr = getattr(target_class, value_attr)
-        if isinstance(attr, (AssociationProxy, AssociationProxyInstance)):
+        assert not isinstance(attr, AssociationProxy)
+        if isinstance(attr, AssociationProxyInstance):
             return attr
         return None
 
     @util.memoized_property
-    def _unwrap_target_assoc_proxy(self):
+    def _unwrap_target_assoc_proxy(
+        self,
+    ) -> Optional[AssociationProxyInstance[_T]]:
         return self._cls_unwrap_target_assoc_proxy(
             self.target_class, self.value_attr
         )
 
     @property
-    def remote_attr(self):
+    def remote_attr(self) -> SQLORMOperations[_T]:
         """The 'remote' class attribute referenced by this
         :class:`.AssociationProxyInstance`.
 
@@ -449,10 +664,12 @@ class AssociationProxyInstance:
             :attr:`.AssociationProxyInstance.local_attr`
 
         """
-        return getattr(self.target_class, self.value_attr)
+        return cast(
+            "SQLORMOperations[_T]", getattr(self.target_class, self.value_attr)
+        )
 
     @property
-    def local_attr(self):
+    def local_attr(self) -> SQLORMOperations[Any]:
         """The 'local' class attribute referenced by this
         :class:`.AssociationProxyInstance`.
 
@@ -463,10 +680,13 @@ class AssociationProxyInstance:
             :attr:`.AssociationProxyInstance.remote_attr`
 
         """
-        return getattr(self.owning_class, self.target_collection)
+        return cast(
+            "SQLORMOperations[Any]",
+            getattr(self.owning_class, self.target_collection),
+        )
 
     @property
-    def attr(self):
+    def attr(self) -> Tuple[SQLORMOperations[Any], SQLORMOperations[_T]]:
         """Return a tuple of ``(local_attr, remote_attr)``.
 
         This attribute was originally intended to facilitate using the
@@ -497,7 +717,7 @@ class AssociationProxyInstance:
         return (self.local_attr, self.remote_attr)
 
     @util.memoized_property
-    def scalar(self):
+    def scalar(self) -> bool:
         """Return ``True`` if this :class:`.AssociationProxyInstance`
         proxies a scalar relationship on the local side."""
 
@@ -507,7 +727,7 @@ class AssociationProxyInstance:
         return scalar
 
     @util.memoized_property
-    def _value_is_scalar(self):
+    def _value_is_scalar(self) -> bool:
         return (
             not self._get_property()
             .mapper.get_property(self.value_attr)
@@ -515,43 +735,63 @@ class AssociationProxyInstance:
         )
 
     @property
-    def _target_is_object(self):
+    def _target_is_object(self) -> bool:
         raise NotImplementedError()
 
-    def _initialize_scalar_accessors(self):
+    _scalar_get: _GetterProtocol[_T]
+    _scalar_set: _PlainSetterProtocol[_T]
+
+    def _initialize_scalar_accessors(self) -> None:
         if self.parent.getset_factory:
             get, set_ = self.parent.getset_factory(None, self)
         else:
             get, set_ = self.parent._default_getset(None)
-        self._scalar_get, self._scalar_set = get, set_
+        self._scalar_get, self._scalar_set = get, cast(
+            "_PlainSetterProtocol[_T]", set_
+        )
 
-    def _default_getset(self, collection_class):
+    def _default_getset(
+        self, collection_class: Any
+    ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]:
         attr = self.value_attr
         _getter = operator.attrgetter(attr)
 
-        def getter(target):
-            return _getter(target) if target is not None else None
+        def getter(instance: Any) -> Optional[_T]:
+            return _getter(instance) if instance is not None else None
 
         if collection_class is dict:
 
-            def setter(o, k, v):
-                return setattr(o, attr, v)
+            def dict_setter(instance: Any, k: Any, value: _T) -> None:
+                setattr(instance, attr, value)
 
+            return getter, dict_setter
         else:
 
-            def setter(o, v):
-                return setattr(o, attr, v)
+            def plain_setter(o: Any, v: _T) -> None:
+                setattr(o, attr, v)
 
-        return getter, setter
+            return getter, plain_setter
 
     @property
-    def info(self):
+    def info(self) -> Dict[Any, Any]:
         return self.parent.info
 
-    def get(self, obj):
+    @overload
+    def get(self: Self, obj: Literal[None]) -> Self:
+        ...
+
+    @overload
+    def get(self, obj: Any) -> _T:
+        ...
+
+    def get(
+        self, obj: Any
+    ) -> Union[Optional[_T], AssociationProxyInstance[_T]]:
         if obj is None:
             return self
 
+        proxy: _T
+
         if self.scalar:
             target = getattr(obj, self.target_collection)
             return self._scalar_get(target)
@@ -559,7 +799,9 @@ class AssociationProxyInstance:
             try:
                 # If the owning instance is reborn (orm session resurrect,
                 # etc.), refresh the proxy cache.
-                creator_id, self_id, proxy = getattr(obj, self.key)
+                creator_id, self_id, proxy = cast(
+                    "Tuple[int, int, _T]", getattr(obj, self.key)
+                )
             except AttributeError:
                 pass
             else:
@@ -573,12 +815,15 @@ class AssociationProxyInstance:
             setattr(obj, self.key, (id(obj), id(self), proxy))
             return proxy
 
-    def set(self, obj, values):
+    def set(self, obj: Any, values: _T) -> None:
         if self.scalar:
-            creator = (
-                self.parent.creator
-                if self.parent.creator
-                else self.target_class
+            creator = cast(
+                "_PlainCreatorProtocol[_T]",
+                (
+                    self.parent.creator
+                    if self.parent.creator
+                    else self.target_class
+                ),
             )
             target = getattr(obj, self.target_collection)
             if target is None:
@@ -595,7 +840,7 @@ class AssociationProxyInstance:
             if proxy is not values:
                 proxy._bulk_replace(self, values)
 
-    def delete(self, obj):
+    def delete(self, obj: Any) -> None:
         if self.owning_class is None:
             self._calc_owner(obj, None)
 
@@ -605,12 +850,21 @@ class AssociationProxyInstance:
                 delattr(target, self.value_attr)
         delattr(obj, self.target_collection)
 
-    def _new(self, lazy_collection):
+    def _new(
+        self, lazy_collection: _LazyCollectionProtocol[_T]
+    ) -> Tuple[Type[Any], _T]:
         creator = (
-            self.parent.creator if self.parent.creator else self.target_class
+            self.parent.creator
+            if self.parent.creator is not None
+            else cast("_CreatorProtocol[_T]", self.target_class)
         )
         collection_class = util.duck_type_collection(lazy_collection())
 
+        if collection_class is None:
+            raise exc.InvalidRequestError(
+                f"lazy collection factory did not return a "
+                f"valid collection type, got {collection_class}"
+            )
         if self.parent.proxy_factory:
             return (
                 collection_class,
@@ -627,22 +881,31 @@ class AssociationProxyInstance:
         if collection_class is list:
             return (
                 collection_class,
-                _AssociationList(
-                    lazy_collection, creator, getter, setter, self
+                cast(
+                    _T,
+                    _AssociationList(
+                        lazy_collection, creator, getter, setter, self
+                    ),
                 ),
             )
         elif collection_class is dict:
             return (
                 collection_class,
-                _AssociationDict(
-                    lazy_collection, creator, getter, setter, self
+                cast(
+                    _T,
+                    _AssociationDict(
+                        lazy_collection, creator, getter, setter, self
+                    ),
                 ),
             )
         elif collection_class is set:
             return (
                 collection_class,
-                _AssociationSet(
-                    lazy_collection, creator, getter, setter, self
+                cast(
+                    _T,
+                    _AssociationSet(
+                        lazy_collection, creator, getter, setter, self
+                    ),
                 ),
             )
         else:
@@ -650,27 +913,31 @@ class AssociationProxyInstance:
                 "could not guess which interface to use for "
                 'collection_class "%s" backing "%s"; specify a '
                 "proxy_factory and proxy_bulk_set manually"
-                % (self.collection_class.__name__, self.target_collection)
+                % (self.collection_class, self.target_collection)
             )
 
-    def _set(self, proxy, values):
+    def _set(
+        self, proxy: _AssociationCollection[Any], values: Iterable[Any]
+    ) -> None:
         if self.parent.proxy_bulk_set:
             self.parent.proxy_bulk_set(proxy, values)
         elif self.collection_class is list:
-            proxy.extend(values)
+            cast("_AssociationList[Any]", proxy).extend(values)
         elif self.collection_class is dict:
-            proxy.update(values)
+            cast("_AssociationDict[Any, Any]", proxy).update(values)
         elif self.collection_class is set:
-            proxy.update(values)
+            cast("_AssociationSet[Any]", proxy).update(values)
         else:
             raise exc.ArgumentError(
                 "no proxy_bulk_set supplied for custom "
                 "collection_class implementation"
             )
 
-    def _inflate(self, proxy):
+    def _inflate(self, proxy: _AssociationCollection[Any]) -> None:
         creator = (
-            self.parent.creator and self.parent.creator or self.target_class
+            self.parent.creator
+            and self.parent.creator
+            or cast(_CreatorProtocol[Any], self.target_class)
         )
 
         if self.parent.getset_factory:
@@ -684,12 +951,14 @@ class AssociationProxyInstance:
         proxy.getter = getter
         proxy.setter = setter
 
-    def _criterion_exists(self, criterion=None, **kwargs):
+    def _criterion_exists(
+        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+    ) -> ColumnElement[bool]:
         is_has = kwargs.pop("is_has", None)
 
         target_assoc = self._unwrap_target_assoc_proxy
         if target_assoc is not None:
-            inner = target_assoc._criterion_exists(
+            inner = target_assoc._criterion_exists(  # type: ignore
                 criterion=criterion, **kwargs
             )
             return self._comparator._criterion_exists(inner)
@@ -713,7 +982,9 @@ class AssociationProxyInstance:
 
         return self._comparator._criterion_exists(value_expr)
 
-    def any(self, criterion=None, **kwargs):
+    def any(
+        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+    ) -> SQLCoreOperations[Any]:
         """Produce a proxied 'any' expression using EXISTS.
 
         This expression will be a composed product
@@ -733,7 +1004,9 @@ class AssociationProxyInstance:
             criterion=criterion, is_has=False, **kwargs
         )
 
-    def has(self, criterion=None, **kwargs):
+    def has(
+        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+    ) -> SQLCoreOperations[Any]:
         """Produce a proxied 'has' expression using EXISTS.
 
         This expression will be a composed product
@@ -753,18 +1026,18 @@ class AssociationProxyInstance:
             criterion=criterion, is_has=True, **kwargs
         )
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "%s(%r)" % (self.__class__.__name__, self.parent)
 
 
-class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
+class AmbiguousAssociationProxyInstance(AssociationProxyInstance[_T]):
     """an :class:`.AssociationProxyInstance` where we cannot determine
     the type of target object.
     """
 
     _is_canonical = False
 
-    def _ambiguous(self):
+    def _ambiguous(self) -> NoReturn:
         raise AttributeError(
             "Association proxy %s.%s refers to an attribute '%s' that is not "
             "directly mapped on class %s; therefore this operation cannot "
@@ -778,32 +1051,38 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
             )
         )
 
-    def get(self, obj):
+    def get(self, obj: Any) -> Any:
         if obj is None:
             return self
         else:
             return super(AmbiguousAssociationProxyInstance, self).get(obj)
 
-    def __eq__(self, obj):
+    def __eq__(self, obj: object) -> NoReturn:
         self._ambiguous()
 
-    def __ne__(self, obj):
+    def __ne__(self, obj: object) -> NoReturn:
         self._ambiguous()
 
-    def any(self, criterion=None, **kwargs):
+    def any(
+        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+    ) -> NoReturn:
         self._ambiguous()
 
-    def has(self, criterion=None, **kwargs):
+    def has(
+        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+    ) -> NoReturn:
         self._ambiguous()
 
     @util.memoized_property
-    def _lookup_cache(self):
+    def _lookup_cache(self) -> Dict[Type[Any], AssociationProxyInstance[_T]]:
         # mapping of <subclass>->AssociationProxyInstance.
         # e.g. proxy is A-> A.b -> B -> B.b_attr, but B.b_attr doesn't exist;
         # only B1(B) and B2(B) have "b_attr", keys in here would be B1, B2
         return {}
 
-    def _non_canonical_get_for_object(self, parent_instance):
+    def _non_canonical_get_for_object(
+        self, parent_instance: Any
+    ) -> AssociationProxyInstance[_T]:
         if parent_instance is not None:
             actual_obj = getattr(parent_instance, self.target_collection)
             if actual_obj is not None:
@@ -826,7 +1105,9 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
         # is a proxy with generally only instance-level functionality
         return self
 
-    def _populate_cache(self, instance_class, mapper):
+    def _populate_cache(
+        self, instance_class: Any, mapper: Mapper[Any]
+    ) -> None:
         prop = orm.class_mapper(self.owning_class).get_property(
             self.target_collection
         )
@@ -841,7 +1122,7 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
                 pass
             else:
                 self._lookup_cache[instance_class] = self._construct_for_assoc(
-                    target_assoc,
+                    cast("AssociationProxyInstance[_T]", target_assoc),
                     self.parent,
                     self.owning_class,
                     target_class,
@@ -849,13 +1130,13 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
                 )
 
 
-class ObjectAssociationProxyInstance(AssociationProxyInstance):
+class ObjectAssociationProxyInstance(AssociationProxyInstance[_T]):
     """an :class:`.AssociationProxyInstance` that has an object as a target."""
 
-    _target_is_object = True
+    _target_is_object: bool = True
     _is_canonical = True
 
-    def contains(self, obj):
+    def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]:
         """Produce a proxied 'contains' expression using EXISTS.
 
         This expression will be a composed product
@@ -868,9 +1149,9 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
         target_assoc = self._unwrap_target_assoc_proxy
         if target_assoc is not None:
             return self._comparator._criterion_exists(
-                target_assoc.contains(obj)
+                target_assoc.contains(other)
                 if not target_assoc.scalar
-                else target_assoc == obj
+                else target_assoc == other
             )
         elif (
             self._target_is_object
@@ -878,7 +1159,7 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
             and not self._value_is_scalar
         ):
             return self._comparator.has(
-                getattr(self.target_class, self.value_attr).contains(obj)
+                getattr(self.target_class, self.value_attr).contains(other)
             )
         elif self._target_is_object and self.scalar and self._value_is_scalar:
             raise exc.InvalidRequestError(
@@ -886,9 +1167,11 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
             )
         else:
 
-            return self._comparator._criterion_exists(**{self.value_attr: obj})
+            return self._comparator._criterion_exists(
+                **{self.value_attr: other}
+            )
 
-    def __eq__(self, obj):
+    def __eq__(self, obj: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa E501
         # note the has() here will fail for collections; eq_()
         # is only allowed with a scalar.
         if obj is None:
@@ -899,7 +1182,7 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
         else:
             return self._comparator.has(**{self.value_attr: obj})
 
-    def __ne__(self, obj):
+    def __ne__(self, obj: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa E501
         # note the has() here will fail for collections; eq_()
         # is only allowed with a scalar.
         return self._comparator.has(
@@ -907,72 +1190,95 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
         )
 
 
-class ColumnAssociationProxyInstance(
-    ColumnOperators, AssociationProxyInstance
-):
+class ColumnAssociationProxyInstance(AssociationProxyInstance[_T]):
     """an :class:`.AssociationProxyInstance` that has a database column as a
     target.
     """
 
-    _target_is_object = False
+    _target_is_object: bool = False
     _is_canonical = True
 
-    def __eq__(self, other):
+    def __eq__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa E501
         # special case "is None" to check for no related row as well
         expr = self._criterion_exists(
-            self.remote_attr.operate(operator.eq, other)
+            self.remote_attr.operate(operators.eq, other)
         )
         if other is None:
             return or_(expr, self._comparator == None)
         else:
             return expr
 
-    def operate(self, op, *other, **kwargs):
+    def operate(
+        self, op: operators.OperatorType, *other: Any, **kwargs: Any
+    ) -> ColumnElement[Any]:
         return self._criterion_exists(
             self.remote_attr.operate(op, *other, **kwargs)
         )
 
 
-class _lazy_collection:
-    def __init__(self, obj, target):
+class _lazy_collection(_LazyCollectionProtocol[_T]):
+    def __init__(self, obj: Any, target: str):
         self.parent = obj
         self.target = target
 
-    def __call__(self):
-        return getattr(self.parent, self.target)
+    def __call__(
+        self,
+    ) -> Union[MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]]:
+        return getattr(self.parent, self.target)  # type: ignore[no-any-return]
 
-    def __getstate__(self):
+    def __getstate__(self) -> Any:
         return {"obj": self.parent, "target": self.target}
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Any) -> None:
         self.parent = state["obj"]
         self.target = state["target"]
 
 
-class _AssociationCollection:
-    def __init__(self, lazy_collection, creator, getter, setter, parent):
-        """Constructs an _AssociationCollection.
+_IT = TypeVar("_IT", bound="Any")
+"""instance type - this is the type of object inside a collection.
 
-        This will always be a subclass of either _AssociationList,
-        _AssociationSet, or _AssociationDict.
+this is not the same as the _T of AssociationProxy and
+AssociationProxyInstance itself, which will often refer to the
+collection[_IT] type.
 
-        lazy_collection
-          A callable returning a list-based collection of entities (usually an
-          object attribute managed by a SQLAlchemy relationship())
+"""
 
-        creator
-          A function that creates new target entities.  Given one parameter:
-          value.  This assertion is assumed::
 
-            obj = creator(somevalue)
-            assert getter(obj) == somevalue
+class _AssociationCollection(Generic[_IT]):
+    getter: _GetterProtocol[_IT]
+    """A function.  Given an associated object, return the 'value'."""
 
-        getter
-          A function.  Given an associated object, return the 'value'.
+    creator: _CreatorProtocol[_IT]
+    """
+    A function that creates new target entities.  Given one parameter:
+    value.  This assertion is assumed::
 
-        setter
-          A function.  Given an associated object and a value, store that
-          value on the object.
+    obj = creator(somevalue)
+    assert getter(obj) == somevalue
+    """
+
+    parent: AssociationProxyInstance[_IT]
+    setter: _SetterProtocol[_IT]
+    """A function.  Given an associated object and a value, store that
+        value on the object.
+    """
+
+    lazy_collection: _LazyCollectionProtocol[_IT]
+    """A callable returning a list-based collection of entities (usually an
+          object attribute managed by a SQLAlchemy relationship())"""
+
+    def __init__(
+        self,
+        lazy_collection: _LazyCollectionProtocol[_IT],
+        creator: _CreatorProtocol[_IT],
+        getter: _GetterProtocol[_IT],
+        setter: _SetterProtocol[_IT],
+        parent: AssociationProxyInstance[_IT],
+    ):
+        """Constructs an _AssociationCollection.
+
+        This will always be a subclass of either _AssociationList,
+        _AssociationSet, or _AssociationDict.
 
         """
         self.lazy_collection = lazy_collection
@@ -981,50 +1287,85 @@ class _AssociationCollection:
         self.setter = setter
         self.parent = parent
 
-    col = property(lambda self: self.lazy_collection())
+    if typing.TYPE_CHECKING:
+        col: Collection[_IT]
+    else:
+        col = property(lambda self: self.lazy_collection())
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.col)
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return bool(self.col)
 
     __nonzero__ = __bool__
 
-    def __getstate__(self):
+    def __getstate__(self) -> Any:
         return {"parent": self.parent, "lazy_collection": self.lazy_collection}
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Any) -> None:
         self.parent = state["parent"]
         self.lazy_collection = state["lazy_collection"]
         self.parent._inflate(self)
 
-    def _bulk_replace(self, assoc_proxy, values):
+    def clear(self) -> None:
+        raise NotImplementedError()
+
+
+class _AssociationSingleItem(_AssociationCollection[_T]):
+    setter: _PlainSetterProtocol[_T]
+    creator: _PlainCreatorProtocol[_T]
+
+    def _create(self, value: _T) -> Any:
+        return self.creator(value)
+
+    def _get(self, object_: Any) -> _T:
+        return self.getter(object_)
+
+    def _bulk_replace(
+        self, assoc_proxy: AssociationProxyInstance[Any], values: Iterable[_IT]
+    ) -> None:
         self.clear()
         assoc_proxy._set(self, values)
 
 
-class _AssociationList(_AssociationCollection):
+class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]):
     """Generic, converting, list-to-list proxy."""
 
-    def _create(self, value):
-        return self.creator(value)
+    col: MutableSequence[_T]
 
-    def _get(self, object_):
-        return self.getter(object_)
+    def _set(self, object_: Any, value: _T) -> None:
+        self.setter(object_, value)
+
+    @overload
+    def __getitem__(self, index: int) -> _T:
+        ...
 
-    def _set(self, object_, value):
-        return self.setter(object_, value)
+    @overload
+    def __getitem__(self, index: slice) -> MutableSequence[_T]:
+        ...
 
-    def __getitem__(self, index):
+    def __getitem__(
+        self, index: Union[int, slice]
+    ) -> Union[_T, MutableSequence[_T]]:
         if not isinstance(index, slice):
             return self._get(self.col[index])
         else:
             return [self._get(member) for member in self.col[index]]
 
-    def __setitem__(self, index, value):
+    @overload
+    def __setitem__(self, index: int, value: _T) -> None:
+        ...
+
+    @overload
+    def __setitem__(self, index: slice, value: Iterable[_T]) -> None:
+        ...
+
+    def __setitem__(
+        self, index: Union[int, slice], value: Union[_T, Iterable[_T]]
+    ) -> None:
         if not isinstance(index, slice):
-            self._set(self.col[index], value)
+            self._set(self.col[index], cast("_T", value))
         else:
             if index.stop is None:
                 stop = len(self)
@@ -1036,43 +1377,45 @@ class _AssociationList(_AssociationCollection):
 
             start = index.start or 0
             rng = list(range(index.start or 0, stop, step))
+
+            sized_value = list(value)
+
             if step == 1:
                 for i in rng:
                     del self[start]
                 i = start
-                for item in value:
+                for item in sized_value:
                     self.insert(i, item)
                     i += 1
             else:
-                if len(value) != len(rng):
+                if len(sized_value) != len(rng):
                     raise ValueError(
                         "attempt to assign sequence of size %s to "
-                        "extended slice of size %s" % (len(value), len(rng))
+                        "extended slice of size %s"
+                        % (len(sized_value), len(rng))
                     )
                 for i, item in zip(rng, value):
                     self._set(self.col[i], item)
 
-    def __delitem__(self, index):
+    @overload
+    def __delitem__(self, index: int) -> None:
+        ...
+
+    @overload
+    def __delitem__(self, index: slice) -> None:
+        ...
+
+    def __delitem__(self, index: Union[slice, int]) -> None:
         del self.col[index]
 
-    def __contains__(self, value):
+    def __contains__(self, value: object) -> bool:
         for member in self.col:
             # testlib.pragma exempt:__eq__
             if self._get(member) == value:
                 return True
         return False
 
-    def __getslice__(self, start, end):
-        return [self._get(member) for member in self.col[start:end]]
-
-    def __setslice__(self, start, end, values):
-        members = [self._create(v) for v in values]
-        self.col[start:end] = members
-
-    def __delslice__(self, start, end):
-        del self.col[start:end]
-
-    def __iter__(self):
+    def __iter__(self) -> Iterator[_T]:
         """Iterate over proxied values.
 
         For the actual domain objects, iterate over .col instead or
@@ -1084,255 +1427,262 @@ class _AssociationList(_AssociationCollection):
             yield self._get(member)
         return
 
-    def append(self, value):
+    def append(self, value: _T) -> None:
         col = self.col
         item = self._create(value)
         col.append(item)
 
-    def count(self, value):
+    def count(self, value: Any) -> int:
         count = 0
         for v in self:
             if v == value:
                 count += 1
         return count
 
-    def extend(self, values):
+    def extend(self, values: Iterable[_T]) -> None:
         for v in values:
             self.append(v)
 
-    def insert(self, index, value):
+    def insert(self, index: int, value: _T) -> None:
         self.col[index:index] = [self._create(value)]
 
-    def pop(self, index=-1):
+    def pop(self, index: int = -1) -> _T:
         return self.getter(self.col.pop(index))
 
-    def remove(self, value):
+    def remove(self, value: _T) -> None:
         for i, val in enumerate(self):
             if val == value:
                 del self.col[i]
                 return
         raise ValueError("value not in list")
 
-    def reverse(self):
+    def reverse(self) -> NoReturn:
         """Not supported, use reversed(mylist)"""
 
-        raise NotImplementedError
+        raise NotImplementedError()
 
-    def sort(self):
+    def sort(self) -> NoReturn:
         """Not supported, use sorted(mylist)"""
 
-        raise NotImplementedError
+        raise NotImplementedError()
 
-    def clear(self):
+    def clear(self) -> None:
         del self.col[0 : len(self.col)]
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         return list(self) == other
 
-    def __ne__(self, other):
+    def __ne__(self, other: object) -> bool:
         return list(self) != other
 
-    def __lt__(self, other):
+    def __lt__(self, other: list[_T]) -> bool:
         return list(self) < other
 
-    def __le__(self, other):
+    def __le__(self, other: list[_T]) -> bool:
         return list(self) <= other
 
-    def __gt__(self, other):
+    def __gt__(self, other: list[_T]) -> bool:
         return list(self) > other
 
-    def __ge__(self, other):
+    def __ge__(self, other: list[_T]) -> bool:
         return list(self) >= other
 
-    def __cmp__(self, other):
-        return util.cmp(list(self), other)
-
-    def __add__(self, iterable):
+    def __add__(self, other: list[_T]) -> list[_T]:
         try:
-            other = list(iterable)
+            other = list(other)
         except TypeError:
             return NotImplemented
         return list(self) + other
 
-    def __radd__(self, iterable):
+    def __radd__(self, other: list[_T]) -> list[_T]:
         try:
-            other = list(iterable)
+            other = list(other)
         except TypeError:
             return NotImplemented
         return other + list(self)
 
-    def __mul__(self, n):
+    def __mul__(self, n: SupportsIndex) -> list[_T]:
         if not isinstance(n, int):
             return NotImplemented
         return list(self) * n
 
-    __rmul__ = __mul__
+    def __rmul__(self, n: SupportsIndex) -> list[_T]:
+        if not isinstance(n, int):
+            return NotImplemented
+        return n * list(self)
 
-    def __iadd__(self, iterable):
+    def __iadd__(self: Self, iterable: Iterable[_T]) -> Self:
         self.extend(iterable)
         return self
 
-    def __imul__(self, n):
+    def __imul__(self: Self, n: SupportsIndex) -> Self:
         # unlike a regular list *=, proxied __imul__ will generate unique
         # backing objects for each copy.  *= on proxied lists is a bit of
         # a stretch anyhow, and this interpretation of the __imul__ contract
         # is more plausibly useful than copying the backing objects.
         if not isinstance(n, int):
-            return NotImplemented
+            raise NotImplementedError()
         if n == 0:
             self.clear()
         elif n > 1:
             self.extend(list(self) * (n - 1))
         return self
 
-    def index(self, item, *args):
-        return list(self).index(item, *args)
+    if typing.TYPE_CHECKING:
+        # TODO: no idea how to do this without separate "stub"
+        def index(self, value: Any, start: int = ..., stop: int = ...) -> int:
+            ...
+
+    else:
+
+        def index(self, value: Any, *arg) -> int:
+            ls = list(self)
+            return ls.index(value, *arg)
 
-    def copy(self):
+    def copy(self) -> list[_T]:
         return list(self)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return repr(list(self))
 
-    def __hash__(self):
+    def __hash__(self) -> NoReturn:
         raise TypeError("%s objects are unhashable" % type(self).__name__)
 
-    for func_name, func in list(locals().items()):
-        if (
-            callable(func)
-            and func.__name__ == func_name
-            and not func.__doc__
-            and hasattr(list, func_name)
-        ):
-            func.__doc__ = getattr(list, func_name).__doc__
-    del func_name, func
-
+    if not typing.TYPE_CHECKING:
+        for func_name, func in list(locals().items()):
+            if (
+                callable(func)
+                and func.__name__ == func_name
+                and not func.__doc__
+                and hasattr(list, func_name)
+            ):
+                func.__doc__ = getattr(list, func_name).__doc__
+        del func_name, func
 
-_NotProvided = util.symbol("_NotProvided")
 
-
-class _AssociationDict(_AssociationCollection):
+class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]):
     """Generic, converting, dict-to-dict proxy."""
 
-    def _create(self, key, value):
+    setter: _DictSetterProtocol[_VT]
+    creator: _KeyCreatorProtocol[_VT]
+    col: MutableMapping[_KT, Optional[_VT]]
+
+    def _create(self, key: _KT, value: Optional[_VT]) -> Any:
         return self.creator(key, value)
 
-    def _get(self, object_):
+    def _get(self, object_: Any) -> _VT:
         return self.getter(object_)
 
-    def _set(self, object_, key, value):
+    def _set(self, object_: Any, key: _KT, value: _VT) -> None:
         return self.setter(object_, key, value)
 
-    def __getitem__(self, key):
+    def __getitem__(self, key: _KT) -> _VT:
         return self._get(self.col[key])
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: _KT, value: _VT) -> None:
         if key in self.col:
             self._set(self.col[key], key, value)
         else:
             self.col[key] = self._create(key, value)
 
-    def __delitem__(self, key):
+    def __delitem__(self, key: _KT) -> None:
         del self.col[key]
 
-    def __contains__(self, key):
-        # testlib.pragma exempt:__hash__
+    def __contains__(self, key: object) -> bool:
         return key in self.col
 
-    def has_key(self, key):
-        # testlib.pragma exempt:__hash__
-        return key in self.col
-
-    def __iter__(self):
+    def __iter__(self) -> Iterator[_KT]:
         return iter(self.col.keys())
 
-    def clear(self):
+    def clear(self) -> None:
         self.col.clear()
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         return dict(self) == other
 
-    def __ne__(self, other):
+    def __ne__(self, other: object) -> bool:
         return dict(self) != other
 
-    def __lt__(self, other):
-        return dict(self) < other
-
-    def __le__(self, other):
-        return dict(self) <= other
+    def __repr__(self) -> str:
+        return repr(dict(self))
 
-    def __gt__(self, other):
-        return dict(self) > other
+    @overload
+    def get(self, __key: _KT) -> Optional[_VT]:
+        ...
 
-    def __ge__(self, other):
-        return dict(self) >= other
+    @overload
+    def get(self, __key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]:
+        ...
 
-    def __cmp__(self, other):
-        return util.cmp(dict(self), other)
-
-    def __repr__(self):
-        return repr(dict(self.items()))
-
-    def get(self, key, default=None):
+    def get(
+        self, key: _KT, default: Optional[Union[_VT, _T]] = None
+    ) -> Union[_VT, _T, None]:
         try:
             return self[key]
         except KeyError:
             return default
 
-    def setdefault(self, key, default=None):
+    def setdefault(self, key: _KT, default: Optional[_VT] = None) -> _VT:
+        # TODO: again, no idea how to create an actual MutableMapping.
+        # default must allow None, return type can't include None,
+        # the stub explicitly allows for default of None with a cryptic message
+        # "This overload should be allowed only if the value type is
+        # compatible with None.".
         if key not in self.col:
             self.col[key] = self._create(key, default)
-            return default
+            return default  # type: ignore
         else:
             return self[key]
 
-    def keys(self):
+    def keys(self) -> KeysView[_KT]:
         return self.col.keys()
 
-    def items(self):
-        return ((key, self._get(self.col[key])) for key in self.col)
+    def items(self) -> ItemsView[_KT, _VT]:
+        return ItemsView(self)
 
-    def values(self):
-        return (self._get(self.col[key]) for key in self.col)
+    def values(self) -> ValuesView[_VT]:
+        return ValuesView(self)
 
-    def pop(self, key, default=_NotProvided):
-        if default is _NotProvided:
-            member = self.col.pop(key)
-        else:
-            member = self.col.pop(key, default)
+    @overload
+    def pop(self, __key: _KT) -> _VT:
+        ...
+
+    @overload
+    def pop(self, __key: _KT, default: Union[_VT, _T] = ...) -> Union[_VT, _T]:
+        ...
+
+    def pop(self, __key: _KT, *arg: Any, **kw: Any) -> Union[_VT, _T]:
+        member = self.col.pop(__key, *arg, **kw)
         return self._get(member)
 
-    def popitem(self):
+    def popitem(self) -> Tuple[_KT, _VT]:
         item = self.col.popitem()
         return (item[0], self._get(item[1]))
 
-    def update(self, *a, **kw):
-        if len(a) > 1:
-            raise TypeError(
-                "update expected at most 1 arguments, got %i" % len(a)
-            )
-        elif len(a) == 1:
-            seq_or_map = a[0]
-            # discern dict from sequence - took the advice from
-            # https://www.voidspace.org.uk/python/articles/duck_typing.shtml
-            # still not perfect :(
-            if hasattr(seq_or_map, "keys"):
-                for item in seq_or_map:
-                    self[item] = seq_or_map[item]
-            else:
-                try:
-                    for k, v in seq_or_map:
-                        self[k] = v
-                except ValueError as err:
-                    raise ValueError(
-                        "dictionary update sequence "
-                        "requires 2-element tuples"
-                    ) from err
+    @overload
+    def update(self, __m: Mapping[_KT, _VT], **kwargs: _VT) -> None:
+        ...
+
+    @overload
+    def update(self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT) -> None:
+        ...
 
-        for key, value in kw:
+    @overload
+    def update(self, **kwargs: _VT) -> None:
+        ...
+
+    def update(self, *a: Any, **kw: Any) -> None:
+        up: Dict[_KT, _VT] = {}
+        up.update(*a, **kw)
+
+        for key, value in up.items():
             self[key] = value
 
-    def _bulk_replace(self, assoc_proxy, values):
+    def _bulk_replace(
+        self,
+        assoc_proxy: AssociationProxyInstance[Any],
+        values: Mapping[_KT, _VT],
+    ) -> None:
         existing = set(self)
         constants = existing.intersection(values or ())
         additions = set(values or ()).difference(constants)
@@ -1347,36 +1697,33 @@ class _AssociationDict(_AssociationCollection):
         for key in removals:
             del self[key]
 
-    def copy(self):
+    def copy(self) -> dict[_KT, _VT]:
         return dict(self.items())
 
-    def __hash__(self):
+    def __hash__(self) -> NoReturn:
         raise TypeError("%s objects are unhashable" % type(self).__name__)
 
-    for func_name, func in list(locals().items()):
-        if (
-            callable(func)
-            and func.__name__ == func_name
-            and not func.__doc__
-            and hasattr(dict, func_name)
-        ):
-            func.__doc__ = getattr(dict, func_name).__doc__
-    del func_name, func
+    if not typing.TYPE_CHECKING:
+        for func_name, func in list(locals().items()):
+            if (
+                callable(func)
+                and func.__name__ == func_name
+                and not func.__doc__
+                and hasattr(dict, func_name)
+            ):
+                func.__doc__ = getattr(dict, func_name).__doc__
+        del func_name, func
 
 
-class _AssociationSet(_AssociationCollection):
+class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]):
     """Generic, converting, set-to-set proxy."""
 
-    def _create(self, value):
-        return self.creator(value)
+    col: MutableSet[_T]
 
-    def _get(self, object_):
-        return self.getter(object_)
-
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.col)
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         if self.col:
             return True
         else:
@@ -1384,14 +1731,13 @@ class _AssociationSet(_AssociationCollection):
 
     __nonzero__ = __bool__
 
-    def __contains__(self, value):
+    def __contains__(self, __o: object) -> bool:
         for member in self.col:
-            # testlib.pragma exempt:__eq__
-            if self._get(member) == value:
+            if self._get(member) == __o:
                 return True
         return False
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[_T]:
         """Iterate over proxied values.
 
         For the actual domain objects, iterate over .col instead or just use
@@ -1402,36 +1748,37 @@ class _AssociationSet(_AssociationCollection):
             yield self._get(member)
         return
 
-    def add(self, value):
-        if value not in self:
-            self.col.add(self._create(value))
+    def add(self, __element: _T) -> None:
+        if __element not in self:
+            self.col.add(self._create(__element))
 
     # for discard and remove, choosing a more expensive check strategy rather
     # than call self.creator()
-    def discard(self, value):
+    def discard(self, __element: _T) -> None:
         for member in self.col:
-            if self._get(member) == value:
+            if self._get(member) == __element:
                 self.col.discard(member)
                 break
 
-    def remove(self, value):
+    def remove(self, __element: _T) -> None:
         for member in self.col:
-            if self._get(member) == value:
+            if self._get(member) == __element:
                 self.col.discard(member)
                 return
-        raise KeyError(value)
+        raise KeyError(__element)
 
-    def pop(self):
+    def pop(self) -> _T:
         if not self.col:
             raise KeyError("pop from an empty set")
         member = self.col.pop()
         return self._get(member)
 
-    def update(self, other):
-        for value in other:
-            self.add(value)
+    def update(self, *s: Iterable[_T]) -> None:
+        for iterable in s:
+            for value in iterable:
+                self.add(value)
 
-    def _bulk_replace(self, assoc_proxy, values):
+    def _bulk_replace(self, assoc_proxy: Any, values: Iterable[_T]) -> None:
         existing = set(self)
         constants = existing.intersection(values or ())
         additions = set(values or ()).difference(constants)
@@ -1449,56 +1796,64 @@ class _AssociationSet(_AssociationCollection):
         for member in removals:
             remover(member)
 
-    def __ior__(self, other):
+    def __ior__(
+        self: Self, other: AbstractSet[_S]
+    ) -> MutableSet[Union[_T, _S]]:
         if not collections._set_binops_check_strict(self, other):
-            return NotImplemented
+            raise NotImplementedError()
         for value in other:
             self.add(value)
         return self
 
-    def _set(self):
+    def _set(self) -> Set[_T]:
         return set(iter(self))
 
-    def union(self, other):
-        return set(self).union(other)
+    def union(self, *s: Iterable[_S]) -> MutableSet[Union[_T, _S]]:
+        return set(self).union(*s)
 
-    __or__ = union
+    def __or__(self, __s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]:
+        return self.union(__s)
 
-    def difference(self, other):
-        return set(self).difference(other)
+    def difference(self, *s: Iterable[Any]) -> MutableSet[_T]:
+        return set(self).difference(*s)
 
-    __sub__ = difference
+    def __sub__(self, s: AbstractSet[Any]) -> MutableSet[_T]:
+        return self.difference(s)
 
-    def difference_update(self, other):
-        for value in other:
-            self.discard(value)
+    def difference_update(self, *s: Iterable[Any]) -> None:
+        for other in s:
+            for value in other:
+                self.discard(value)
 
-    def __isub__(self, other):
-        if not collections._set_binops_check_strict(self, other):
-            return NotImplemented
-        for value in other:
+    def __isub__(self: Self, s: AbstractSet[Any]) -> Self:
+        if not collections._set_binops_check_strict(self, s):
+            raise NotImplementedError()
+        for value in s:
             self.discard(value)
         return self
 
-    def intersection(self, other):
-        return set(self).intersection(other)
+    def intersection(self, *s: Iterable[Any]) -> MutableSet[_T]:
+        return set(self).intersection(*s)
 
-    __and__ = intersection
+    def __and__(self, s: AbstractSet[Any]) -> MutableSet[_T]:
+        return self.intersection(s)
 
-    def intersection_update(self, other):
-        want, have = self.intersection(other), set(self)
+    def intersection_update(self, *s: Iterable[Any]) -> None:
+        for other in s:
+            want, have = self.intersection(other), set(self)
 
-        remove, add = have - want, want - have
+            remove, add = have - want, want - have
 
-        for value in remove:
-            self.remove(value)
-        for value in add:
-            self.add(value)
+            for value in remove:
+                self.remove(value)
+            for value in add:
+                self.add(value)
 
-    def __iand__(self, other):
-        if not collections._set_binops_check_strict(self, other):
-            return NotImplemented
-        want, have = self.intersection(other), set(self)
+    def __iand__(self: Self, s: AbstractSet[Any]) -> Self:
+        if not collections._set_binops_check_strict(self, s):
+            raise NotImplementedError()
+        want = self.intersection(s)
+        have: Set[_T] = set(self)
 
         remove, add = have - want, want - have
 
@@ -1508,12 +1863,13 @@ class _AssociationSet(_AssociationCollection):
             self.add(value)
         return self
 
-    def symmetric_difference(self, other):
-        return set(self).symmetric_difference(other)
+    def symmetric_difference(self, __s: Iterable[_T]) -> MutableSet[_T]:
+        return set(self).symmetric_difference(__s)
 
-    __xor__ = symmetric_difference
+    def __xor__(self, s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]:
+        return self.symmetric_difference(s)  # type: ignore
 
-    def symmetric_difference_update(self, other):
+    def symmetric_difference_update(self, other: Iterable[Any]) -> None:
         want, have = self.symmetric_difference(other), set(self)
 
         remove, add = have - want, want - have
@@ -1523,61 +1879,56 @@ class _AssociationSet(_AssociationCollection):
         for value in add:
             self.add(value)
 
-    def __ixor__(self, other):
+    def __ixor__(self, other: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]:
         if not collections._set_binops_check_strict(self, other):
-            return NotImplemented
-        want, have = self.symmetric_difference(other), set(self)
-
-        remove, add = have - want, want - have
+            raise NotImplementedError()
 
-        for value in remove:
-            self.remove(value)
-        for value in add:
-            self.add(value)
+        self.symmetric_difference_update(other)
         return self
 
-    def issubset(self, other):
-        return set(self).issubset(other)
+    def issubset(self, __s: Iterable[Any]) -> bool:
+        return set(self).issubset(__s)
 
-    def issuperset(self, other):
-        return set(self).issuperset(other)
+    def issuperset(self, __s: Iterable[Any]) -> bool:
+        return set(self).issuperset(__s)
 
-    def clear(self):
+    def clear(self) -> None:
         self.col.clear()
 
-    def copy(self):
+    def copy(self) -> AbstractSet[_T]:
         return set(self)
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         return set(self) == other
 
-    def __ne__(self, other):
+    def __ne__(self, other: object) -> bool:
         return set(self) != other
 
-    def __lt__(self, other):
+    def __lt__(self, other: AbstractSet[Any]) -> bool:
         return set(self) < other
 
-    def __le__(self, other):
+    def __le__(self, other: AbstractSet[Any]) -> bool:
         return set(self) <= other
 
-    def __gt__(self, other):
+    def __gt__(self, other: AbstractSet[Any]) -> bool:
         return set(self) > other
 
-    def __ge__(self, other):
+    def __ge__(self, other: AbstractSet[Any]) -> bool:
         return set(self) >= other
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return repr(set(self))
 
-    def __hash__(self):
+    def __hash__(self) -> NoReturn:
         raise TypeError("%s objects are unhashable" % type(self).__name__)
 
-    for func_name, func in list(locals().items()):
-        if (
-            callable(func)
-            and func.__name__ == func_name
-            and not func.__doc__
-            and hasattr(set, func_name)
-        ):
-            func.__doc__ = getattr(set, func_name).__doc__
-    del func_name, func
+    if not typing.TYPE_CHECKING:
+        for func_name, func in list(locals().items()):
+            if (
+                callable(func)
+                and func.__name__ == func_name
+                and not func.__doc__
+                and hasattr(set, func_name)
+            ):
+                func.__doc__ = getattr(set, func_name).__doc__
+        del func_name, func
index c7d9d4f887a68dddbb9ae3df1951d411bc837124..dc34a2ef58b333cad01b65523753258a97ec3f46 100644 (file)
@@ -807,45 +807,51 @@ from typing import TypeVar
 
 from .. import util
 from ..orm import attributes
+from ..orm import InspectionAttrExtensionType
 from ..orm import interfaces
+from ..orm import ORMDescriptor
+
 
 _T = TypeVar("_T", bound=Any)
 
-HYBRID_METHOD = util.symbol("HYBRID_METHOD")
-"""Symbol indicating an :class:`InspectionAttr` that's
-   of type :class:`.hybrid_method`.
 
-   Is assigned to the :attr:`.InspectionAttr.extension_type`
-   attribute.
+class HybridExtensionType(InspectionAttrExtensionType):
 
-   .. seealso::
+    HYBRID_METHOD = "HYBRID_METHOD"
+    """Symbol indicating an :class:`InspectionAttr` that's
+    of type :class:`.hybrid_method`.
 
-    :attr:`_orm.Mapper.all_orm_attributes`
+    Is assigned to the :attr:`.InspectionAttr.extension_type`
+    attribute.
 
-"""
+    .. seealso::
 
-HYBRID_PROPERTY = util.symbol("HYBRID_PROPERTY")
-"""Symbol indicating an :class:`InspectionAttr` that's
-    of type :class:`.hybrid_method`.
+        :attr:`_orm.Mapper.all_orm_attributes`
 
-   Is assigned to the :attr:`.InspectionAttr.extension_type`
-   attribute.
+    """
 
-   .. seealso::
+    HYBRID_PROPERTY = "HYBRID_PROPERTY"
+    """Symbol indicating an :class:`InspectionAttr` that's
+        of type :class:`.hybrid_method`.
 
-    :attr:`_orm.Mapper.all_orm_attributes`
+    Is assigned to the :attr:`.InspectionAttr.extension_type`
+    attribute.
 
-"""
+    .. seealso::
+
+        :attr:`_orm.Mapper.all_orm_attributes`
+
+    """
 
 
-class hybrid_method(interfaces.InspectionAttrInfo):
+class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
     """A decorator which allows definition of a Python object method with both
     instance-level and class-level behavior.
 
     """
 
     is_attribute = True
-    extension_type = HYBRID_METHOD
+    extension_type = HybridExtensionType.HYBRID_METHOD
 
     def __init__(self, func, expr=None):
         """Create a new :class:`.hybrid_method`.
@@ -890,7 +896,7 @@ class hybrid_property(interfaces.InspectionAttrInfo):
     """
 
     is_attribute = True
-    extension_type = HYBRID_PROPERTY
+    extension_type = HybridExtensionType.HYBRID_PROPERTY
 
     def __init__(
         self,
index 5a8a0f6cf60143f8553fda527a69d631f790f59a..141702ae65cde21729b33e9f819abb2f8ffcc8c0 100644 (file)
@@ -43,7 +43,11 @@ from ._orm_constructors import with_polymorphic as with_polymorphic
 from .attributes import AttributeEvent as AttributeEvent
 from .attributes import InstrumentedAttribute as InstrumentedAttribute
 from .attributes import QueryableAttribute as QueryableAttribute
+from .base import class_mapper as class_mapper
+from .base import InspectionAttrExtensionType as InspectionAttrExtensionType
 from .base import Mapped as Mapped
+from .base import NotExtension as NotExtension
+from .base import ORMDescriptor as ORMDescriptor
 from .context import QueryContext as QueryContext
 from .decl_api import add_mapped_attribute as add_mapped_attribute
 from .decl_api import as_declarative as as_declarative
@@ -75,13 +79,11 @@ from .interfaces import InspectionAttrInfo as InspectionAttrInfo
 from .interfaces import MANYTOMANY as MANYTOMANY
 from .interfaces import MANYTOONE as MANYTOONE
 from .interfaces import MapperProperty as MapperProperty
-from .interfaces import NOT_EXTENSION as NOT_EXTENSION
 from .interfaces import ONETOMANY as ONETOMANY
 from .interfaces import PropComparator as PropComparator
 from .interfaces import UserDefinedOption as UserDefinedOption
 from .loading import merge_frozen_result as merge_frozen_result
 from .loading import merge_result as merge_result
-from .mapper import class_mapper as class_mapper
 from .mapper import configure_mappers as configure_mappers
 from .mapper import Mapper as Mapper
 from .mapper import reconstructor as reconstructor
index b9c881cfe01f1064639b3a34201107716eec4292..c63a89c70427b6973bc2932a3c51be27cc64a22e 100644 (file)
 
 from __future__ import annotations
 
+from enum import Enum
 import operator
 import typing
 from typing import Any
 from typing import Callable
+from typing import Dict
 from typing import Generic
 from typing import Optional
 from typing import overload
 from typing import Tuple
+from typing import Type
 from typing import TypeVar
 from typing import Union
 
@@ -29,11 +32,13 @@ from .. import util
 from ..sql.elements import SQLCoreOperations
 from ..util.langhelpers import TypingOnly
 from ..util.typing import Concatenate
+from ..util.typing import Literal
 from ..util.typing import ParamSpec
-
+from ..util.typing import Self
 
 if typing.TYPE_CHECKING:
     from .attributes import InstrumentedAttribute
+    from .mapper import Mapper
 
 _T = TypeVar("_T", bound=Any)
 
@@ -223,16 +228,22 @@ MANYTOMANY = util.symbol(
     """,
 )
 
-NOT_EXTENSION = util.symbol(
-    "NOT_EXTENSION",
+
+class InspectionAttrExtensionType(Enum):
+    """Symbols indicating the type of extension that a
+    :class:`.InspectionAttr` is part of."""
+
+
+class NotExtension(InspectionAttrExtensionType):
+    NOT_EXTENSION = "not_extension"
     """Symbol indicating an :class:`InspectionAttr` that's
     not part of sqlalchemy.ext.
 
     Is assigned to the :attr:`.InspectionAttr.extension_type`
     attribute.
 
-    """,
-)
+    """
+
 
 _never_set = frozenset([NEVER_SET])
 
@@ -455,7 +466,7 @@ def _inspect_mapped_class(class_, configure=False):
         return mapper
 
 
-def class_mapper(class_, configure=True):
+def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]:
     """Given a class, return the primary :class:`_orm.Mapper` associated
     with the key.
 
@@ -546,17 +557,15 @@ class InspectionAttr:
     """True if this object is an instance of
     :class:`_expression.ClauseElement`."""
 
-    extension_type = NOT_EXTENSION
+    extension_type: InspectionAttrExtensionType = NotExtension.NOT_EXTENSION
     """The extension type, if any.
-    Defaults to :data:`.interfaces.NOT_EXTENSION`
+    Defaults to :attr:`.interfaces.NotExtension.NOT_EXTENSION`
 
     .. seealso::
 
-        :data:`.HYBRID_METHOD`
+        :class:`.HybridExtensionType`
 
-        :data:`.HYBRID_PROPERTY`
-
-        :data:`.ASSOCIATION_PROXY`
+        :class:`.AssociationProxyExtensionType`
 
     """
 
@@ -571,7 +580,7 @@ class InspectionAttrInfo(InspectionAttr):
     """
 
     @util.memoized_property
-    def info(self):
+    def info(self) -> Dict[Any, Any]:
         """Info dictionary associated with the object, allowing user-defined
         data to be associated with this :class:`.InspectionAttr`.
 
@@ -614,7 +623,35 @@ class SQLORMOperations(SQLCoreOperations[_T], TypingOnly):
             ...
 
 
-class Mapped(Generic[_T], TypingOnly):
+class ORMDescriptor(Generic[_T], TypingOnly):
+    """Represent any Python descriptor that provides a SQL expression
+    construct at the class level."""
+
+    __slots__ = ()
+
+    if typing.TYPE_CHECKING:
+
+        @overload
+        def __get__(self: Self, instance: Any, owner: Literal[None]) -> Self:
+            ...
+
+        @overload
+        def __get__(
+            self, instance: Literal[None], owner: Any
+        ) -> SQLORMOperations[_T]:
+            ...
+
+        @overload
+        def __get__(self, instance: object, owner: Any) -> _T:
+            ...
+
+        def __get__(
+            self, instance: object, owner: Any
+        ) -> Union[SQLORMOperations[_T], _T]:
+            ...
+
+
+class Mapped(ORMDescriptor[_T], TypingOnly):
     """Represent an ORM mapped attribute on a mapped class.
 
     This class represents the complete descriptor interface for any class
@@ -646,7 +683,7 @@ class Mapped(Generic[_T], TypingOnly):
         @overload
         def __get__(
             self, instance: None, owner: Any
-        ) -> "InstrumentedAttribute[_T]":
+        ) -> InstrumentedAttribute[_T]:
             ...
 
         @overload
@@ -655,11 +692,11 @@ class Mapped(Generic[_T], TypingOnly):
 
         def __get__(
             self, instance: object, owner: Any
-        ) -> Union["InstrumentedAttribute[_T]", _T]:
+        ) -> Union[InstrumentedAttribute[_T], _T]:
             ...
 
         @classmethod
-        def _empty_constructor(cls, arg1: Any) -> "Mapped[_T]":
+        def _empty_constructor(cls, arg1: Any) -> Mapped[_T]:
             ...
 
         def __set__(
index d0cb53e29b3e9fd4f19842d79cb7d981979dd3a2..fe6dbfdc9abae09f4fdb73557b64835eb8b8fd64 100644 (file)
@@ -293,7 +293,7 @@ class _GetColumns:
                 )
 
             desc = mp.all_orm_descriptors[key]
-            if desc.extension_type is interfaces.NOT_EXTENSION:
+            if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION:
                 prop = desc.property
                 if isinstance(prop, Synonym):
                     key = prop.name
index 00ae9dac7501d3da729fa89d541fabdf79a6040f..b1854de5a33f9ac9545e60598fe15a2f138edbb5 100644 (file)
@@ -107,6 +107,7 @@ from __future__ import annotations
 import operator
 import threading
 import typing
+from typing import Any
 import weakref
 
 from .. import exc as sa_exc
@@ -1239,13 +1240,13 @@ def _dict_decorators():
 _set_binop_bases = (set, frozenset)
 
 
-def _set_binops_check_strict(self, obj):
+def _set_binops_check_strict(self: Any, obj: Any) -> bool:
     """Allow only set, frozenset and self.__class__-derived
     objects in binops."""
     return isinstance(obj, _set_binop_bases + (self.__class__,))
 
 
-def _set_binops_check_loose(self, obj):
+def _set_binops_check_loose(self: Any, obj: Any) -> bool:
     """Allow anything set-like to participate in set binops."""
     return (
         isinstance(obj, _set_binop_bases + (self.__class__,))
index eed97352635d1725fa2e7651c5737e7d8cbec2c8..04fc07f61bc6212f974ba8a73ed3e11600cb89e9 100644 (file)
@@ -31,17 +31,19 @@ from typing import Union
 
 from . import exc as orm_exc
 from . import path_registry
-from .base import _MappedAttribute  # noqa
-from .base import EXT_CONTINUE
-from .base import EXT_SKIP
-from .base import EXT_STOP
-from .base import InspectionAttr  # noqa
-from .base import InspectionAttrInfo  # noqa
-from .base import MANYTOMANY
-from .base import MANYTOONE
-from .base import NOT_EXTENSION
-from .base import ONETOMANY
+from .base import _MappedAttribute as _MappedAttribute
+from .base import EXT_CONTINUE as EXT_CONTINUE
+from .base import EXT_SKIP as EXT_SKIP
+from .base import EXT_STOP as EXT_STOP
+from .base import InspectionAttr as InspectionAttr
+from .base import InspectionAttrExtensionType as InspectionAttrExtensionType
+from .base import InspectionAttrInfo as InspectionAttrInfo
+from .base import MANYTOMANY as MANYTOMANY
+from .base import MANYTOONE as MANYTOONE
+from .base import NotExtension as NotExtension
+from .base import ONETOMANY as ONETOMANY
 from .base import SQLORMOperations
+from .. import ColumnElement
 from .. import inspect
 from .. import inspection
 from .. import util
@@ -51,6 +53,7 @@ from ..sql import visitors
 from ..sql._typing import _ColumnsClauseElement
 from ..sql.base import ExecutableOption
 from ..sql.cache_key import HasCacheKey
+from ..sql.elements import SQLCoreOperations
 from ..sql.schema import Column
 from ..sql.type_api import TypeEngine
 from ..util.typing import TypedDict
@@ -60,22 +63,6 @@ if typing.TYPE_CHECKING:
 
 _T = TypeVar("_T", bound=Any)
 
-__all__ = (
-    "EXT_CONTINUE",
-    "EXT_STOP",
-    "EXT_SKIP",
-    "ONETOMANY",
-    "MANYTOMANY",
-    "MANYTOONE",
-    "NOT_EXTENSION",
-    "LoaderStrategy",
-    "MapperOption",
-    "LoaderOption",
-    "MapperProperty",
-    "PropComparator",
-    "StrategizedProperty",
-)
-
 
 class ORMStatementRole(roles.StatementRole):
     __slots__ = ()
@@ -190,6 +177,10 @@ class MapperProperty(
 
     """
 
+    comparator: PropComparator[_T]
+    """The :class:`_orm.PropComparator` instance that implements SQL
+    expression construction on behalf of this mapped attribute."""
+
     @property
     def _links_to_entity(self):
         """True if this MapperProperty refers to a mapped entity.
@@ -512,6 +503,11 @@ class PropComparator(
             }
         )
 
+    def _criterion_exists(
+        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+    ) -> ColumnElement[Any]:
+        return self.prop.comparator._criterion_exists(criterion, **kwargs)
+
     @property
     def adapter(self):
         """Produce a callable that adapts column expressions
@@ -547,12 +543,12 @@ class PropComparator(
 
         def operate(
             self, op: operators.OperatorType, *other: Any, **kwargs: Any
-        ) -> "SQLORMOperations":
+        ) -> "SQLCoreOperations[Any]":
             ...
 
         def reverse_operate(
             self, op: operators.OperatorType, other: Any, **kwargs: Any
-        ) -> "SQLORMOperations":
+        ) -> "SQLCoreOperations[Any]":
             ...
 
     def of_type(self, class_) -> "SQLORMOperations[_T]":
@@ -609,9 +605,11 @@ class PropComparator(
         """
         return self.operate(operators.and_, *criteria)
 
-    def any(self, criterion=None, **kwargs) -> "SQLORMOperations[_T]":
-        r"""Return true if this collection contains any member that meets the
-        given criterion.
+    def any(
+        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs
+    ) -> ColumnElement[bool]:
+        r"""Return a SQL expression representing true if this element
+        references a member which meets the given criterion.
 
         The usual implementation of ``any()`` is
         :meth:`.Relationship.Comparator.any`.
@@ -627,9 +625,11 @@ class PropComparator(
 
         return self.operate(PropComparator.any_op, criterion, **kwargs)
 
-    def has(self, criterion=None, **kwargs) -> "SQLORMOperations[_T]":
-        r"""Return true if this element references a member which meets the
-        given criterion.
+    def has(
+        self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs
+    ) -> ColumnElement[bool]:
+        r"""Return a SQL expression representing true if this element
+        references a member which meets the given criterion.
 
         The usual implementation of ``has()`` is
         :meth:`.Relationship.Comparator.has`.
index 15e9b84311c73b0f70544b5f1d7a308615cec8fe..5a34188a9cf28441ccd26fe40cb8397ba3ffaa90 100644 (file)
@@ -21,6 +21,7 @@ from functools import reduce
 from itertools import chain
 import sys
 import threading
+from typing import Any
 from typing import Generic
 from typing import Type
 from typing import TypeVar
@@ -113,6 +114,9 @@ class Mapper(
     _dispose_called = False
     _ready_for_configure = False
 
+    class_: Type[_MC]
+    """The class to which this :class:`_orm.Mapper` is mapped."""
+
     @util.deprecated_params(
         non_primary=(
             "1.3",
@@ -1984,10 +1988,12 @@ class Mapper(
         else:
             return False
 
-    def has_property(self, key):
+    def has_property(self, key: str) -> bool:
         return key in self._props
 
-    def get_property(self, key, _configure_mappers=True):
+    def get_property(
+        self, key: str, _configure_mappers: bool = True
+    ) -> MapperProperty[Any]:
         """return a MapperProperty associated with the given key."""
 
         if _configure_mappers:
@@ -2715,7 +2721,7 @@ class Mapper(
         else:
             return _state_mapper(state) is s
 
-    def isa(self, other):
+    def isa(self, other: Mapper[Any]) -> bool:
         """Return True if the this mapper inherits from the given mapper."""
 
         m = self
index 1b8f778c0a80eb83b3660838257cd139092553b7..b4697912bce9ee5acc8836c34133435e95dd5573 100644 (file)
@@ -42,6 +42,7 @@ from .util import _orm_annotate
 from .util import _orm_deannotate
 from .util import CascadeOptions
 from .. import exc as sa_exc
+from .. import Exists
 from .. import log
 from .. import schema
 from .. import sql
@@ -52,6 +53,7 @@ from ..sql import expression
 from ..sql import operators
 from ..sql import roles
 from ..sql import visitors
+from ..sql.elements import SQLCoreOperations
 from ..sql.util import _deep_deannotate
 from ..sql.util import _shallow_annotate
 from ..sql.util import adapt_criterion_to_null
@@ -534,7 +536,11 @@ class Relationship(
                     )
                 )
 
-        def _criterion_exists(self, criterion=None, **kwargs):
+        def _criterion_exists(
+            self,
+            criterion: Optional[SQLCoreOperations[Any]] = None,
+            **kwargs: Any,
+        ) -> Exists[bool]:
             if getattr(self, "_of_type", None):
                 info = inspect(self._of_type)
                 target_mapper, to_selectable, is_aliased_class = (
@@ -1327,7 +1333,7 @@ class Relationship(
         return self.entity
 
     @util.memoized_property
-    def mapper(self) -> "Mapper":
+    def mapper(self) -> Mapper[_T]:
         """Return the targeted :class:`_orm.Mapper` for this
         :class:`.Relationship`.
 
index 4f6ff06889129dd73096ba85c44fd1b7b4153488..9d15cdcc3a89a3620fb593ed4732e69e53327ae1 100644 (file)
@@ -35,6 +35,7 @@ from .elements import FunctionFilter
 from .elements import Label
 from .elements import Null
 from .elements import Over
+from .elements import SQLCoreOperations
 from .elements import TextClause
 from .elements import True_
 from .elements import Tuple
@@ -1228,7 +1229,7 @@ def nulls_last(column):
     return UnaryExpression._create_nulls_last(column)
 
 
-def or_(*clauses):
+def or_(*clauses: SQLCoreOperations) -> BooleanClauseList:
     """Produce a conjunction of expressions joined by ``OR``.
 
     E.g.::
index ac5dc46db11a0bbfdcffa43152d5f4c7af42134f..4c38c4efabe491941f80a79166ab5e76d1705d74 100644 (file)
@@ -44,6 +44,7 @@ from .base import SingletonConstant
 from .cache_key import MemoizedHasCacheKey
 from .cache_key import NO_CACHE
 from .coercions import _document_text_coercion  # noqa
+from .operators import ColumnOperators
 from .traversals import HasCopyInternals
 from .visitors import cloned_traverse
 from .visitors import InternalTraversal
@@ -57,6 +58,7 @@ from ..util.langhelpers import TypingOnly
 if typing.TYPE_CHECKING:
     from decimal import Decimal
 
+    from .operators import OperatorType
     from .selectable import FromClause
     from .selectable import Select
     from .sqltypes import Boolean  # noqa
@@ -586,7 +588,9 @@ class CompilerColumnElement(
     __slots__ = ()
 
 
-class SQLCoreOperations(Generic[_T], TypingOnly):
+class SQLCoreOperations(
+    Generic[_T], ColumnOperators["SQLCoreOperations"], TypingOnly
+):
     __slots__ = ()
 
     # annotations for comparison methods
@@ -594,6 +598,16 @@ class SQLCoreOperations(Generic[_T], TypingOnly):
     # redefined with the specific types returned by ColumnElement hierarchies
     if typing.TYPE_CHECKING:
 
+        def operate(
+            self, op: OperatorType, *other: Any, **kwargs: Any
+        ) -> ColumnElement:
+            ...
+
+        def reverse_operate(
+            self, op: OperatorType, other: Any, **kwargs: Any
+        ) -> ColumnElement:
+            ...
+
         def op(
             self,
             opstring: Any,
@@ -620,34 +634,34 @@ class SQLCoreOperations(Generic[_T], TypingOnly):
         def __invert__(self) -> "UnaryExpression[_T]":
             ...
 
-        def __lt__(self, other: Any) -> "BinaryExpression[bool]":
+        def __lt__(self, other: Any) -> "ColumnElement[bool]":
             ...
 
-        def __le__(self, other: Any) -> "BinaryExpression[bool]":
+        def __le__(self, other: Any) -> "ColumnElement[bool]":
             ...
 
-        def __eq__(self, other: Any) -> "BinaryExpression[bool]":
+        def __eq__(self, other: Any) -> "ColumnElement[bool]":  # type: ignore[override]  # noqa: E501
             ...
 
-        def __ne__(self, other: Any) -> "BinaryExpression[bool]":
+        def __ne__(self, other: Any) -> "ColumnElement[bool]":  # type: ignore[override]  # noqa: E501
             ...
 
-        def is_distinct_from(self, other: Any) -> "BinaryExpression[bool]":
+        def is_distinct_from(self, other: Any) -> "ColumnElement[bool]":
             ...
 
-        def is_not_distinct_from(self, other: Any) -> "BinaryExpression[bool]":
+        def is_not_distinct_from(self, other: Any) -> "ColumnElement[bool]":
             ...
 
-        def __gt__(self, other: Any) -> "BinaryExpression[bool]":
+        def __gt__(self, other: Any) -> "ColumnElement[bool]":
             ...
 
-        def __ge__(self, other: Any) -> "BinaryExpression[bool]":
+        def __ge__(self, other: Any) -> "ColumnElement[bool]":
             ...
 
         def __neg__(self) -> "UnaryExpression[_T]":
             ...
 
-        def __contains__(self, other: Any) -> "BinaryExpression[bool]":
+        def __contains__(self, other: Any) -> "ColumnElement[bool]":
             ...
 
         def __getitem__(self, index: Any) -> "ColumnElement":
@@ -656,14 +670,14 @@ class SQLCoreOperations(Generic[_T], TypingOnly):
         @overload
         def concat(
             self: "SQLCoreOperations[_ST]", other: Any
-        ) -> "BinaryExpression[_ST]":
+        ) -> "ColumnElement[_ST]":
             ...
 
         @overload
-        def concat(self, other: Any) -> "BinaryExpression":
+        def concat(self, other: Any) -> "ColumnElement":
             ...
 
-        def concat(self, other: Any) -> "BinaryExpression":
+        def concat(self, other: Any) -> "ColumnElement":
             ...
 
         def like(self, other: Any, escape=None) -> "BinaryExpression[bool]":
@@ -702,30 +716,26 @@ class SQLCoreOperations(Generic[_T], TypingOnly):
 
         def startswith(
             self, other: Any, escape=None, autoescape=False
-        ) -> "BinaryExpression[bool]":
+        ) -> "ColumnElement[bool]":
             ...
 
         def endswith(
             self, other: Any, escape=None, autoescape=False
-        ) -> "BinaryExpression[bool]":
+        ) -> "ColumnElement[bool]":
             ...
 
-        def contains(
-            self, other: Any, escape=None, autoescape=False
-        ) -> "BinaryExpression[bool]":
+        def contains(self, other: Any, **kw: Any) -> "ColumnElement[bool]":
             ...
 
-        def match(self, other: Any, **kwargs) -> "BinaryExpression[bool]":
+        def match(self, other: Any, **kwargs) -> "ColumnElement[bool]":
             ...
 
-        def regexp_match(
-            self, pattern, flags=None
-        ) -> "BinaryExpression[bool]":
+        def regexp_match(self, pattern, flags=None) -> "ColumnElement[bool]":
             ...
 
         def regexp_replace(
             self, pattern, replacement, flags=None
-        ) -> "BinaryExpression":
+        ) -> "ColumnElement":
             ...
 
         def desc(self) -> "UnaryExpression[_T]":
@@ -745,7 +755,7 @@ class SQLCoreOperations(Generic[_T], TypingOnly):
 
         def between(
             self, cleft, cright, symmetric=False
-        ) -> "BinaryExpression[bool]":
+        ) -> "ColumnElement[bool]":
             ...
 
         def distinct(self: "SQLCoreOperations[_T]") -> "UnaryExpression[_T]":
@@ -766,166 +776,166 @@ class SQLCoreOperations(Generic[_T], TypingOnly):
         def __add__(
             self: "Union[_SQO[_NT], _SQO[Optional[_NT]]]",
             other: "Union[_SQO[Optional[_NT]], _SQO[_NT], _NT]",
-        ) -> "BinaryExpression[_NT]":
+        ) -> "ColumnElement[_NT]":
             ...
 
         @overload
         def __add__(
             self: "Union[_SQO[_NT], _SQO[Optional[_NT]]]",
             other: Any,
-        ) -> "BinaryExpression[_NUMERIC]":
+        ) -> "ColumnElement[_NUMERIC]":
             ...
 
         @overload
         def __add__(
             self: "Union[_SQO[_ST], _SQO[Optional[_ST]]]",
             other: Any,
-        ) -> "BinaryExpression[_ST]":
+        ) -> "ColumnElement[_ST]":
             ...
 
-        def __add__(self, other: Any) -> "BinaryExpression":
+        def __add__(self, other: Any) -> "ColumnElement":
             ...
 
         @overload
-        def __radd__(self, other: Any) -> "BinaryExpression[_NUMERIC]":
+        def __radd__(self, other: Any) -> "ColumnElement[_NUMERIC]":
             ...
 
         @overload
-        def __radd__(self, other: Any) -> "BinaryExpression":
+        def __radd__(self, other: Any) -> "ColumnElement":
             ...
 
-        def __radd__(self, other: Any) -> "BinaryExpression":
+        def __radd__(self, other: Any) -> "ColumnElement":
             ...
 
         @overload
         def __sub__(
             self: "SQLCoreOperations[_NT]",
             other: "Union[SQLCoreOperations[_NT], _NT]",
-        ) -> "BinaryExpression[_NT]":
+        ) -> "ColumnElement[_NT]":
             ...
 
         @overload
-        def __sub__(self, other: Any) -> "BinaryExpression":
+        def __sub__(self, other: Any) -> "ColumnElement":
             ...
 
-        def __sub__(self, other: Any) -> "BinaryExpression":
+        def __sub__(self, other: Any) -> "ColumnElement":
             ...
 
         @overload
         def __rsub__(
             self: "SQLCoreOperations[_NT]", other: Any
-        ) -> "BinaryExpression[_NUMERIC]":
+        ) -> "ColumnElement[_NUMERIC]":
             ...
 
         @overload
-        def __rsub__(self, other: Any) -> "BinaryExpression":
+        def __rsub__(self, other: Any) -> "ColumnElement":
             ...
 
-        def __rsub__(self, other: Any) -> "BinaryExpression":
+        def __rsub__(self, other: Any) -> "ColumnElement":
             ...
 
         @overload
         def __mul__(
             self: "SQLCoreOperations[_NT]", other: Any
-        ) -> "BinaryExpression[_NUMERIC]":
+        ) -> "ColumnElement[_NUMERIC]":
             ...
 
         @overload
-        def __mul__(self, other: Any) -> "BinaryExpression":
+        def __mul__(self, other: Any) -> "ColumnElement":
             ...
 
-        def __mul__(self, other: Any) -> "BinaryExpression":
+        def __mul__(self, other: Any) -> "ColumnElement":
             ...
 
         @overload
         def __rmul__(
             self: "SQLCoreOperations[_NT]", other: Any
-        ) -> "BinaryExpression[_NUMERIC]":
+        ) -> "ColumnElement[_NUMERIC]":
             ...
 
         @overload
-        def __rmul__(self, other: Any) -> "BinaryExpression":
+        def __rmul__(self, other: Any) -> "ColumnElement":
             ...
 
-        def __rmul__(self, other: Any) -> "BinaryExpression":
+        def __rmul__(self, other: Any) -> "ColumnElement":
             ...
 
         @overload
         def __mod__(
             self: "SQLCoreOperations[_NT]", other: Any
-        ) -> "BinaryExpression[_NUMERIC]":
+        ) -> "ColumnElement[_NUMERIC]":
             ...
 
         @overload
-        def __mod__(self, other: Any) -> "BinaryExpression":
+        def __mod__(self, other: Any) -> "ColumnElement":
             ...
 
-        def __mod__(self, other: Any) -> "BinaryExpression":
+        def __mod__(self, other: Any) -> "ColumnElement":
             ...
 
         @overload
         def __rmod__(
             self: "SQLCoreOperations[_NT]", other: Any
-        ) -> "BinaryExpression[_NUMERIC]":
+        ) -> "ColumnElement[_NUMERIC]":
             ...
 
         @overload
-        def __rmod__(self, other: Any) -> "BinaryExpression":
+        def __rmod__(self, other: Any) -> "ColumnElement":
             ...
 
-        def __rmod__(self, other: Any) -> "BinaryExpression":
+        def __rmod__(self, other: Any) -> "ColumnElement":
             ...
 
         @overload
         def __truediv__(
             self: "SQLCoreOperations[_NT]", other: Any
-        ) -> "BinaryExpression[_NUMERIC]":
+        ) -> "ColumnElement[_NUMERIC]":
             ...
 
         @overload
-        def __truediv__(self, other: Any) -> "BinaryExpression":
+        def __truediv__(self, other: Any) -> "ColumnElement":
             ...
 
-        def __truediv__(self, other: Any) -> "BinaryExpression":
+        def __truediv__(self, other: Any) -> "ColumnElement":
             ...
 
         @overload
         def __rtruediv__(
             self: "SQLCoreOperations[_NT]", other: Any
-        ) -> "BinaryExpression[_NUMERIC]":
+        ) -> "ColumnElement[_NUMERIC]":
             ...
 
         @overload
-        def __rtruediv__(self, other: Any) -> "BinaryExpression":
+        def __rtruediv__(self, other: Any) -> "ColumnElement":
             ...
 
-        def __rtruediv__(self, other: Any) -> "BinaryExpression":
+        def __rtruediv__(self, other: Any) -> "ColumnElement":
             ...
 
         @overload
         def __floordiv__(
             self: "SQLCoreOperations[_NT]", other: Any
-        ) -> "BinaryExpression[_NUMERIC]":
+        ) -> "ColumnElement[_NUMERIC]":
             ...
 
         @overload
-        def __floordiv__(self, other: Any) -> "BinaryExpression":
+        def __floordiv__(self, other: Any) -> "ColumnElement":
             ...
 
-        def __floordiv__(self, other: Any) -> "BinaryExpression":
+        def __floordiv__(self, other: Any) -> "ColumnElement":
             ...
 
         @overload
         def __rfloordiv__(
             self: "SQLCoreOperations[_NT]", other: Any
-        ) -> "BinaryExpression[_NUMERIC]":
+        ) -> "ColumnElement[_NUMERIC]":
             ...
 
         @overload
-        def __rfloordiv__(self, other: Any) -> "BinaryExpression":
+        def __rfloordiv__(self, other: Any) -> "ColumnElement":
             ...
 
-        def __rfloordiv__(self, other: Any) -> "BinaryExpression":
+        def __rfloordiv__(self, other: Any) -> "ColumnElement":
             ...
 
 
index d4fa8042dd5d6d8084a5676c11ba14490ede7a95..f08e71bcd70e8cafadd917e62824b50ff5220fef 100644 (file)
@@ -179,10 +179,10 @@ class Operators(Generic[_OP_RETURN]):
         precedence: int = 0,
         is_comparison: bool = False,
         return_type: Optional[
-            Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
+            Union[Type["TypeEngine[Any]"], "TypeEngine[Any]"]
         ] = None,
         python_impl=None,
-    ) -> Callable[[Any], _OP_RETURN]:
+    ) -> Callable[[Any], Any]:
         """Produce a generic operator function.
 
         e.g.::
@@ -270,7 +270,7 @@ class Operators(Generic[_OP_RETURN]):
 
     def bool_op(
         self, opstring: Any, precedence: int = 0, python_impl=None
-    ) -> Callable[[Any], _OP_RETURN]:
+    ) -> Callable[[Any], Any]:
         """Return a custom boolean operator.
 
         This method is shorthand for calling
@@ -1021,9 +1021,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
             endswith_op, other, escape=escape, autoescape=autoescape
         )
 
-    def contains(
-        self, other: Any, escape=None, autoescape=False
-    ) -> "ColumnOperators":
+    def contains(self, other: Any, **kw: Any) -> "ColumnOperators":
         r"""Implement the 'contains' operator.
 
         Produces a LIKE expression that tests against a match for the middle
@@ -1101,9 +1099,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
 
 
         """
-        return self.operate(
-            contains_op, other, escape=escape, autoescape=autoescape
-        )
+        return self.operate(contains_op, other, **kw)
 
     def match(self, other: Any, **kwargs) -> "ColumnOperators":
         """Implements a database-specific 'match' operator.
index 836c30af7420558ff3b15d7aefb1371bd24fecd9..a5cbffb5e1efb2f2d81bf5b4f79f08727f028dd8 100644 (file)
@@ -19,9 +19,11 @@ import itertools
 from operator import attrgetter
 import typing
 from typing import Any as TODO_Any
+from typing import Any
 from typing import NamedTuple
 from typing import Optional
 from typing import Tuple
+from typing import TypeVar
 
 from . import cache_key
 from . import coercions
@@ -72,6 +74,8 @@ from .. import util
 
 and_ = BooleanClauseList.and_
 
+_T = TypeVar("_T", bound=Any)
+
 
 class _OffsetLimitParam(BindParameter):
     inherit_cache = True
@@ -5528,7 +5532,7 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
         return self
 
 
-class Exists(UnaryExpression):
+class Exists(UnaryExpression[_T]):
     """Represent an ``EXISTS`` clause.
 
     See :func:`_sql.exists` for a description of usage.
index 1e79fd5474c281ed6c2548afd3b39ca531289fd9..5674e19afe2e70834fe68370d159b52b96ce494f 100644 (file)
@@ -1418,40 +1418,38 @@ def counter() -> Callable[[], int]:
 
 
 def duck_type_collection(
-    specimen: Union[object, Type[Any]], default: Optional[Type[Any]] = None
-) -> Type[Any]:
+    specimen: Any, default: Optional[Type[Any]] = None
+) -> Optional[Type[Any]]:
     """Given an instance or class, guess if it is or is acting as one of
     the basic collection types: list, set and dict.  If the __emulates__
     property is present, return that preferentially.
     """
-    if typing.TYPE_CHECKING:
-        return object
-    else:
-        if hasattr(specimen, "__emulates__"):
-            # canonicalize set vs sets.Set to a standard: the builtin set
-            if specimen.__emulates__ is not None and issubclass(
-                specimen.__emulates__, set
-            ):
-                return set
-            else:
-                return specimen.__emulates__
 
-        isa = isinstance(specimen, type) and issubclass or isinstance
-        if isa(specimen, list):
-            return list
-        elif isa(specimen, set):
-            return set
-        elif isa(specimen, dict):
-            return dict
-
-        if hasattr(specimen, "append"):
-            return list
-        elif hasattr(specimen, "add"):
+    if hasattr(specimen, "__emulates__"):
+        # canonicalize set vs sets.Set to a standard: the builtin set
+        if specimen.__emulates__ is not None and issubclass(
+            specimen.__emulates__, set
+        ):
             return set
-        elif hasattr(specimen, "set"):
-            return dict
         else:
-            return default
+            return specimen.__emulates__  # type: ignore
+
+    isa = isinstance(specimen, type) and issubclass or isinstance
+    if isa(specimen, list):
+        return list
+    elif isa(specimen, set):
+        return set
+    elif isa(specimen, dict):
+        return dict
+
+    if hasattr(specimen, "append"):
+        return list
+    elif hasattr(specimen, "add"):
+        return set
+    elif hasattr(specimen, "set"):
+        return dict
+    else:
+        return default
 
 
 def assert_arg_type(arg: Any, argtype: Type[Any], name: str) -> Any:
index ad9c8e531423ea6f5ddbe2cc1c8d83bbd1350e3a..291061561d7437eca980d4875a4605c04abba1fc 100644 (file)
@@ -19,6 +19,8 @@ from . import compat
 
 _T = TypeVar("_T", bound=Any)
 
+Self = TypeVar("Self", bound=Any)
+
 if compat.py310:
     # why they took until py310 to put this in stdlib is beyond me,
     # I've been wanting it since py27
@@ -26,6 +28,11 @@ if compat.py310:
 else:
     NoneType = type(None)  # type: ignore
 
+if typing.TYPE_CHECKING or compat.py38:
+    from typing import SupportsIndex as SupportsIndex
+else:
+    from typing_extensions import SupportsIndex as SupportsIndex
+
 if typing.TYPE_CHECKING or compat.py310:
     from typing import Annotated as Annotated
 else:
index b2754b193dd83df93a78d3c84b1564f8ededb956..cefda245f7e2ab7b82637941f3ea42b60a832e97 100644 (file)
@@ -100,6 +100,7 @@ ignore_errors = true
 module = [
     "sqlalchemy.connectors.*",
     "sqlalchemy.engine.*",
+    "sqlalchemy.ext.associationproxy",
     "sqlalchemy.pool.*",
     "sqlalchemy.event.*",
     "sqlalchemy.events",
diff --git a/test/ext/mypy/plain_files/association_proxy_one.py b/test/ext/mypy/plain_files/association_proxy_one.py
new file mode 100644 (file)
index 0000000..c5c8979
--- /dev/null
@@ -0,0 +1,47 @@
+import typing
+from typing import Set
+
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.ext.associationproxy import association_proxy
+from sqlalchemy.ext.associationproxy import AssociationProxy
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import relationship
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class User(Base):
+    __tablename__ = "user"
+
+    id = mapped_column(Integer, primary_key=True)
+    name = mapped_column(String, nullable=False)
+
+    addresses: Mapped[Set["Address"]] = relationship()
+
+    email_addresses: AssociationProxy[Set[str]] = association_proxy(
+        "addresses", "email"
+    )
+
+
+class Address(Base):
+    __tablename__ = "address"
+
+    id = mapped_column(Integer, primary_key=True)
+    user_id = mapped_column(ForeignKey("user.id"))
+    email = mapped_column(String, nullable=False)
+
+
+u1 = User()
+
+if typing.TYPE_CHECKING:
+    # EXPECTED_TYPE: sqlalchemy.*.associationproxy.AssociationProxyInstance\[builtins.set\*\[builtins.str\]\]
+    reveal_type(User.email_addresses)
+
+    # EXPECTED_TYPE: builtins.set\*\[builtins.str\]
+    reveal_type(u1.email_addresses)
diff --git a/test/ext/mypy/plain_files/sql_operations.py b/test/ext/mypy/plain_files/sql_operations.py
new file mode 100644 (file)
index 0000000..2cee1dd
--- /dev/null
@@ -0,0 +1,20 @@
+import typing
+
+from sqlalchemy import column
+from sqlalchemy import Integer
+
+# builtin.pyi stubs define object.__eq__() as returning bool,  which
+# can't be overridden (it's final).  So for us to type `__eq__()` and
+# `__ne__()`, we have to use type: ignore[override].  Test if does this mean
+# the typing tools don't know the type, or if they just ignore the error.
+# (it's fortunately the former)
+expr1 = column("x", Integer) == 10
+
+
+if typing.TYPE_CHECKING:
+
+    # as far as if this is ColumnElement, BinaryElement, SQLCoreOperations,
+    # that might change.  main thing is it's SomeSQLColThing[bool] and
+    # not 'bool' or 'Any'.
+    # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[builtins.bool\]
+    reveal_type(expr1)
index 583898ce9f463bba9c22c4c49df09969c953c877..484aed7953d7c4f4c10fccef22d06814cf181da5 100644 (file)
@@ -1,3 +1,4 @@
+from collections import abc
 import copy
 import pickle
 from unittest.mock import call
@@ -256,6 +257,23 @@ class _CollectionOperations(fixtures.MappedTest):
         )
         cls.mapper_registry.map_imperatively(Child, children_table)
 
+    def test_abc(self):
+        Parent = self.classes.Parent
+
+        p1 = Parent("x")
+
+        collection_class = self.collection_class or list
+
+        for abc_ in (abc.Set, abc.MutableMapping, abc.MutableSequence):
+            if issubclass(collection_class, abc_):
+                break
+        else:
+            abc_ = None
+
+        if abc_:
+            p1 = Parent("x")
+            assert isinstance(p1.children, abc_)
+
     def roundtrip(self, obj):
         if obj not in self.session:
             self.session.add(obj)
@@ -512,6 +530,10 @@ class CustomDictTest(_CollectionOperations):
 
         p1.children["b"] = "proxied"
 
+        eq_(list(p1.children.keys()), ["a", "b"])
+        eq_(list(p1.children.items()), [("a", "regular"), ("b", "proxied")])
+        eq_(list(p1.children.values()), ["regular", "proxied"])
+
         self.assert_("proxied" in list(p1.children.values()))
         self.assert_("b" in p1.children)
         self.assert_("proxied" not in p1._children)
@@ -2364,7 +2386,8 @@ class DictOfTupleUpdateTest(fixtures.MappedTest):
         a1 = self.classes.A()
         assert_raises_message(
             ValueError,
-            "dictionary update sequence requires " "2-element tuples",
+            "dictionary update sequence element #1 has length 5; "
+            "2 is required",
             a1.elements.update,
             (("B", 3), "elem2"),
         )
@@ -2373,7 +2396,7 @@ class DictOfTupleUpdateTest(fixtures.MappedTest):
         a1 = self.classes.A()
         assert_raises_message(
             TypeError,
-            "update expected at most 1 arguments, got 2",
+            "update expected at most 1 arguments?, got 2",
             a1.elements.update,
             (("B", 3), "elem2"),
             (("C", 4), "elem3"),
index 6933777373ec43189a6f342f4fcb98cd66f30c0f..c6fc3ace192edd2f0c8f297b8142c51a8a87c39e 100644 (file)
@@ -235,16 +235,15 @@ class TestORMInspection(_fixtures.FixtureTest):
     def test_extension_types(self):
         from sqlalchemy.ext.associationproxy import (
             association_proxy,
-            ASSOCIATION_PROXY,
+            AssociationProxyExtensionType,
         )
         from sqlalchemy.ext.hybrid import (
             hybrid_property,
             hybrid_method,
-            HYBRID_PROPERTY,
-            HYBRID_METHOD,
+            HybridExtensionType,
         )
         from sqlalchemy import Table, MetaData, Integer, Column
-        from sqlalchemy.orm.interfaces import NOT_EXTENSION
+        from sqlalchemy.orm.interfaces import NotExtension
 
         class SomeClass(self.classes.User):
             some_assoc = association_proxy("addresses", "email_address")
@@ -290,15 +289,15 @@ class TestORMInspection(_fixtures.FixtureTest):
                 for k, v in list(insp.all_orm_descriptors.items())
             ),
             {
-                "id": NOT_EXTENSION,
-                "name": NOT_EXTENSION,
-                "name_syn": NOT_EXTENSION,
-                "addresses": NOT_EXTENSION,
-                "orders": NOT_EXTENSION,
-                "upper_name": HYBRID_PROPERTY,
-                "foo": HYBRID_PROPERTY,
-                "conv": HYBRID_METHOD,
-                "some_assoc": ASSOCIATION_PROXY,
+                "id": NotExtension.NOT_EXTENSION,
+                "name": NotExtension.NOT_EXTENSION,
+                "name_syn": NotExtension.NOT_EXTENSION,
+                "addresses": NotExtension.NOT_EXTENSION,
+                "orders": NotExtension.NOT_EXTENSION,
+                "upper_name": HybridExtensionType.HYBRID_PROPERTY,
+                "foo": HybridExtensionType.HYBRID_PROPERTY,
+                "conv": HybridExtensionType.HYBRID_METHOD,
+                "some_assoc": AssociationProxyExtensionType.ASSOCIATION_PROXY,
             },
         )
         is_(