See the example ``examples/association/proxied_association.py``.
"""
-import operator
+from __future__ import annotations
+import operator
+import typing
+from typing import AbstractSet
+from typing import Any
+from typing import cast
+from typing import Collection
+from typing import Dict
+from typing import Generic
+from typing import ItemsView
+from typing import Iterable
+from typing import Iterator
+from typing import KeysView
+from typing import Mapping
+from typing import MutableMapping
+from typing import MutableSequence
+from typing import MutableSet
+from typing import NoReturn
+from typing import Optional
+from typing import overload
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TypeVar
+from typing import Union
+from typing import ValuesView
+
+from .. import ColumnElement
from .. import exc
from .. import inspect
from .. import orm
from .. import util
from ..orm import collections
+from ..orm import InspectionAttrExtensionType
from ..orm import interfaces
+from ..orm import ORMDescriptor
+from ..orm.base import SQLORMOperations
+from ..sql import operators
from ..sql import or_
+from ..sql.elements import SQLCoreOperations
from ..sql.operators import ColumnOperators
-
-
-def association_proxy(target_collection, attr, **kw):
+from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import Self
+from ..util.typing import SupportsIndex
+
+if typing.TYPE_CHECKING:
+ from ..orm.attributes import InstrumentedAttribute
+ from ..orm.interfaces import MapperProperty
+ from ..orm.interfaces import PropComparator
+ from ..orm.mapper import Mapper
+
+_T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
+_T_con = TypeVar("_T_con", bound=Any, contravariant=True)
+_S = TypeVar("_S", bound=Any)
+_KT = TypeVar("_KT", bound=Any)
+_VT = TypeVar("_VT", bound=Any)
+
+
+def association_proxy(
+ target_collection: str, attr: str, **kw: Any
+) -> AssociationProxy[Any]:
r"""Return a Python property implementing a view of a target
attribute which references an attribute on members of the
target.
return AssociationProxy(target_collection, attr, **kw)
-ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY")
-"""Symbol indicating an :class:`.InspectionAttr` that's
+class AssociationProxyExtensionType(InspectionAttrExtensionType):
+ ASSOCIATION_PROXY = "ASSOCIATION_PROXY"
+ """Symbol indicating an :class:`.InspectionAttr` that's
of type :class:`.AssociationProxy`.
- Is assigned to the :attr:`.InspectionAttr.extension_type`
- attribute.
+ Is assigned to the :attr:`.InspectionAttr.extension_type`
+ attribute.
+
+ """
+
+
+class _GetterProtocol(Protocol[_T_co]):
+ def __call__(self, instance: Any) -> _T_co:
+ ...
+
+
+class _SetterProtocol(Protocol[_T_co]):
+ ...
+
+
+class _PlainSetterProtocol(_SetterProtocol[_T_con]):
+ def __call__(self, instance: Any, value: _T_con) -> None:
+ ...
+
+
+class _DictSetterProtocol(_SetterProtocol[_T_con]):
+ def __call__(self, instance: Any, key: Any, value: _T_con) -> None:
+ ...
+
+
+class _CreatorProtocol(Protocol[_T_co]):
+ ...
+
+
+class _PlainCreatorProtocol(_CreatorProtocol[_T_con]):
+ def __call__(self, value: _T_con) -> Any:
+ ...
+
+
+class _KeyCreatorProtocol(_CreatorProtocol[_T_con]):
+ def __call__(self, key: Any, value: Optional[_T_con]) -> Any:
+ ...
-"""
+class _LazyCollectionProtocol(Protocol[_T]):
+ def __call__(
+ self,
+ ) -> Union[MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]]:
+ ...
+
+
+class _GetSetFactoryProtocol(Protocol):
+ def __call__(
+ self,
+ collection_class: Optional[Type[Any]],
+ assoc_instance: AssociationProxyInstance[Any],
+ ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]:
+ ...
+
+
+class _ProxyFactoryProtocol(Protocol):
+ def __call__(
+ self,
+ lazy_collection: _LazyCollectionProtocol[Any],
+ creator: _CreatorProtocol[Any],
+ value_attr: str,
+ parent: AssociationProxyInstance[Any],
+ ) -> _T:
+ ...
+
+
+class _ProxyBulkSetProtocol(Protocol):
+ def __call__(
+ self, proxy: _AssociationCollection[Any], collection: Iterable[Any]
+ ) -> None:
+ ...
+
+
+class _AssociationProxyProtocol(Protocol[_T]):
+ """describes the interface of :class:`.AssociationProxy`
+ without including descriptor methods in the interface."""
-class AssociationProxy(interfaces.InspectionAttrInfo):
+ creator: Optional[_CreatorProtocol[Any]]
+ key: str
+ target_collection: str
+ value_attr: str
+ getset_factory: Optional[_GetSetFactoryProtocol]
+ proxy_factory: Optional[_ProxyFactoryProtocol]
+ proxy_bulk_set: Optional[_ProxyBulkSetProtocol]
+
+ @util.memoized_property
+ def info(self) -> Dict[Any, Any]:
+ ...
+
+ def for_class(
+ self, class_: Type[Any], obj: Optional[object] = None
+ ) -> AssociationProxyInstance[_T]:
+ ...
+
+ def _default_getset(
+ self, collection_class: Any
+ ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]:
+ ...
+
+
+_SelfAssociationProxy = TypeVar(
+ "_SelfAssociationProxy", bound="AssociationProxy[Any]"
+)
+
+
+class AssociationProxy(
+ interfaces.InspectionAttrInfo,
+ ORMDescriptor[_T],
+ _AssociationProxyProtocol[_T],
+):
"""A descriptor that presents a read/write view of an object attribute."""
is_attribute = True
- extension_type = ASSOCIATION_PROXY
+ extension_type = AssociationProxyExtensionType.ASSOCIATION_PROXY
def __init__(
self,
- target_collection,
- attr,
- creator=None,
- getset_factory=None,
- proxy_factory=None,
- proxy_bulk_set=None,
- info=None,
- cascade_scalar_deletes=False,
+ target_collection: str,
+ attr: str,
+ creator: Optional[_CreatorProtocol[Any]] = None,
+ getset_factory: Optional[_GetSetFactoryProtocol] = None,
+ proxy_factory: Optional[_ProxyFactoryProtocol] = None,
+ proxy_bulk_set: Optional[_ProxyBulkSetProtocol] = None,
+ info: Optional[Dict[Any, Any]] = None,
+ cascade_scalar_deletes: bool = False,
):
"""Construct a new :class:`.AssociationProxy`.
if info:
self.info = info
- def __get__(self, obj, class_):
- if class_ is None:
+ @overload
+ def __get__(
+ self: _SelfAssociationProxy, instance: Any, owner: Literal[None]
+ ) -> _SelfAssociationProxy:
+ ...
+
+ @overload
+ def __get__(
+ self, instance: Literal[None], owner: Any
+ ) -> AssociationProxyInstance[_T]:
+ ...
+
+ @overload
+ def __get__(self, instance: object, owner: Any) -> _T:
+ ...
+
+ def __get__(
+ self, instance: object, owner: Any
+ ) -> Union[AssociationProxyInstance[_T], _T, AssociationProxy[_T]]:
+ if owner is None:
return self
- inst = self._as_instance(class_, obj)
+ inst = self._as_instance(owner, instance)
if inst:
- return inst.get(obj)
+ return inst.get(instance)
- # obj has to be None here
- # assert obj is None
+ assert instance is None
return self
- def __set__(self, obj, values):
- class_ = type(obj)
- return self._as_instance(class_, obj).set(obj, values)
+ def __set__(self, instance: object, values: _T) -> None:
+ class_ = type(instance)
+ self._as_instance(class_, instance).set(instance, values)
- def __delete__(self, obj):
- class_ = type(obj)
- return self._as_instance(class_, obj).delete(obj)
+ def __delete__(self, instance: object) -> None:
+ class_ = type(instance)
+ self._as_instance(class_, instance).delete(instance)
- def for_class(self, class_, obj=None):
+ def for_class(
+ self, class_: Type[Any], obj: Optional[object] = None
+ ) -> AssociationProxyInstance[_T]:
r"""Return the internal state local to a specific mapped class.
E.g., given a class ``User``::
"""
return self._as_instance(class_, obj)
- def _as_instance(self, class_, obj):
+ def _as_instance(
+ self, class_: Any, obj: Any
+ ) -> AssociationProxyInstance[_T]:
try:
inst = class_.__dict__[self.key + "_inst"]
except KeyError:
# class, only on subclasses of it, which might be
# different. only return for the specific
# object's current value
- return inst._non_canonical_get_for_object(obj)
+ return inst._non_canonical_get_for_object(obj) # type: ignore
else:
- return inst
+ return inst # type: ignore # TODO
- def _calc_owner(self, target_cls):
+ def _calc_owner(self, target_cls: Any) -> Any:
# we might be getting invoked for a subclass
# that is not mapped yet, in some declarative situations.
# save until we are mapped
else:
return insp.mapper.class_manager.class_
- def _default_getset(self, collection_class):
+ def _default_getset(
+ self, collection_class: Any
+ ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]:
attr = self.value_attr
_getter = operator.attrgetter(attr)
- def getter(target):
- return _getter(target) if target is not None else None
+ def getter(instance: Any) -> Optional[Any]:
+ return _getter(instance) if instance is not None else None
if collection_class is dict:
- def setter(o, k, v):
- setattr(o, attr, v)
+ def dict_setter(instance: Any, k: Any, value: Any) -> None:
+ setattr(instance, attr, value)
+
+ return getter, dict_setter
else:
- def setter(o, v):
+ def plain_setter(o: Any, v: Any) -> None:
setattr(o, attr, v)
- return getter, setter
+ return getter, plain_setter
- def __repr__(self):
+ def __repr__(self) -> str:
return "AssociationProxy(%r, %r)" % (
self.target_collection,
self.value_attr,
)
-class AssociationProxyInstance:
+_SelfAssociationProxyInstance = TypeVar(
+ "_SelfAssociationProxyInstance", bound="AssociationProxyInstance[Any]"
+)
+
+
+class AssociationProxyInstance(
+ SQLORMOperations[_T], ColumnOperators[SQLORMOperations[_T]]
+):
"""A per-class object that serves class- and object-specific results.
This is used by :class:`.AssociationProxy` when it is invoked
""" # noqa
- def __init__(self, parent, owning_class, target_class, value_attr):
+ collection_class: Optional[Type[Any]]
+ parent: _AssociationProxyProtocol[_T]
+
+ def __init__(
+ self,
+ parent: _AssociationProxyProtocol[_T],
+ owning_class: Type[Any],
+ target_class: Type[Any],
+ value_attr: str,
+ ):
self.parent = parent
self.key = parent.key
self.owning_class = owning_class
self.target_class = target_class
self.value_attr = value_attr
- target_class = None
+ target_class: Type[Any]
"""The intermediary class handled by this
:class:`.AssociationProxyInstance`.
"""
@classmethod
- def for_proxy(cls, parent, owning_class, parent_instance):
+ def for_proxy(
+ cls,
+ parent: AssociationProxy[_T],
+ owning_class: Type[Any],
+ parent_instance: Any,
+ ) -> AssociationProxyInstance[_T]:
target_collection = parent.target_collection
value_attr = parent.value_attr
- prop = orm.class_mapper(owning_class).get_property(target_collection)
+ prop = cast(
+ "orm.Relationship[_T]",
+ orm.class_mapper(owning_class).get_property(target_collection),
+ )
# this was never asserted before but this should be made clear.
if not isinstance(prop, orm.Relationship):
target_class = prop.mapper.class_
try:
- target_assoc = cls._cls_unwrap_target_assoc_proxy(
- target_class, value_attr
+ target_assoc = cast(
+ "AssociationProxyInstance[_T]",
+ cls._cls_unwrap_target_assoc_proxy(target_class, value_attr),
)
except AttributeError:
# the proxied attribute doesn't exist on the target class;
@classmethod
def _construct_for_assoc(
- cls, target_assoc, parent, owning_class, target_class, value_attr
- ):
+ cls,
+ target_assoc: Optional[AssociationProxyInstance[_T]],
+ parent: _AssociationProxyProtocol[_T],
+ owning_class: Type[Any],
+ target_class: Type[Any],
+ value_attr: str,
+ ) -> AssociationProxyInstance[_T]:
if target_assoc is not None:
return ObjectAssociationProxyInstance(
parent, owning_class, target_class, value_attr
parent, owning_class, target_class, value_attr
)
- def _get_property(self):
+ def _get_property(self) -> MapperProperty[Any]:
return orm.class_mapper(self.owning_class).get_property(
self.target_collection
)
@property
- def _comparator(self):
+ def _comparator(self) -> PropComparator[Any]:
return self._get_property().comparator
- def __clause_element__(self):
+ def __clause_element__(self) -> NoReturn:
raise NotImplementedError(
"The association proxy can't be used as a plain column "
"expression; it only works inside of a comparison expression"
)
@classmethod
- def _cls_unwrap_target_assoc_proxy(cls, target_class, value_attr):
+ def _cls_unwrap_target_assoc_proxy(
+ cls, target_class: Any, value_attr: str
+ ) -> Optional[AssociationProxyInstance[_T]]:
attr = getattr(target_class, value_attr)
- if isinstance(attr, (AssociationProxy, AssociationProxyInstance)):
+ assert not isinstance(attr, AssociationProxy)
+ if isinstance(attr, AssociationProxyInstance):
return attr
return None
@util.memoized_property
- def _unwrap_target_assoc_proxy(self):
+ def _unwrap_target_assoc_proxy(
+ self,
+ ) -> Optional[AssociationProxyInstance[_T]]:
return self._cls_unwrap_target_assoc_proxy(
self.target_class, self.value_attr
)
@property
- def remote_attr(self):
+ def remote_attr(self) -> SQLORMOperations[_T]:
"""The 'remote' class attribute referenced by this
:class:`.AssociationProxyInstance`.
:attr:`.AssociationProxyInstance.local_attr`
"""
- return getattr(self.target_class, self.value_attr)
+ return cast(
+ "SQLORMOperations[_T]", getattr(self.target_class, self.value_attr)
+ )
@property
- def local_attr(self):
+ def local_attr(self) -> SQLORMOperations[Any]:
"""The 'local' class attribute referenced by this
:class:`.AssociationProxyInstance`.
:attr:`.AssociationProxyInstance.remote_attr`
"""
- return getattr(self.owning_class, self.target_collection)
+ return cast(
+ "SQLORMOperations[Any]",
+ getattr(self.owning_class, self.target_collection),
+ )
@property
- def attr(self):
+ def attr(self) -> Tuple[SQLORMOperations[Any], SQLORMOperations[_T]]:
"""Return a tuple of ``(local_attr, remote_attr)``.
This attribute was originally intended to facilitate using the
return (self.local_attr, self.remote_attr)
@util.memoized_property
- def scalar(self):
+ def scalar(self) -> bool:
"""Return ``True`` if this :class:`.AssociationProxyInstance`
proxies a scalar relationship on the local side."""
return scalar
@util.memoized_property
- def _value_is_scalar(self):
+ def _value_is_scalar(self) -> bool:
return (
not self._get_property()
.mapper.get_property(self.value_attr)
)
@property
- def _target_is_object(self):
+ def _target_is_object(self) -> bool:
raise NotImplementedError()
- def _initialize_scalar_accessors(self):
+ _scalar_get: _GetterProtocol[_T]
+ _scalar_set: _PlainSetterProtocol[_T]
+
+ def _initialize_scalar_accessors(self) -> None:
if self.parent.getset_factory:
get, set_ = self.parent.getset_factory(None, self)
else:
get, set_ = self.parent._default_getset(None)
- self._scalar_get, self._scalar_set = get, set_
+ self._scalar_get, self._scalar_set = get, cast(
+ "_PlainSetterProtocol[_T]", set_
+ )
- def _default_getset(self, collection_class):
+ def _default_getset(
+ self, collection_class: Any
+ ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]:
attr = self.value_attr
_getter = operator.attrgetter(attr)
- def getter(target):
- return _getter(target) if target is not None else None
+ def getter(instance: Any) -> Optional[_T]:
+ return _getter(instance) if instance is not None else None
if collection_class is dict:
- def setter(o, k, v):
- return setattr(o, attr, v)
+ def dict_setter(instance: Any, k: Any, value: _T) -> None:
+ setattr(instance, attr, value)
+ return getter, dict_setter
else:
- def setter(o, v):
- return setattr(o, attr, v)
+ def plain_setter(o: Any, v: _T) -> None:
+ setattr(o, attr, v)
- return getter, setter
+ return getter, plain_setter
@property
- def info(self):
+ def info(self) -> Dict[Any, Any]:
return self.parent.info
- def get(self, obj):
+ @overload
+ def get(self: Self, obj: Literal[None]) -> Self:
+ ...
+
+ @overload
+ def get(self, obj: Any) -> _T:
+ ...
+
+ def get(
+ self, obj: Any
+ ) -> Union[Optional[_T], AssociationProxyInstance[_T]]:
if obj is None:
return self
+ proxy: _T
+
if self.scalar:
target = getattr(obj, self.target_collection)
return self._scalar_get(target)
try:
# If the owning instance is reborn (orm session resurrect,
# etc.), refresh the proxy cache.
- creator_id, self_id, proxy = getattr(obj, self.key)
+ creator_id, self_id, proxy = cast(
+ "Tuple[int, int, _T]", getattr(obj, self.key)
+ )
except AttributeError:
pass
else:
setattr(obj, self.key, (id(obj), id(self), proxy))
return proxy
- def set(self, obj, values):
+ def set(self, obj: Any, values: _T) -> None:
if self.scalar:
- creator = (
- self.parent.creator
- if self.parent.creator
- else self.target_class
+ creator = cast(
+ "_PlainCreatorProtocol[_T]",
+ (
+ self.parent.creator
+ if self.parent.creator
+ else self.target_class
+ ),
)
target = getattr(obj, self.target_collection)
if target is None:
if proxy is not values:
proxy._bulk_replace(self, values)
- def delete(self, obj):
+ def delete(self, obj: Any) -> None:
if self.owning_class is None:
self._calc_owner(obj, None)
delattr(target, self.value_attr)
delattr(obj, self.target_collection)
- def _new(self, lazy_collection):
+ def _new(
+ self, lazy_collection: _LazyCollectionProtocol[_T]
+ ) -> Tuple[Type[Any], _T]:
creator = (
- self.parent.creator if self.parent.creator else self.target_class
+ self.parent.creator
+ if self.parent.creator is not None
+ else cast("_CreatorProtocol[_T]", self.target_class)
)
collection_class = util.duck_type_collection(lazy_collection())
+ if collection_class is None:
+ raise exc.InvalidRequestError(
+ f"lazy collection factory did not return a "
+ f"valid collection type, got {collection_class}"
+ )
if self.parent.proxy_factory:
return (
collection_class,
if collection_class is list:
return (
collection_class,
- _AssociationList(
- lazy_collection, creator, getter, setter, self
+ cast(
+ _T,
+ _AssociationList(
+ lazy_collection, creator, getter, setter, self
+ ),
),
)
elif collection_class is dict:
return (
collection_class,
- _AssociationDict(
- lazy_collection, creator, getter, setter, self
+ cast(
+ _T,
+ _AssociationDict(
+ lazy_collection, creator, getter, setter, self
+ ),
),
)
elif collection_class is set:
return (
collection_class,
- _AssociationSet(
- lazy_collection, creator, getter, setter, self
+ cast(
+ _T,
+ _AssociationSet(
+ lazy_collection, creator, getter, setter, self
+ ),
),
)
else:
"could not guess which interface to use for "
'collection_class "%s" backing "%s"; specify a '
"proxy_factory and proxy_bulk_set manually"
- % (self.collection_class.__name__, self.target_collection)
+ % (self.collection_class, self.target_collection)
)
- def _set(self, proxy, values):
+ def _set(
+ self, proxy: _AssociationCollection[Any], values: Iterable[Any]
+ ) -> None:
if self.parent.proxy_bulk_set:
self.parent.proxy_bulk_set(proxy, values)
elif self.collection_class is list:
- proxy.extend(values)
+ cast("_AssociationList[Any]", proxy).extend(values)
elif self.collection_class is dict:
- proxy.update(values)
+ cast("_AssociationDict[Any, Any]", proxy).update(values)
elif self.collection_class is set:
- proxy.update(values)
+ cast("_AssociationSet[Any]", proxy).update(values)
else:
raise exc.ArgumentError(
"no proxy_bulk_set supplied for custom "
"collection_class implementation"
)
- def _inflate(self, proxy):
+ def _inflate(self, proxy: _AssociationCollection[Any]) -> None:
creator = (
- self.parent.creator and self.parent.creator or self.target_class
+ self.parent.creator
+ and self.parent.creator
+ or cast(_CreatorProtocol[Any], self.target_class)
)
if self.parent.getset_factory:
proxy.getter = getter
proxy.setter = setter
- def _criterion_exists(self, criterion=None, **kwargs):
+ def _criterion_exists(
+ self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+ ) -> ColumnElement[bool]:
is_has = kwargs.pop("is_has", None)
target_assoc = self._unwrap_target_assoc_proxy
if target_assoc is not None:
- inner = target_assoc._criterion_exists(
+ inner = target_assoc._criterion_exists( # type: ignore
criterion=criterion, **kwargs
)
return self._comparator._criterion_exists(inner)
return self._comparator._criterion_exists(value_expr)
- def any(self, criterion=None, **kwargs):
+ def any(
+ self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+ ) -> SQLCoreOperations[Any]:
"""Produce a proxied 'any' expression using EXISTS.
This expression will be a composed product
criterion=criterion, is_has=False, **kwargs
)
- def has(self, criterion=None, **kwargs):
+ def has(
+ self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+ ) -> SQLCoreOperations[Any]:
"""Produce a proxied 'has' expression using EXISTS.
This expression will be a composed product
criterion=criterion, is_has=True, **kwargs
)
- def __repr__(self):
+ def __repr__(self) -> str:
return "%s(%r)" % (self.__class__.__name__, self.parent)
-class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
+class AmbiguousAssociationProxyInstance(AssociationProxyInstance[_T]):
"""an :class:`.AssociationProxyInstance` where we cannot determine
the type of target object.
"""
_is_canonical = False
- def _ambiguous(self):
+ def _ambiguous(self) -> NoReturn:
raise AttributeError(
"Association proxy %s.%s refers to an attribute '%s' that is not "
"directly mapped on class %s; therefore this operation cannot "
)
)
- def get(self, obj):
+ def get(self, obj: Any) -> Any:
if obj is None:
return self
else:
return super(AmbiguousAssociationProxyInstance, self).get(obj)
- def __eq__(self, obj):
+ def __eq__(self, obj: object) -> NoReturn:
self._ambiguous()
- def __ne__(self, obj):
+ def __ne__(self, obj: object) -> NoReturn:
self._ambiguous()
- def any(self, criterion=None, **kwargs):
+ def any(
+ self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+ ) -> NoReturn:
self._ambiguous()
- def has(self, criterion=None, **kwargs):
+ def has(
+ self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+ ) -> NoReturn:
self._ambiguous()
@util.memoized_property
- def _lookup_cache(self):
+ def _lookup_cache(self) -> Dict[Type[Any], AssociationProxyInstance[_T]]:
# mapping of <subclass>->AssociationProxyInstance.
# e.g. proxy is A-> A.b -> B -> B.b_attr, but B.b_attr doesn't exist;
# only B1(B) and B2(B) have "b_attr", keys in here would be B1, B2
return {}
- def _non_canonical_get_for_object(self, parent_instance):
+ def _non_canonical_get_for_object(
+ self, parent_instance: Any
+ ) -> AssociationProxyInstance[_T]:
if parent_instance is not None:
actual_obj = getattr(parent_instance, self.target_collection)
if actual_obj is not None:
# is a proxy with generally only instance-level functionality
return self
- def _populate_cache(self, instance_class, mapper):
+ def _populate_cache(
+ self, instance_class: Any, mapper: Mapper[Any]
+ ) -> None:
prop = orm.class_mapper(self.owning_class).get_property(
self.target_collection
)
pass
else:
self._lookup_cache[instance_class] = self._construct_for_assoc(
- target_assoc,
+ cast("AssociationProxyInstance[_T]", target_assoc),
self.parent,
self.owning_class,
target_class,
)
-class ObjectAssociationProxyInstance(AssociationProxyInstance):
+class ObjectAssociationProxyInstance(AssociationProxyInstance[_T]):
"""an :class:`.AssociationProxyInstance` that has an object as a target."""
- _target_is_object = True
+ _target_is_object: bool = True
_is_canonical = True
- def contains(self, obj):
+ def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]:
"""Produce a proxied 'contains' expression using EXISTS.
This expression will be a composed product
target_assoc = self._unwrap_target_assoc_proxy
if target_assoc is not None:
return self._comparator._criterion_exists(
- target_assoc.contains(obj)
+ target_assoc.contains(other)
if not target_assoc.scalar
- else target_assoc == obj
+ else target_assoc == other
)
elif (
self._target_is_object
and not self._value_is_scalar
):
return self._comparator.has(
- getattr(self.target_class, self.value_attr).contains(obj)
+ getattr(self.target_class, self.value_attr).contains(other)
)
elif self._target_is_object and self.scalar and self._value_is_scalar:
raise exc.InvalidRequestError(
)
else:
- return self._comparator._criterion_exists(**{self.value_attr: obj})
+ return self._comparator._criterion_exists(
+ **{self.value_attr: other}
+ )
- def __eq__(self, obj):
+ def __eq__(self, obj: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa E501
# note the has() here will fail for collections; eq_()
# is only allowed with a scalar.
if obj is None:
else:
return self._comparator.has(**{self.value_attr: obj})
- def __ne__(self, obj):
+ def __ne__(self, obj: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa E501
# note the has() here will fail for collections; eq_()
# is only allowed with a scalar.
return self._comparator.has(
)
-class ColumnAssociationProxyInstance(
- ColumnOperators, AssociationProxyInstance
-):
+class ColumnAssociationProxyInstance(AssociationProxyInstance[_T]):
"""an :class:`.AssociationProxyInstance` that has a database column as a
target.
"""
- _target_is_object = False
+ _target_is_object: bool = False
_is_canonical = True
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa E501
# special case "is None" to check for no related row as well
expr = self._criterion_exists(
- self.remote_attr.operate(operator.eq, other)
+ self.remote_attr.operate(operators.eq, other)
)
if other is None:
return or_(expr, self._comparator == None)
else:
return expr
- def operate(self, op, *other, **kwargs):
+ def operate(
+ self, op: operators.OperatorType, *other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
return self._criterion_exists(
self.remote_attr.operate(op, *other, **kwargs)
)
-class _lazy_collection:
- def __init__(self, obj, target):
+class _lazy_collection(_LazyCollectionProtocol[_T]):
+ def __init__(self, obj: Any, target: str):
self.parent = obj
self.target = target
- def __call__(self):
- return getattr(self.parent, self.target)
+ def __call__(
+ self,
+ ) -> Union[MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]]:
+ return getattr(self.parent, self.target) # type: ignore[no-any-return]
- def __getstate__(self):
+ def __getstate__(self) -> Any:
return {"obj": self.parent, "target": self.target}
- def __setstate__(self, state):
+ def __setstate__(self, state: Any) -> None:
self.parent = state["obj"]
self.target = state["target"]
-class _AssociationCollection:
- def __init__(self, lazy_collection, creator, getter, setter, parent):
- """Constructs an _AssociationCollection.
+_IT = TypeVar("_IT", bound="Any")
+"""instance type - this is the type of object inside a collection.
- This will always be a subclass of either _AssociationList,
- _AssociationSet, or _AssociationDict.
+this is not the same as the _T of AssociationProxy and
+AssociationProxyInstance itself, which will often refer to the
+collection[_IT] type.
- lazy_collection
- A callable returning a list-based collection of entities (usually an
- object attribute managed by a SQLAlchemy relationship())
+"""
- creator
- A function that creates new target entities. Given one parameter:
- value. This assertion is assumed::
- obj = creator(somevalue)
- assert getter(obj) == somevalue
+class _AssociationCollection(Generic[_IT]):
+ getter: _GetterProtocol[_IT]
+ """A function. Given an associated object, return the 'value'."""
- getter
- A function. Given an associated object, return the 'value'.
+ creator: _CreatorProtocol[_IT]
+ """
+ A function that creates new target entities. Given one parameter:
+ value. This assertion is assumed::
- setter
- A function. Given an associated object and a value, store that
- value on the object.
+ obj = creator(somevalue)
+ assert getter(obj) == somevalue
+ """
+
+ parent: AssociationProxyInstance[_IT]
+ setter: _SetterProtocol[_IT]
+ """A function. Given an associated object and a value, store that
+ value on the object.
+ """
+
+ lazy_collection: _LazyCollectionProtocol[_IT]
+ """A callable returning a list-based collection of entities (usually an
+ object attribute managed by a SQLAlchemy relationship())"""
+
+ def __init__(
+ self,
+ lazy_collection: _LazyCollectionProtocol[_IT],
+ creator: _CreatorProtocol[_IT],
+ getter: _GetterProtocol[_IT],
+ setter: _SetterProtocol[_IT],
+ parent: AssociationProxyInstance[_IT],
+ ):
+ """Constructs an _AssociationCollection.
+
+ This will always be a subclass of either _AssociationList,
+ _AssociationSet, or _AssociationDict.
"""
self.lazy_collection = lazy_collection
self.setter = setter
self.parent = parent
- col = property(lambda self: self.lazy_collection())
+ if typing.TYPE_CHECKING:
+ col: Collection[_IT]
+ else:
+ col = property(lambda self: self.lazy_collection())
- def __len__(self):
+ def __len__(self) -> int:
return len(self.col)
- def __bool__(self):
+ def __bool__(self) -> bool:
return bool(self.col)
__nonzero__ = __bool__
- def __getstate__(self):
+ def __getstate__(self) -> Any:
return {"parent": self.parent, "lazy_collection": self.lazy_collection}
- def __setstate__(self, state):
+ def __setstate__(self, state: Any) -> None:
self.parent = state["parent"]
self.lazy_collection = state["lazy_collection"]
self.parent._inflate(self)
- def _bulk_replace(self, assoc_proxy, values):
+ def clear(self) -> None:
+ raise NotImplementedError()
+
+
+class _AssociationSingleItem(_AssociationCollection[_T]):
+ setter: _PlainSetterProtocol[_T]
+ creator: _PlainCreatorProtocol[_T]
+
+ def _create(self, value: _T) -> Any:
+ return self.creator(value)
+
+ def _get(self, object_: Any) -> _T:
+ return self.getter(object_)
+
+ def _bulk_replace(
+ self, assoc_proxy: AssociationProxyInstance[Any], values: Iterable[_IT]
+ ) -> None:
self.clear()
assoc_proxy._set(self, values)
-class _AssociationList(_AssociationCollection):
+class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]):
"""Generic, converting, list-to-list proxy."""
- def _create(self, value):
- return self.creator(value)
+ col: MutableSequence[_T]
- def _get(self, object_):
- return self.getter(object_)
+ def _set(self, object_: Any, value: _T) -> None:
+ self.setter(object_, value)
+
+ @overload
+ def __getitem__(self, index: int) -> _T:
+ ...
- def _set(self, object_, value):
- return self.setter(object_, value)
+ @overload
+ def __getitem__(self, index: slice) -> MutableSequence[_T]:
+ ...
- def __getitem__(self, index):
+ def __getitem__(
+ self, index: Union[int, slice]
+ ) -> Union[_T, MutableSequence[_T]]:
if not isinstance(index, slice):
return self._get(self.col[index])
else:
return [self._get(member) for member in self.col[index]]
- def __setitem__(self, index, value):
+ @overload
+ def __setitem__(self, index: int, value: _T) -> None:
+ ...
+
+ @overload
+ def __setitem__(self, index: slice, value: Iterable[_T]) -> None:
+ ...
+
+ def __setitem__(
+ self, index: Union[int, slice], value: Union[_T, Iterable[_T]]
+ ) -> None:
if not isinstance(index, slice):
- self._set(self.col[index], value)
+ self._set(self.col[index], cast("_T", value))
else:
if index.stop is None:
stop = len(self)
start = index.start or 0
rng = list(range(index.start or 0, stop, step))
+
+ sized_value = list(value)
+
if step == 1:
for i in rng:
del self[start]
i = start
- for item in value:
+ for item in sized_value:
self.insert(i, item)
i += 1
else:
- if len(value) != len(rng):
+ if len(sized_value) != len(rng):
raise ValueError(
"attempt to assign sequence of size %s to "
- "extended slice of size %s" % (len(value), len(rng))
+ "extended slice of size %s"
+ % (len(sized_value), len(rng))
)
for i, item in zip(rng, value):
self._set(self.col[i], item)
- def __delitem__(self, index):
+ @overload
+ def __delitem__(self, index: int) -> None:
+ ...
+
+ @overload
+ def __delitem__(self, index: slice) -> None:
+ ...
+
+ def __delitem__(self, index: Union[slice, int]) -> None:
del self.col[index]
- def __contains__(self, value):
+ def __contains__(self, value: object) -> bool:
for member in self.col:
# testlib.pragma exempt:__eq__
if self._get(member) == value:
return True
return False
- def __getslice__(self, start, end):
- return [self._get(member) for member in self.col[start:end]]
-
- def __setslice__(self, start, end, values):
- members = [self._create(v) for v in values]
- self.col[start:end] = members
-
- def __delslice__(self, start, end):
- del self.col[start:end]
-
- def __iter__(self):
+ def __iter__(self) -> Iterator[_T]:
"""Iterate over proxied values.
For the actual domain objects, iterate over .col instead or
yield self._get(member)
return
- def append(self, value):
+ def append(self, value: _T) -> None:
col = self.col
item = self._create(value)
col.append(item)
- def count(self, value):
+ def count(self, value: Any) -> int:
count = 0
for v in self:
if v == value:
count += 1
return count
- def extend(self, values):
+ def extend(self, values: Iterable[_T]) -> None:
for v in values:
self.append(v)
- def insert(self, index, value):
+ def insert(self, index: int, value: _T) -> None:
self.col[index:index] = [self._create(value)]
- def pop(self, index=-1):
+ def pop(self, index: int = -1) -> _T:
return self.getter(self.col.pop(index))
- def remove(self, value):
+ def remove(self, value: _T) -> None:
for i, val in enumerate(self):
if val == value:
del self.col[i]
return
raise ValueError("value not in list")
- def reverse(self):
+ def reverse(self) -> NoReturn:
"""Not supported, use reversed(mylist)"""
- raise NotImplementedError
+ raise NotImplementedError()
- def sort(self):
+ def sort(self) -> NoReturn:
"""Not supported, use sorted(mylist)"""
- raise NotImplementedError
+ raise NotImplementedError()
- def clear(self):
+ def clear(self) -> None:
del self.col[0 : len(self.col)]
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
return list(self) == other
- def __ne__(self, other):
+ def __ne__(self, other: object) -> bool:
return list(self) != other
- def __lt__(self, other):
+ def __lt__(self, other: list[_T]) -> bool:
return list(self) < other
- def __le__(self, other):
+ def __le__(self, other: list[_T]) -> bool:
return list(self) <= other
- def __gt__(self, other):
+ def __gt__(self, other: list[_T]) -> bool:
return list(self) > other
- def __ge__(self, other):
+ def __ge__(self, other: list[_T]) -> bool:
return list(self) >= other
- def __cmp__(self, other):
- return util.cmp(list(self), other)
-
- def __add__(self, iterable):
+ def __add__(self, other: list[_T]) -> list[_T]:
try:
- other = list(iterable)
+ other = list(other)
except TypeError:
return NotImplemented
return list(self) + other
- def __radd__(self, iterable):
+ def __radd__(self, other: list[_T]) -> list[_T]:
try:
- other = list(iterable)
+ other = list(other)
except TypeError:
return NotImplemented
return other + list(self)
- def __mul__(self, n):
+ def __mul__(self, n: SupportsIndex) -> list[_T]:
if not isinstance(n, int):
return NotImplemented
return list(self) * n
- __rmul__ = __mul__
+ def __rmul__(self, n: SupportsIndex) -> list[_T]:
+ if not isinstance(n, int):
+ return NotImplemented
+ return n * list(self)
- def __iadd__(self, iterable):
+ def __iadd__(self: Self, iterable: Iterable[_T]) -> Self:
self.extend(iterable)
return self
- def __imul__(self, n):
+ def __imul__(self: Self, n: SupportsIndex) -> Self:
# unlike a regular list *=, proxied __imul__ will generate unique
# backing objects for each copy. *= on proxied lists is a bit of
# a stretch anyhow, and this interpretation of the __imul__ contract
# is more plausibly useful than copying the backing objects.
if not isinstance(n, int):
- return NotImplemented
+ raise NotImplementedError()
if n == 0:
self.clear()
elif n > 1:
self.extend(list(self) * (n - 1))
return self
- def index(self, item, *args):
- return list(self).index(item, *args)
+ if typing.TYPE_CHECKING:
+ # TODO: no idea how to do this without separate "stub"
+ def index(self, value: Any, start: int = ..., stop: int = ...) -> int:
+ ...
+
+ else:
+
+ def index(self, value: Any, *arg) -> int:
+ ls = list(self)
+ return ls.index(value, *arg)
- def copy(self):
+ def copy(self) -> list[_T]:
return list(self)
- def __repr__(self):
+ def __repr__(self) -> str:
return repr(list(self))
- def __hash__(self):
+ def __hash__(self) -> NoReturn:
raise TypeError("%s objects are unhashable" % type(self).__name__)
- for func_name, func in list(locals().items()):
- if (
- callable(func)
- and func.__name__ == func_name
- and not func.__doc__
- and hasattr(list, func_name)
- ):
- func.__doc__ = getattr(list, func_name).__doc__
- del func_name, func
-
+ if not typing.TYPE_CHECKING:
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(list, func_name)
+ ):
+ func.__doc__ = getattr(list, func_name).__doc__
+ del func_name, func
-_NotProvided = util.symbol("_NotProvided")
-
-class _AssociationDict(_AssociationCollection):
+class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]):
"""Generic, converting, dict-to-dict proxy."""
- def _create(self, key, value):
+ setter: _DictSetterProtocol[_VT]
+ creator: _KeyCreatorProtocol[_VT]
+ col: MutableMapping[_KT, Optional[_VT]]
+
+ def _create(self, key: _KT, value: Optional[_VT]) -> Any:
return self.creator(key, value)
- def _get(self, object_):
+ def _get(self, object_: Any) -> _VT:
return self.getter(object_)
- def _set(self, object_, key, value):
+ def _set(self, object_: Any, key: _KT, value: _VT) -> None:
return self.setter(object_, key, value)
- def __getitem__(self, key):
+ def __getitem__(self, key: _KT) -> _VT:
return self._get(self.col[key])
- def __setitem__(self, key, value):
+ def __setitem__(self, key: _KT, value: _VT) -> None:
if key in self.col:
self._set(self.col[key], key, value)
else:
self.col[key] = self._create(key, value)
- def __delitem__(self, key):
+ def __delitem__(self, key: _KT) -> None:
del self.col[key]
- def __contains__(self, key):
- # testlib.pragma exempt:__hash__
+ def __contains__(self, key: object) -> bool:
return key in self.col
- def has_key(self, key):
- # testlib.pragma exempt:__hash__
- return key in self.col
-
- def __iter__(self):
+ def __iter__(self) -> Iterator[_KT]:
return iter(self.col.keys())
- def clear(self):
+ def clear(self) -> None:
self.col.clear()
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
return dict(self) == other
- def __ne__(self, other):
+ def __ne__(self, other: object) -> bool:
return dict(self) != other
- def __lt__(self, other):
- return dict(self) < other
-
- def __le__(self, other):
- return dict(self) <= other
+ def __repr__(self) -> str:
+ return repr(dict(self))
- def __gt__(self, other):
- return dict(self) > other
+ @overload
+ def get(self, __key: _KT) -> Optional[_VT]:
+ ...
- def __ge__(self, other):
- return dict(self) >= other
+ @overload
+ def get(self, __key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]:
+ ...
- def __cmp__(self, other):
- return util.cmp(dict(self), other)
-
- def __repr__(self):
- return repr(dict(self.items()))
-
- def get(self, key, default=None):
+ def get(
+ self, key: _KT, default: Optional[Union[_VT, _T]] = None
+ ) -> Union[_VT, _T, None]:
try:
return self[key]
except KeyError:
return default
- def setdefault(self, key, default=None):
+ def setdefault(self, key: _KT, default: Optional[_VT] = None) -> _VT:
+ # TODO: again, no idea how to create an actual MutableMapping.
+ # default must allow None, return type can't include None,
+ # the stub explicitly allows for default of None with a cryptic message
+ # "This overload should be allowed only if the value type is
+ # compatible with None.".
if key not in self.col:
self.col[key] = self._create(key, default)
- return default
+ return default # type: ignore
else:
return self[key]
- def keys(self):
+ def keys(self) -> KeysView[_KT]:
return self.col.keys()
- def items(self):
- return ((key, self._get(self.col[key])) for key in self.col)
+ def items(self) -> ItemsView[_KT, _VT]:
+ return ItemsView(self)
- def values(self):
- return (self._get(self.col[key]) for key in self.col)
+ def values(self) -> ValuesView[_VT]:
+ return ValuesView(self)
- def pop(self, key, default=_NotProvided):
- if default is _NotProvided:
- member = self.col.pop(key)
- else:
- member = self.col.pop(key, default)
+ @overload
+ def pop(self, __key: _KT) -> _VT:
+ ...
+
+ @overload
+ def pop(self, __key: _KT, default: Union[_VT, _T] = ...) -> Union[_VT, _T]:
+ ...
+
+ def pop(self, __key: _KT, *arg: Any, **kw: Any) -> Union[_VT, _T]:
+ member = self.col.pop(__key, *arg, **kw)
return self._get(member)
- def popitem(self):
+ def popitem(self) -> Tuple[_KT, _VT]:
item = self.col.popitem()
return (item[0], self._get(item[1]))
- def update(self, *a, **kw):
- if len(a) > 1:
- raise TypeError(
- "update expected at most 1 arguments, got %i" % len(a)
- )
- elif len(a) == 1:
- seq_or_map = a[0]
- # discern dict from sequence - took the advice from
- # https://www.voidspace.org.uk/python/articles/duck_typing.shtml
- # still not perfect :(
- if hasattr(seq_or_map, "keys"):
- for item in seq_or_map:
- self[item] = seq_or_map[item]
- else:
- try:
- for k, v in seq_or_map:
- self[k] = v
- except ValueError as err:
- raise ValueError(
- "dictionary update sequence "
- "requires 2-element tuples"
- ) from err
+ @overload
+ def update(self, __m: Mapping[_KT, _VT], **kwargs: _VT) -> None:
+ ...
+
+ @overload
+ def update(self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT) -> None:
+ ...
- for key, value in kw:
+ @overload
+ def update(self, **kwargs: _VT) -> None:
+ ...
+
+ def update(self, *a: Any, **kw: Any) -> None:
+ up: Dict[_KT, _VT] = {}
+ up.update(*a, **kw)
+
+ for key, value in up.items():
self[key] = value
- def _bulk_replace(self, assoc_proxy, values):
+ def _bulk_replace(
+ self,
+ assoc_proxy: AssociationProxyInstance[Any],
+ values: Mapping[_KT, _VT],
+ ) -> None:
existing = set(self)
constants = existing.intersection(values or ())
additions = set(values or ()).difference(constants)
for key in removals:
del self[key]
- def copy(self):
+ def copy(self) -> dict[_KT, _VT]:
return dict(self.items())
- def __hash__(self):
+ def __hash__(self) -> NoReturn:
raise TypeError("%s objects are unhashable" % type(self).__name__)
- for func_name, func in list(locals().items()):
- if (
- callable(func)
- and func.__name__ == func_name
- and not func.__doc__
- and hasattr(dict, func_name)
- ):
- func.__doc__ = getattr(dict, func_name).__doc__
- del func_name, func
+ if not typing.TYPE_CHECKING:
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(dict, func_name)
+ ):
+ func.__doc__ = getattr(dict, func_name).__doc__
+ del func_name, func
-class _AssociationSet(_AssociationCollection):
+class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]):
"""Generic, converting, set-to-set proxy."""
- def _create(self, value):
- return self.creator(value)
+ col: MutableSet[_T]
- def _get(self, object_):
- return self.getter(object_)
-
- def __len__(self):
+ def __len__(self) -> int:
return len(self.col)
- def __bool__(self):
+ def __bool__(self) -> bool:
if self.col:
return True
else:
__nonzero__ = __bool__
- def __contains__(self, value):
+ def __contains__(self, __o: object) -> bool:
for member in self.col:
- # testlib.pragma exempt:__eq__
- if self._get(member) == value:
+ if self._get(member) == __o:
return True
return False
- def __iter__(self):
+ def __iter__(self) -> Iterator[_T]:
"""Iterate over proxied values.
For the actual domain objects, iterate over .col instead or just use
yield self._get(member)
return
- def add(self, value):
- if value not in self:
- self.col.add(self._create(value))
+ def add(self, __element: _T) -> None:
+ if __element not in self:
+ self.col.add(self._create(__element))
# for discard and remove, choosing a more expensive check strategy rather
# than call self.creator()
- def discard(self, value):
+ def discard(self, __element: _T) -> None:
for member in self.col:
- if self._get(member) == value:
+ if self._get(member) == __element:
self.col.discard(member)
break
- def remove(self, value):
+ def remove(self, __element: _T) -> None:
for member in self.col:
- if self._get(member) == value:
+ if self._get(member) == __element:
self.col.discard(member)
return
- raise KeyError(value)
+ raise KeyError(__element)
- def pop(self):
+ def pop(self) -> _T:
if not self.col:
raise KeyError("pop from an empty set")
member = self.col.pop()
return self._get(member)
- def update(self, other):
- for value in other:
- self.add(value)
+ def update(self, *s: Iterable[_T]) -> None:
+ for iterable in s:
+ for value in iterable:
+ self.add(value)
- def _bulk_replace(self, assoc_proxy, values):
+ def _bulk_replace(self, assoc_proxy: Any, values: Iterable[_T]) -> None:
existing = set(self)
constants = existing.intersection(values or ())
additions = set(values or ()).difference(constants)
for member in removals:
remover(member)
- def __ior__(self, other):
+ def __ior__(
+ self: Self, other: AbstractSet[_S]
+ ) -> MutableSet[Union[_T, _S]]:
if not collections._set_binops_check_strict(self, other):
- return NotImplemented
+ raise NotImplementedError()
for value in other:
self.add(value)
return self
- def _set(self):
+ def _set(self) -> Set[_T]:
return set(iter(self))
- def union(self, other):
- return set(self).union(other)
+ def union(self, *s: Iterable[_S]) -> MutableSet[Union[_T, _S]]:
+ return set(self).union(*s)
- __or__ = union
+ def __or__(self, __s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]:
+ return self.union(__s)
- def difference(self, other):
- return set(self).difference(other)
+ def difference(self, *s: Iterable[Any]) -> MutableSet[_T]:
+ return set(self).difference(*s)
- __sub__ = difference
+ def __sub__(self, s: AbstractSet[Any]) -> MutableSet[_T]:
+ return self.difference(s)
- def difference_update(self, other):
- for value in other:
- self.discard(value)
+ def difference_update(self, *s: Iterable[Any]) -> None:
+ for other in s:
+ for value in other:
+ self.discard(value)
- def __isub__(self, other):
- if not collections._set_binops_check_strict(self, other):
- return NotImplemented
- for value in other:
+ def __isub__(self: Self, s: AbstractSet[Any]) -> Self:
+ if not collections._set_binops_check_strict(self, s):
+ raise NotImplementedError()
+ for value in s:
self.discard(value)
return self
- def intersection(self, other):
- return set(self).intersection(other)
+ def intersection(self, *s: Iterable[Any]) -> MutableSet[_T]:
+ return set(self).intersection(*s)
- __and__ = intersection
+ def __and__(self, s: AbstractSet[Any]) -> MutableSet[_T]:
+ return self.intersection(s)
- def intersection_update(self, other):
- want, have = self.intersection(other), set(self)
+ def intersection_update(self, *s: Iterable[Any]) -> None:
+ for other in s:
+ want, have = self.intersection(other), set(self)
- remove, add = have - want, want - have
+ remove, add = have - want, want - have
- for value in remove:
- self.remove(value)
- for value in add:
- self.add(value)
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
- def __iand__(self, other):
- if not collections._set_binops_check_strict(self, other):
- return NotImplemented
- want, have = self.intersection(other), set(self)
+ def __iand__(self: Self, s: AbstractSet[Any]) -> Self:
+ if not collections._set_binops_check_strict(self, s):
+ raise NotImplementedError()
+ want = self.intersection(s)
+ have: Set[_T] = set(self)
remove, add = have - want, want - have
self.add(value)
return self
- def symmetric_difference(self, other):
- return set(self).symmetric_difference(other)
+ def symmetric_difference(self, __s: Iterable[_T]) -> MutableSet[_T]:
+ return set(self).symmetric_difference(__s)
- __xor__ = symmetric_difference
+ def __xor__(self, s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]:
+ return self.symmetric_difference(s) # type: ignore
- def symmetric_difference_update(self, other):
+ def symmetric_difference_update(self, other: Iterable[Any]) -> None:
want, have = self.symmetric_difference(other), set(self)
remove, add = have - want, want - have
for value in add:
self.add(value)
- def __ixor__(self, other):
+ def __ixor__(self, other: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]:
if not collections._set_binops_check_strict(self, other):
- return NotImplemented
- want, have = self.symmetric_difference(other), set(self)
-
- remove, add = have - want, want - have
+ raise NotImplementedError()
- for value in remove:
- self.remove(value)
- for value in add:
- self.add(value)
+ self.symmetric_difference_update(other)
return self
- def issubset(self, other):
- return set(self).issubset(other)
+ def issubset(self, __s: Iterable[Any]) -> bool:
+ return set(self).issubset(__s)
- def issuperset(self, other):
- return set(self).issuperset(other)
+ def issuperset(self, __s: Iterable[Any]) -> bool:
+ return set(self).issuperset(__s)
- def clear(self):
+ def clear(self) -> None:
self.col.clear()
- def copy(self):
+ def copy(self) -> AbstractSet[_T]:
return set(self)
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
return set(self) == other
- def __ne__(self, other):
+ def __ne__(self, other: object) -> bool:
return set(self) != other
- def __lt__(self, other):
+ def __lt__(self, other: AbstractSet[Any]) -> bool:
return set(self) < other
- def __le__(self, other):
+ def __le__(self, other: AbstractSet[Any]) -> bool:
return set(self) <= other
- def __gt__(self, other):
+ def __gt__(self, other: AbstractSet[Any]) -> bool:
return set(self) > other
- def __ge__(self, other):
+ def __ge__(self, other: AbstractSet[Any]) -> bool:
return set(self) >= other
- def __repr__(self):
+ def __repr__(self) -> str:
return repr(set(self))
- def __hash__(self):
+ def __hash__(self) -> NoReturn:
raise TypeError("%s objects are unhashable" % type(self).__name__)
- for func_name, func in list(locals().items()):
- if (
- callable(func)
- and func.__name__ == func_name
- and not func.__doc__
- and hasattr(set, func_name)
- ):
- func.__doc__ = getattr(set, func_name).__doc__
- del func_name, func
+ if not typing.TYPE_CHECKING:
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(set, func_name)
+ ):
+ func.__doc__ = getattr(set, func_name).__doc__
+ del func_name, func