From: Mike Bayer Date: Mon, 28 Feb 2022 04:05:46 +0000 (-0500) Subject: pep484 + abc bases for assocaitionproxy X-Git-Tag: rel_2_0_0b1~459 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=afb9634fb28b00c7b0979660e3e0bfed6caafde5;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pep484 + abc bases for assocaitionproxy went to this one next as it was going to be hard, and also exercises the ORM expression hierarchy a bit. made some adjustments to SQLCoreOperations etc. Change-Id: Ie5dde9218dc1318252826b766d3e70b17dd24ea7 References: #6810 References: #7774 --- diff --git a/doc/build/orm/internals.rst b/doc/build/orm/internals.rst index 05cf83b394..f251e43bd0 100644 --- a/doc/build/orm/internals.rst +++ b/doc/build/orm/internals.rst @@ -88,7 +88,10 @@ sections, are listed here. :attr:`.SchemaItem.info` -.. autodata:: NOT_EXTENSION +.. autoclass:: InspectionAttrExtensionType + +.. autoclass:: NotExtension + :members: .. autofunction:: merge_result diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index d5119907ed..709c13c146 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -13,19 +13,70 @@ transparent proxied access to the endpoint of an association object. See the example ``examples/association/proxied_association.py``. """ -import operator +from __future__ import annotations +import operator +import typing +from typing import AbstractSet +from typing import Any +from typing import cast +from typing import Collection +from typing import Dict +from typing import Generic +from typing import ItemsView +from typing import Iterable +from typing import Iterator +from typing import KeysView +from typing import Mapping +from typing import MutableMapping +from typing import MutableSequence +from typing import MutableSet +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +from typing import ValuesView + +from .. import ColumnElement from .. import exc from .. import inspect from .. import orm from .. import util from ..orm import collections +from ..orm import InspectionAttrExtensionType from ..orm import interfaces +from ..orm import ORMDescriptor +from ..orm.base import SQLORMOperations +from ..sql import operators from ..sql import or_ +from ..sql.elements import SQLCoreOperations from ..sql.operators import ColumnOperators - - -def association_proxy(target_collection, attr, **kw): +from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import Self +from ..util.typing import SupportsIndex + +if typing.TYPE_CHECKING: + from ..orm.attributes import InstrumentedAttribute + from ..orm.interfaces import MapperProperty + from ..orm.interfaces import PropComparator + from ..orm.mapper import Mapper + +_T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) +_T_con = TypeVar("_T_con", bound=Any, contravariant=True) +_S = TypeVar("_S", bound=Any) +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + + +def association_proxy( + target_collection: str, attr: str, **kw: Any +) -> AssociationProxy[Any]: r"""Return a Python property implementing a view of a target attribute which references an attribute on members of the target. @@ -80,32 +131,136 @@ def association_proxy(target_collection, attr, **kw): return AssociationProxy(target_collection, attr, **kw) -ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY") -"""Symbol indicating an :class:`.InspectionAttr` that's +class AssociationProxyExtensionType(InspectionAttrExtensionType): + ASSOCIATION_PROXY = "ASSOCIATION_PROXY" + """Symbol indicating an :class:`.InspectionAttr` that's of type :class:`.AssociationProxy`. - Is assigned to the :attr:`.InspectionAttr.extension_type` - attribute. + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + + """ + + +class _GetterProtocol(Protocol[_T_co]): + def __call__(self, instance: Any) -> _T_co: + ... + + +class _SetterProtocol(Protocol[_T_co]): + ... + + +class _PlainSetterProtocol(_SetterProtocol[_T_con]): + def __call__(self, instance: Any, value: _T_con) -> None: + ... + + +class _DictSetterProtocol(_SetterProtocol[_T_con]): + def __call__(self, instance: Any, key: Any, value: _T_con) -> None: + ... + + +class _CreatorProtocol(Protocol[_T_co]): + ... + + +class _PlainCreatorProtocol(_CreatorProtocol[_T_con]): + def __call__(self, value: _T_con) -> Any: + ... + + +class _KeyCreatorProtocol(_CreatorProtocol[_T_con]): + def __call__(self, key: Any, value: Optional[_T_con]) -> Any: + ... -""" +class _LazyCollectionProtocol(Protocol[_T]): + def __call__( + self, + ) -> Union[MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]]: + ... + + +class _GetSetFactoryProtocol(Protocol): + def __call__( + self, + collection_class: Optional[Type[Any]], + assoc_instance: AssociationProxyInstance[Any], + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]: + ... + + +class _ProxyFactoryProtocol(Protocol): + def __call__( + self, + lazy_collection: _LazyCollectionProtocol[Any], + creator: _CreatorProtocol[Any], + value_attr: str, + parent: AssociationProxyInstance[Any], + ) -> _T: + ... + + +class _ProxyBulkSetProtocol(Protocol): + def __call__( + self, proxy: _AssociationCollection[Any], collection: Iterable[Any] + ) -> None: + ... + + +class _AssociationProxyProtocol(Protocol[_T]): + """describes the interface of :class:`.AssociationProxy` + without including descriptor methods in the interface.""" -class AssociationProxy(interfaces.InspectionAttrInfo): + creator: Optional[_CreatorProtocol[Any]] + key: str + target_collection: str + value_attr: str + getset_factory: Optional[_GetSetFactoryProtocol] + proxy_factory: Optional[_ProxyFactoryProtocol] + proxy_bulk_set: Optional[_ProxyBulkSetProtocol] + + @util.memoized_property + def info(self) -> Dict[Any, Any]: + ... + + def for_class( + self, class_: Type[Any], obj: Optional[object] = None + ) -> AssociationProxyInstance[_T]: + ... + + def _default_getset( + self, collection_class: Any + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]: + ... + + +_SelfAssociationProxy = TypeVar( + "_SelfAssociationProxy", bound="AssociationProxy[Any]" +) + + +class AssociationProxy( + interfaces.InspectionAttrInfo, + ORMDescriptor[_T], + _AssociationProxyProtocol[_T], +): """A descriptor that presents a read/write view of an object attribute.""" is_attribute = True - extension_type = ASSOCIATION_PROXY + extension_type = AssociationProxyExtensionType.ASSOCIATION_PROXY def __init__( self, - target_collection, - attr, - creator=None, - getset_factory=None, - proxy_factory=None, - proxy_bulk_set=None, - info=None, - cascade_scalar_deletes=False, + target_collection: str, + attr: str, + creator: Optional[_CreatorProtocol[Any]] = None, + getset_factory: Optional[_GetSetFactoryProtocol] = None, + proxy_factory: Optional[_ProxyFactoryProtocol] = None, + proxy_bulk_set: Optional[_ProxyBulkSetProtocol] = None, + info: Optional[Dict[Any, Any]] = None, + cascade_scalar_deletes: bool = False, ): """Construct a new :class:`.AssociationProxy`. @@ -185,27 +340,46 @@ class AssociationProxy(interfaces.InspectionAttrInfo): if info: self.info = info - def __get__(self, obj, class_): - if class_ is None: + @overload + def __get__( + self: _SelfAssociationProxy, instance: Any, owner: Literal[None] + ) -> _SelfAssociationProxy: + ... + + @overload + def __get__( + self, instance: Literal[None], owner: Any + ) -> AssociationProxyInstance[_T]: + ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T: + ... + + def __get__( + self, instance: object, owner: Any + ) -> Union[AssociationProxyInstance[_T], _T, AssociationProxy[_T]]: + if owner is None: return self - inst = self._as_instance(class_, obj) + inst = self._as_instance(owner, instance) if inst: - return inst.get(obj) + return inst.get(instance) - # obj has to be None here - # assert obj is None + assert instance is None return self - def __set__(self, obj, values): - class_ = type(obj) - return self._as_instance(class_, obj).set(obj, values) + def __set__(self, instance: object, values: _T) -> None: + class_ = type(instance) + self._as_instance(class_, instance).set(instance, values) - def __delete__(self, obj): - class_ = type(obj) - return self._as_instance(class_, obj).delete(obj) + def __delete__(self, instance: object) -> None: + class_ = type(instance) + self._as_instance(class_, instance).delete(instance) - def for_class(self, class_, obj=None): + def for_class( + self, class_: Type[Any], obj: Optional[object] = None + ) -> AssociationProxyInstance[_T]: r"""Return the internal state local to a specific mapped class. E.g., given a class ``User``:: @@ -240,7 +414,9 @@ class AssociationProxy(interfaces.InspectionAttrInfo): """ return self._as_instance(class_, obj) - def _as_instance(self, class_, obj): + def _as_instance( + self, class_: Any, obj: Any + ) -> AssociationProxyInstance[_T]: try: inst = class_.__dict__[self.key + "_inst"] except KeyError: @@ -261,11 +437,11 @@ class AssociationProxy(interfaces.InspectionAttrInfo): # class, only on subclasses of it, which might be # different. only return for the specific # object's current value - return inst._non_canonical_get_for_object(obj) + return inst._non_canonical_get_for_object(obj) # type: ignore else: - return inst + return inst # type: ignore # TODO - def _calc_owner(self, target_cls): + def _calc_owner(self, target_cls: Any) -> Any: # we might be getting invoked for a subclass # that is not mapped yet, in some declarative situations. # save until we are mapped @@ -280,33 +456,44 @@ class AssociationProxy(interfaces.InspectionAttrInfo): else: return insp.mapper.class_manager.class_ - def _default_getset(self, collection_class): + def _default_getset( + self, collection_class: Any + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]: attr = self.value_attr _getter = operator.attrgetter(attr) - def getter(target): - return _getter(target) if target is not None else None + def getter(instance: Any) -> Optional[Any]: + return _getter(instance) if instance is not None else None if collection_class is dict: - def setter(o, k, v): - setattr(o, attr, v) + def dict_setter(instance: Any, k: Any, value: Any) -> None: + setattr(instance, attr, value) + + return getter, dict_setter else: - def setter(o, v): + def plain_setter(o: Any, v: Any) -> None: setattr(o, attr, v) - return getter, setter + return getter, plain_setter - def __repr__(self): + def __repr__(self) -> str: return "AssociationProxy(%r, %r)" % ( self.target_collection, self.value_attr, ) -class AssociationProxyInstance: +_SelfAssociationProxyInstance = TypeVar( + "_SelfAssociationProxyInstance", bound="AssociationProxyInstance[Any]" +) + + +class AssociationProxyInstance( + SQLORMOperations[_T], ColumnOperators[SQLORMOperations[_T]] +): """A per-class object that serves class- and object-specific results. This is used by :class:`.AssociationProxy` when it is invoked @@ -336,7 +523,16 @@ class AssociationProxyInstance: """ # noqa - def __init__(self, parent, owning_class, target_class, value_attr): + collection_class: Optional[Type[Any]] + parent: _AssociationProxyProtocol[_T] + + def __init__( + self, + parent: _AssociationProxyProtocol[_T], + owning_class: Type[Any], + target_class: Type[Any], + value_attr: str, + ): self.parent = parent self.key = parent.key self.owning_class = owning_class @@ -345,7 +541,7 @@ class AssociationProxyInstance: self.target_class = target_class self.value_attr = value_attr - target_class = None + target_class: Type[Any] """The intermediary class handled by this :class:`.AssociationProxyInstance`. @@ -355,10 +551,18 @@ class AssociationProxyInstance: """ @classmethod - def for_proxy(cls, parent, owning_class, parent_instance): + def for_proxy( + cls, + parent: AssociationProxy[_T], + owning_class: Type[Any], + parent_instance: Any, + ) -> AssociationProxyInstance[_T]: target_collection = parent.target_collection value_attr = parent.value_attr - prop = orm.class_mapper(owning_class).get_property(target_collection) + prop = cast( + "orm.Relationship[_T]", + orm.class_mapper(owning_class).get_property(target_collection), + ) # this was never asserted before but this should be made clear. if not isinstance(prop, orm.Relationship): @@ -370,8 +574,9 @@ class AssociationProxyInstance: target_class = prop.mapper.class_ try: - target_assoc = cls._cls_unwrap_target_assoc_proxy( - target_class, value_attr + target_assoc = cast( + "AssociationProxyInstance[_T]", + cls._cls_unwrap_target_assoc_proxy(target_class, value_attr), ) except AttributeError: # the proxied attribute doesn't exist on the target class; @@ -387,8 +592,13 @@ class AssociationProxyInstance: @classmethod def _construct_for_assoc( - cls, target_assoc, parent, owning_class, target_class, value_attr - ): + cls, + target_assoc: Optional[AssociationProxyInstance[_T]], + parent: _AssociationProxyProtocol[_T], + owning_class: Type[Any], + target_class: Type[Any], + value_attr: str, + ) -> AssociationProxyInstance[_T]: if target_assoc is not None: return ObjectAssociationProxyInstance( parent, owning_class, target_class, value_attr @@ -409,36 +619,41 @@ class AssociationProxyInstance: parent, owning_class, target_class, value_attr ) - def _get_property(self): + def _get_property(self) -> MapperProperty[Any]: return orm.class_mapper(self.owning_class).get_property( self.target_collection ) @property - def _comparator(self): + def _comparator(self) -> PropComparator[Any]: return self._get_property().comparator - def __clause_element__(self): + def __clause_element__(self) -> NoReturn: raise NotImplementedError( "The association proxy can't be used as a plain column " "expression; it only works inside of a comparison expression" ) @classmethod - def _cls_unwrap_target_assoc_proxy(cls, target_class, value_attr): + def _cls_unwrap_target_assoc_proxy( + cls, target_class: Any, value_attr: str + ) -> Optional[AssociationProxyInstance[_T]]: attr = getattr(target_class, value_attr) - if isinstance(attr, (AssociationProxy, AssociationProxyInstance)): + assert not isinstance(attr, AssociationProxy) + if isinstance(attr, AssociationProxyInstance): return attr return None @util.memoized_property - def _unwrap_target_assoc_proxy(self): + def _unwrap_target_assoc_proxy( + self, + ) -> Optional[AssociationProxyInstance[_T]]: return self._cls_unwrap_target_assoc_proxy( self.target_class, self.value_attr ) @property - def remote_attr(self): + def remote_attr(self) -> SQLORMOperations[_T]: """The 'remote' class attribute referenced by this :class:`.AssociationProxyInstance`. @@ -449,10 +664,12 @@ class AssociationProxyInstance: :attr:`.AssociationProxyInstance.local_attr` """ - return getattr(self.target_class, self.value_attr) + return cast( + "SQLORMOperations[_T]", getattr(self.target_class, self.value_attr) + ) @property - def local_attr(self): + def local_attr(self) -> SQLORMOperations[Any]: """The 'local' class attribute referenced by this :class:`.AssociationProxyInstance`. @@ -463,10 +680,13 @@ class AssociationProxyInstance: :attr:`.AssociationProxyInstance.remote_attr` """ - return getattr(self.owning_class, self.target_collection) + return cast( + "SQLORMOperations[Any]", + getattr(self.owning_class, self.target_collection), + ) @property - def attr(self): + def attr(self) -> Tuple[SQLORMOperations[Any], SQLORMOperations[_T]]: """Return a tuple of ``(local_attr, remote_attr)``. This attribute was originally intended to facilitate using the @@ -497,7 +717,7 @@ class AssociationProxyInstance: return (self.local_attr, self.remote_attr) @util.memoized_property - def scalar(self): + def scalar(self) -> bool: """Return ``True`` if this :class:`.AssociationProxyInstance` proxies a scalar relationship on the local side.""" @@ -507,7 +727,7 @@ class AssociationProxyInstance: return scalar @util.memoized_property - def _value_is_scalar(self): + def _value_is_scalar(self) -> bool: return ( not self._get_property() .mapper.get_property(self.value_attr) @@ -515,43 +735,63 @@ class AssociationProxyInstance: ) @property - def _target_is_object(self): + def _target_is_object(self) -> bool: raise NotImplementedError() - def _initialize_scalar_accessors(self): + _scalar_get: _GetterProtocol[_T] + _scalar_set: _PlainSetterProtocol[_T] + + def _initialize_scalar_accessors(self) -> None: if self.parent.getset_factory: get, set_ = self.parent.getset_factory(None, self) else: get, set_ = self.parent._default_getset(None) - self._scalar_get, self._scalar_set = get, set_ + self._scalar_get, self._scalar_set = get, cast( + "_PlainSetterProtocol[_T]", set_ + ) - def _default_getset(self, collection_class): + def _default_getset( + self, collection_class: Any + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]: attr = self.value_attr _getter = operator.attrgetter(attr) - def getter(target): - return _getter(target) if target is not None else None + def getter(instance: Any) -> Optional[_T]: + return _getter(instance) if instance is not None else None if collection_class is dict: - def setter(o, k, v): - return setattr(o, attr, v) + def dict_setter(instance: Any, k: Any, value: _T) -> None: + setattr(instance, attr, value) + return getter, dict_setter else: - def setter(o, v): - return setattr(o, attr, v) + def plain_setter(o: Any, v: _T) -> None: + setattr(o, attr, v) - return getter, setter + return getter, plain_setter @property - def info(self): + def info(self) -> Dict[Any, Any]: return self.parent.info - def get(self, obj): + @overload + def get(self: Self, obj: Literal[None]) -> Self: + ... + + @overload + def get(self, obj: Any) -> _T: + ... + + def get( + self, obj: Any + ) -> Union[Optional[_T], AssociationProxyInstance[_T]]: if obj is None: return self + proxy: _T + if self.scalar: target = getattr(obj, self.target_collection) return self._scalar_get(target) @@ -559,7 +799,9 @@ class AssociationProxyInstance: try: # If the owning instance is reborn (orm session resurrect, # etc.), refresh the proxy cache. - creator_id, self_id, proxy = getattr(obj, self.key) + creator_id, self_id, proxy = cast( + "Tuple[int, int, _T]", getattr(obj, self.key) + ) except AttributeError: pass else: @@ -573,12 +815,15 @@ class AssociationProxyInstance: setattr(obj, self.key, (id(obj), id(self), proxy)) return proxy - def set(self, obj, values): + def set(self, obj: Any, values: _T) -> None: if self.scalar: - creator = ( - self.parent.creator - if self.parent.creator - else self.target_class + creator = cast( + "_PlainCreatorProtocol[_T]", + ( + self.parent.creator + if self.parent.creator + else self.target_class + ), ) target = getattr(obj, self.target_collection) if target is None: @@ -595,7 +840,7 @@ class AssociationProxyInstance: if proxy is not values: proxy._bulk_replace(self, values) - def delete(self, obj): + def delete(self, obj: Any) -> None: if self.owning_class is None: self._calc_owner(obj, None) @@ -605,12 +850,21 @@ class AssociationProxyInstance: delattr(target, self.value_attr) delattr(obj, self.target_collection) - def _new(self, lazy_collection): + def _new( + self, lazy_collection: _LazyCollectionProtocol[_T] + ) -> Tuple[Type[Any], _T]: creator = ( - self.parent.creator if self.parent.creator else self.target_class + self.parent.creator + if self.parent.creator is not None + else cast("_CreatorProtocol[_T]", self.target_class) ) collection_class = util.duck_type_collection(lazy_collection()) + if collection_class is None: + raise exc.InvalidRequestError( + f"lazy collection factory did not return a " + f"valid collection type, got {collection_class}" + ) if self.parent.proxy_factory: return ( collection_class, @@ -627,22 +881,31 @@ class AssociationProxyInstance: if collection_class is list: return ( collection_class, - _AssociationList( - lazy_collection, creator, getter, setter, self + cast( + _T, + _AssociationList( + lazy_collection, creator, getter, setter, self + ), ), ) elif collection_class is dict: return ( collection_class, - _AssociationDict( - lazy_collection, creator, getter, setter, self + cast( + _T, + _AssociationDict( + lazy_collection, creator, getter, setter, self + ), ), ) elif collection_class is set: return ( collection_class, - _AssociationSet( - lazy_collection, creator, getter, setter, self + cast( + _T, + _AssociationSet( + lazy_collection, creator, getter, setter, self + ), ), ) else: @@ -650,27 +913,31 @@ class AssociationProxyInstance: "could not guess which interface to use for " 'collection_class "%s" backing "%s"; specify a ' "proxy_factory and proxy_bulk_set manually" - % (self.collection_class.__name__, self.target_collection) + % (self.collection_class, self.target_collection) ) - def _set(self, proxy, values): + def _set( + self, proxy: _AssociationCollection[Any], values: Iterable[Any] + ) -> None: if self.parent.proxy_bulk_set: self.parent.proxy_bulk_set(proxy, values) elif self.collection_class is list: - proxy.extend(values) + cast("_AssociationList[Any]", proxy).extend(values) elif self.collection_class is dict: - proxy.update(values) + cast("_AssociationDict[Any, Any]", proxy).update(values) elif self.collection_class is set: - proxy.update(values) + cast("_AssociationSet[Any]", proxy).update(values) else: raise exc.ArgumentError( "no proxy_bulk_set supplied for custom " "collection_class implementation" ) - def _inflate(self, proxy): + def _inflate(self, proxy: _AssociationCollection[Any]) -> None: creator = ( - self.parent.creator and self.parent.creator or self.target_class + self.parent.creator + and self.parent.creator + or cast(_CreatorProtocol[Any], self.target_class) ) if self.parent.getset_factory: @@ -684,12 +951,14 @@ class AssociationProxyInstance: proxy.getter = getter proxy.setter = setter - def _criterion_exists(self, criterion=None, **kwargs): + def _criterion_exists( + self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any + ) -> ColumnElement[bool]: is_has = kwargs.pop("is_has", None) target_assoc = self._unwrap_target_assoc_proxy if target_assoc is not None: - inner = target_assoc._criterion_exists( + inner = target_assoc._criterion_exists( # type: ignore criterion=criterion, **kwargs ) return self._comparator._criterion_exists(inner) @@ -713,7 +982,9 @@ class AssociationProxyInstance: return self._comparator._criterion_exists(value_expr) - def any(self, criterion=None, **kwargs): + def any( + self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any + ) -> SQLCoreOperations[Any]: """Produce a proxied 'any' expression using EXISTS. This expression will be a composed product @@ -733,7 +1004,9 @@ class AssociationProxyInstance: criterion=criterion, is_has=False, **kwargs ) - def has(self, criterion=None, **kwargs): + def has( + self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any + ) -> SQLCoreOperations[Any]: """Produce a proxied 'has' expression using EXISTS. This expression will be a composed product @@ -753,18 +1026,18 @@ class AssociationProxyInstance: criterion=criterion, is_has=True, **kwargs ) - def __repr__(self): + def __repr__(self) -> str: return "%s(%r)" % (self.__class__.__name__, self.parent) -class AmbiguousAssociationProxyInstance(AssociationProxyInstance): +class AmbiguousAssociationProxyInstance(AssociationProxyInstance[_T]): """an :class:`.AssociationProxyInstance` where we cannot determine the type of target object. """ _is_canonical = False - def _ambiguous(self): + def _ambiguous(self) -> NoReturn: raise AttributeError( "Association proxy %s.%s refers to an attribute '%s' that is not " "directly mapped on class %s; therefore this operation cannot " @@ -778,32 +1051,38 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance): ) ) - def get(self, obj): + def get(self, obj: Any) -> Any: if obj is None: return self else: return super(AmbiguousAssociationProxyInstance, self).get(obj) - def __eq__(self, obj): + def __eq__(self, obj: object) -> NoReturn: self._ambiguous() - def __ne__(self, obj): + def __ne__(self, obj: object) -> NoReturn: self._ambiguous() - def any(self, criterion=None, **kwargs): + def any( + self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any + ) -> NoReturn: self._ambiguous() - def has(self, criterion=None, **kwargs): + def has( + self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any + ) -> NoReturn: self._ambiguous() @util.memoized_property - def _lookup_cache(self): + def _lookup_cache(self) -> Dict[Type[Any], AssociationProxyInstance[_T]]: # mapping of ->AssociationProxyInstance. # e.g. proxy is A-> A.b -> B -> B.b_attr, but B.b_attr doesn't exist; # only B1(B) and B2(B) have "b_attr", keys in here would be B1, B2 return {} - def _non_canonical_get_for_object(self, parent_instance): + def _non_canonical_get_for_object( + self, parent_instance: Any + ) -> AssociationProxyInstance[_T]: if parent_instance is not None: actual_obj = getattr(parent_instance, self.target_collection) if actual_obj is not None: @@ -826,7 +1105,9 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance): # is a proxy with generally only instance-level functionality return self - def _populate_cache(self, instance_class, mapper): + def _populate_cache( + self, instance_class: Any, mapper: Mapper[Any] + ) -> None: prop = orm.class_mapper(self.owning_class).get_property( self.target_collection ) @@ -841,7 +1122,7 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance): pass else: self._lookup_cache[instance_class] = self._construct_for_assoc( - target_assoc, + cast("AssociationProxyInstance[_T]", target_assoc), self.parent, self.owning_class, target_class, @@ -849,13 +1130,13 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance): ) -class ObjectAssociationProxyInstance(AssociationProxyInstance): +class ObjectAssociationProxyInstance(AssociationProxyInstance[_T]): """an :class:`.AssociationProxyInstance` that has an object as a target.""" - _target_is_object = True + _target_is_object: bool = True _is_canonical = True - def contains(self, obj): + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: """Produce a proxied 'contains' expression using EXISTS. This expression will be a composed product @@ -868,9 +1149,9 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): target_assoc = self._unwrap_target_assoc_proxy if target_assoc is not None: return self._comparator._criterion_exists( - target_assoc.contains(obj) + target_assoc.contains(other) if not target_assoc.scalar - else target_assoc == obj + else target_assoc == other ) elif ( self._target_is_object @@ -878,7 +1159,7 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): and not self._value_is_scalar ): return self._comparator.has( - getattr(self.target_class, self.value_attr).contains(obj) + getattr(self.target_class, self.value_attr).contains(other) ) elif self._target_is_object and self.scalar and self._value_is_scalar: raise exc.InvalidRequestError( @@ -886,9 +1167,11 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): ) else: - return self._comparator._criterion_exists(**{self.value_attr: obj}) + return self._comparator._criterion_exists( + **{self.value_attr: other} + ) - def __eq__(self, obj): + def __eq__(self, obj: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa E501 # note the has() here will fail for collections; eq_() # is only allowed with a scalar. if obj is None: @@ -899,7 +1182,7 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): else: return self._comparator.has(**{self.value_attr: obj}) - def __ne__(self, obj): + def __ne__(self, obj: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa E501 # note the has() here will fail for collections; eq_() # is only allowed with a scalar. return self._comparator.has( @@ -907,72 +1190,95 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): ) -class ColumnAssociationProxyInstance( - ColumnOperators, AssociationProxyInstance -): +class ColumnAssociationProxyInstance(AssociationProxyInstance[_T]): """an :class:`.AssociationProxyInstance` that has a database column as a target. """ - _target_is_object = False + _target_is_object: bool = False _is_canonical = True - def __eq__(self, other): + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa E501 # special case "is None" to check for no related row as well expr = self._criterion_exists( - self.remote_attr.operate(operator.eq, other) + self.remote_attr.operate(operators.eq, other) ) if other is None: return or_(expr, self._comparator == None) else: return expr - def operate(self, op, *other, **kwargs): + def operate( + self, op: operators.OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: return self._criterion_exists( self.remote_attr.operate(op, *other, **kwargs) ) -class _lazy_collection: - def __init__(self, obj, target): +class _lazy_collection(_LazyCollectionProtocol[_T]): + def __init__(self, obj: Any, target: str): self.parent = obj self.target = target - def __call__(self): - return getattr(self.parent, self.target) + def __call__( + self, + ) -> Union[MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]]: + return getattr(self.parent, self.target) # type: ignore[no-any-return] - def __getstate__(self): + def __getstate__(self) -> Any: return {"obj": self.parent, "target": self.target} - def __setstate__(self, state): + def __setstate__(self, state: Any) -> None: self.parent = state["obj"] self.target = state["target"] -class _AssociationCollection: - def __init__(self, lazy_collection, creator, getter, setter, parent): - """Constructs an _AssociationCollection. +_IT = TypeVar("_IT", bound="Any") +"""instance type - this is the type of object inside a collection. - This will always be a subclass of either _AssociationList, - _AssociationSet, or _AssociationDict. +this is not the same as the _T of AssociationProxy and +AssociationProxyInstance itself, which will often refer to the +collection[_IT] type. - lazy_collection - A callable returning a list-based collection of entities (usually an - object attribute managed by a SQLAlchemy relationship()) +""" - creator - A function that creates new target entities. Given one parameter: - value. This assertion is assumed:: - obj = creator(somevalue) - assert getter(obj) == somevalue +class _AssociationCollection(Generic[_IT]): + getter: _GetterProtocol[_IT] + """A function. Given an associated object, return the 'value'.""" - getter - A function. Given an associated object, return the 'value'. + creator: _CreatorProtocol[_IT] + """ + A function that creates new target entities. Given one parameter: + value. This assertion is assumed:: - setter - A function. Given an associated object and a value, store that - value on the object. + obj = creator(somevalue) + assert getter(obj) == somevalue + """ + + parent: AssociationProxyInstance[_IT] + setter: _SetterProtocol[_IT] + """A function. Given an associated object and a value, store that + value on the object. + """ + + lazy_collection: _LazyCollectionProtocol[_IT] + """A callable returning a list-based collection of entities (usually an + object attribute managed by a SQLAlchemy relationship())""" + + def __init__( + self, + lazy_collection: _LazyCollectionProtocol[_IT], + creator: _CreatorProtocol[_IT], + getter: _GetterProtocol[_IT], + setter: _SetterProtocol[_IT], + parent: AssociationProxyInstance[_IT], + ): + """Constructs an _AssociationCollection. + + This will always be a subclass of either _AssociationList, + _AssociationSet, or _AssociationDict. """ self.lazy_collection = lazy_collection @@ -981,50 +1287,85 @@ class _AssociationCollection: self.setter = setter self.parent = parent - col = property(lambda self: self.lazy_collection()) + if typing.TYPE_CHECKING: + col: Collection[_IT] + else: + col = property(lambda self: self.lazy_collection()) - def __len__(self): + def __len__(self) -> int: return len(self.col) - def __bool__(self): + def __bool__(self) -> bool: return bool(self.col) __nonzero__ = __bool__ - def __getstate__(self): + def __getstate__(self) -> Any: return {"parent": self.parent, "lazy_collection": self.lazy_collection} - def __setstate__(self, state): + def __setstate__(self, state: Any) -> None: self.parent = state["parent"] self.lazy_collection = state["lazy_collection"] self.parent._inflate(self) - def _bulk_replace(self, assoc_proxy, values): + def clear(self) -> None: + raise NotImplementedError() + + +class _AssociationSingleItem(_AssociationCollection[_T]): + setter: _PlainSetterProtocol[_T] + creator: _PlainCreatorProtocol[_T] + + def _create(self, value: _T) -> Any: + return self.creator(value) + + def _get(self, object_: Any) -> _T: + return self.getter(object_) + + def _bulk_replace( + self, assoc_proxy: AssociationProxyInstance[Any], values: Iterable[_IT] + ) -> None: self.clear() assoc_proxy._set(self, values) -class _AssociationList(_AssociationCollection): +class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]): """Generic, converting, list-to-list proxy.""" - def _create(self, value): - return self.creator(value) + col: MutableSequence[_T] - def _get(self, object_): - return self.getter(object_) + def _set(self, object_: Any, value: _T) -> None: + self.setter(object_, value) + + @overload + def __getitem__(self, index: int) -> _T: + ... - def _set(self, object_, value): - return self.setter(object_, value) + @overload + def __getitem__(self, index: slice) -> MutableSequence[_T]: + ... - def __getitem__(self, index): + def __getitem__( + self, index: Union[int, slice] + ) -> Union[_T, MutableSequence[_T]]: if not isinstance(index, slice): return self._get(self.col[index]) else: return [self._get(member) for member in self.col[index]] - def __setitem__(self, index, value): + @overload + def __setitem__(self, index: int, value: _T) -> None: + ... + + @overload + def __setitem__(self, index: slice, value: Iterable[_T]) -> None: + ... + + def __setitem__( + self, index: Union[int, slice], value: Union[_T, Iterable[_T]] + ) -> None: if not isinstance(index, slice): - self._set(self.col[index], value) + self._set(self.col[index], cast("_T", value)) else: if index.stop is None: stop = len(self) @@ -1036,43 +1377,45 @@ class _AssociationList(_AssociationCollection): start = index.start or 0 rng = list(range(index.start or 0, stop, step)) + + sized_value = list(value) + if step == 1: for i in rng: del self[start] i = start - for item in value: + for item in sized_value: self.insert(i, item) i += 1 else: - if len(value) != len(rng): + if len(sized_value) != len(rng): raise ValueError( "attempt to assign sequence of size %s to " - "extended slice of size %s" % (len(value), len(rng)) + "extended slice of size %s" + % (len(sized_value), len(rng)) ) for i, item in zip(rng, value): self._set(self.col[i], item) - def __delitem__(self, index): + @overload + def __delitem__(self, index: int) -> None: + ... + + @overload + def __delitem__(self, index: slice) -> None: + ... + + def __delitem__(self, index: Union[slice, int]) -> None: del self.col[index] - def __contains__(self, value): + def __contains__(self, value: object) -> bool: for member in self.col: # testlib.pragma exempt:__eq__ if self._get(member) == value: return True return False - def __getslice__(self, start, end): - return [self._get(member) for member in self.col[start:end]] - - def __setslice__(self, start, end, values): - members = [self._create(v) for v in values] - self.col[start:end] = members - - def __delslice__(self, start, end): - del self.col[start:end] - - def __iter__(self): + def __iter__(self) -> Iterator[_T]: """Iterate over proxied values. For the actual domain objects, iterate over .col instead or @@ -1084,255 +1427,262 @@ class _AssociationList(_AssociationCollection): yield self._get(member) return - def append(self, value): + def append(self, value: _T) -> None: col = self.col item = self._create(value) col.append(item) - def count(self, value): + def count(self, value: Any) -> int: count = 0 for v in self: if v == value: count += 1 return count - def extend(self, values): + def extend(self, values: Iterable[_T]) -> None: for v in values: self.append(v) - def insert(self, index, value): + def insert(self, index: int, value: _T) -> None: self.col[index:index] = [self._create(value)] - def pop(self, index=-1): + def pop(self, index: int = -1) -> _T: return self.getter(self.col.pop(index)) - def remove(self, value): + def remove(self, value: _T) -> None: for i, val in enumerate(self): if val == value: del self.col[i] return raise ValueError("value not in list") - def reverse(self): + def reverse(self) -> NoReturn: """Not supported, use reversed(mylist)""" - raise NotImplementedError + raise NotImplementedError() - def sort(self): + def sort(self) -> NoReturn: """Not supported, use sorted(mylist)""" - raise NotImplementedError + raise NotImplementedError() - def clear(self): + def clear(self) -> None: del self.col[0 : len(self.col)] - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return list(self) == other - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return list(self) != other - def __lt__(self, other): + def __lt__(self, other: list[_T]) -> bool: return list(self) < other - def __le__(self, other): + def __le__(self, other: list[_T]) -> bool: return list(self) <= other - def __gt__(self, other): + def __gt__(self, other: list[_T]) -> bool: return list(self) > other - def __ge__(self, other): + def __ge__(self, other: list[_T]) -> bool: return list(self) >= other - def __cmp__(self, other): - return util.cmp(list(self), other) - - def __add__(self, iterable): + def __add__(self, other: list[_T]) -> list[_T]: try: - other = list(iterable) + other = list(other) except TypeError: return NotImplemented return list(self) + other - def __radd__(self, iterable): + def __radd__(self, other: list[_T]) -> list[_T]: try: - other = list(iterable) + other = list(other) except TypeError: return NotImplemented return other + list(self) - def __mul__(self, n): + def __mul__(self, n: SupportsIndex) -> list[_T]: if not isinstance(n, int): return NotImplemented return list(self) * n - __rmul__ = __mul__ + def __rmul__(self, n: SupportsIndex) -> list[_T]: + if not isinstance(n, int): + return NotImplemented + return n * list(self) - def __iadd__(self, iterable): + def __iadd__(self: Self, iterable: Iterable[_T]) -> Self: self.extend(iterable) return self - def __imul__(self, n): + def __imul__(self: Self, n: SupportsIndex) -> Self: # unlike a regular list *=, proxied __imul__ will generate unique # backing objects for each copy. *= on proxied lists is a bit of # a stretch anyhow, and this interpretation of the __imul__ contract # is more plausibly useful than copying the backing objects. if not isinstance(n, int): - return NotImplemented + raise NotImplementedError() if n == 0: self.clear() elif n > 1: self.extend(list(self) * (n - 1)) return self - def index(self, item, *args): - return list(self).index(item, *args) + if typing.TYPE_CHECKING: + # TODO: no idea how to do this without separate "stub" + def index(self, value: Any, start: int = ..., stop: int = ...) -> int: + ... + + else: + + def index(self, value: Any, *arg) -> int: + ls = list(self) + return ls.index(value, *arg) - def copy(self): + def copy(self) -> list[_T]: return list(self) - def __repr__(self): + def __repr__(self) -> str: return repr(list(self)) - def __hash__(self): + def __hash__(self) -> NoReturn: raise TypeError("%s objects are unhashable" % type(self).__name__) - for func_name, func in list(locals().items()): - if ( - callable(func) - and func.__name__ == func_name - and not func.__doc__ - and hasattr(list, func_name) - ): - func.__doc__ = getattr(list, func_name).__doc__ - del func_name, func - + if not typing.TYPE_CHECKING: + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): + func.__doc__ = getattr(list, func_name).__doc__ + del func_name, func -_NotProvided = util.symbol("_NotProvided") - -class _AssociationDict(_AssociationCollection): +class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]): """Generic, converting, dict-to-dict proxy.""" - def _create(self, key, value): + setter: _DictSetterProtocol[_VT] + creator: _KeyCreatorProtocol[_VT] + col: MutableMapping[_KT, Optional[_VT]] + + def _create(self, key: _KT, value: Optional[_VT]) -> Any: return self.creator(key, value) - def _get(self, object_): + def _get(self, object_: Any) -> _VT: return self.getter(object_) - def _set(self, object_, key, value): + def _set(self, object_: Any, key: _KT, value: _VT) -> None: return self.setter(object_, key, value) - def __getitem__(self, key): + def __getitem__(self, key: _KT) -> _VT: return self._get(self.col[key]) - def __setitem__(self, key, value): + def __setitem__(self, key: _KT, value: _VT) -> None: if key in self.col: self._set(self.col[key], key, value) else: self.col[key] = self._create(key, value) - def __delitem__(self, key): + def __delitem__(self, key: _KT) -> None: del self.col[key] - def __contains__(self, key): - # testlib.pragma exempt:__hash__ + def __contains__(self, key: object) -> bool: return key in self.col - def has_key(self, key): - # testlib.pragma exempt:__hash__ - return key in self.col - - def __iter__(self): + def __iter__(self) -> Iterator[_KT]: return iter(self.col.keys()) - def clear(self): + def clear(self) -> None: self.col.clear() - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return dict(self) == other - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return dict(self) != other - def __lt__(self, other): - return dict(self) < other - - def __le__(self, other): - return dict(self) <= other + def __repr__(self) -> str: + return repr(dict(self)) - def __gt__(self, other): - return dict(self) > other + @overload + def get(self, __key: _KT) -> Optional[_VT]: + ... - def __ge__(self, other): - return dict(self) >= other + @overload + def get(self, __key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: + ... - def __cmp__(self, other): - return util.cmp(dict(self), other) - - def __repr__(self): - return repr(dict(self.items())) - - def get(self, key, default=None): + def get( + self, key: _KT, default: Optional[Union[_VT, _T]] = None + ) -> Union[_VT, _T, None]: try: return self[key] except KeyError: return default - def setdefault(self, key, default=None): + def setdefault(self, key: _KT, default: Optional[_VT] = None) -> _VT: + # TODO: again, no idea how to create an actual MutableMapping. + # default must allow None, return type can't include None, + # the stub explicitly allows for default of None with a cryptic message + # "This overload should be allowed only if the value type is + # compatible with None.". if key not in self.col: self.col[key] = self._create(key, default) - return default + return default # type: ignore else: return self[key] - def keys(self): + def keys(self) -> KeysView[_KT]: return self.col.keys() - def items(self): - return ((key, self._get(self.col[key])) for key in self.col) + def items(self) -> ItemsView[_KT, _VT]: + return ItemsView(self) - def values(self): - return (self._get(self.col[key]) for key in self.col) + def values(self) -> ValuesView[_VT]: + return ValuesView(self) - def pop(self, key, default=_NotProvided): - if default is _NotProvided: - member = self.col.pop(key) - else: - member = self.col.pop(key, default) + @overload + def pop(self, __key: _KT) -> _VT: + ... + + @overload + def pop(self, __key: _KT, default: Union[_VT, _T] = ...) -> Union[_VT, _T]: + ... + + def pop(self, __key: _KT, *arg: Any, **kw: Any) -> Union[_VT, _T]: + member = self.col.pop(__key, *arg, **kw) return self._get(member) - def popitem(self): + def popitem(self) -> Tuple[_KT, _VT]: item = self.col.popitem() return (item[0], self._get(item[1])) - def update(self, *a, **kw): - if len(a) > 1: - raise TypeError( - "update expected at most 1 arguments, got %i" % len(a) - ) - elif len(a) == 1: - seq_or_map = a[0] - # discern dict from sequence - took the advice from - # https://www.voidspace.org.uk/python/articles/duck_typing.shtml - # still not perfect :( - if hasattr(seq_or_map, "keys"): - for item in seq_or_map: - self[item] = seq_or_map[item] - else: - try: - for k, v in seq_or_map: - self[k] = v - except ValueError as err: - raise ValueError( - "dictionary update sequence " - "requires 2-element tuples" - ) from err + @overload + def update(self, __m: Mapping[_KT, _VT], **kwargs: _VT) -> None: + ... + + @overload + def update(self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT) -> None: + ... - for key, value in kw: + @overload + def update(self, **kwargs: _VT) -> None: + ... + + def update(self, *a: Any, **kw: Any) -> None: + up: Dict[_KT, _VT] = {} + up.update(*a, **kw) + + for key, value in up.items(): self[key] = value - def _bulk_replace(self, assoc_proxy, values): + def _bulk_replace( + self, + assoc_proxy: AssociationProxyInstance[Any], + values: Mapping[_KT, _VT], + ) -> None: existing = set(self) constants = existing.intersection(values or ()) additions = set(values or ()).difference(constants) @@ -1347,36 +1697,33 @@ class _AssociationDict(_AssociationCollection): for key in removals: del self[key] - def copy(self): + def copy(self) -> dict[_KT, _VT]: return dict(self.items()) - def __hash__(self): + def __hash__(self) -> NoReturn: raise TypeError("%s objects are unhashable" % type(self).__name__) - for func_name, func in list(locals().items()): - if ( - callable(func) - and func.__name__ == func_name - and not func.__doc__ - and hasattr(dict, func_name) - ): - func.__doc__ = getattr(dict, func_name).__doc__ - del func_name, func + if not typing.TYPE_CHECKING: + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(dict, func_name) + ): + func.__doc__ = getattr(dict, func_name).__doc__ + del func_name, func -class _AssociationSet(_AssociationCollection): +class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]): """Generic, converting, set-to-set proxy.""" - def _create(self, value): - return self.creator(value) + col: MutableSet[_T] - def _get(self, object_): - return self.getter(object_) - - def __len__(self): + def __len__(self) -> int: return len(self.col) - def __bool__(self): + def __bool__(self) -> bool: if self.col: return True else: @@ -1384,14 +1731,13 @@ class _AssociationSet(_AssociationCollection): __nonzero__ = __bool__ - def __contains__(self, value): + def __contains__(self, __o: object) -> bool: for member in self.col: - # testlib.pragma exempt:__eq__ - if self._get(member) == value: + if self._get(member) == __o: return True return False - def __iter__(self): + def __iter__(self) -> Iterator[_T]: """Iterate over proxied values. For the actual domain objects, iterate over .col instead or just use @@ -1402,36 +1748,37 @@ class _AssociationSet(_AssociationCollection): yield self._get(member) return - def add(self, value): - if value not in self: - self.col.add(self._create(value)) + def add(self, __element: _T) -> None: + if __element not in self: + self.col.add(self._create(__element)) # for discard and remove, choosing a more expensive check strategy rather # than call self.creator() - def discard(self, value): + def discard(self, __element: _T) -> None: for member in self.col: - if self._get(member) == value: + if self._get(member) == __element: self.col.discard(member) break - def remove(self, value): + def remove(self, __element: _T) -> None: for member in self.col: - if self._get(member) == value: + if self._get(member) == __element: self.col.discard(member) return - raise KeyError(value) + raise KeyError(__element) - def pop(self): + def pop(self) -> _T: if not self.col: raise KeyError("pop from an empty set") member = self.col.pop() return self._get(member) - def update(self, other): - for value in other: - self.add(value) + def update(self, *s: Iterable[_T]) -> None: + for iterable in s: + for value in iterable: + self.add(value) - def _bulk_replace(self, assoc_proxy, values): + def _bulk_replace(self, assoc_proxy: Any, values: Iterable[_T]) -> None: existing = set(self) constants = existing.intersection(values or ()) additions = set(values or ()).difference(constants) @@ -1449,56 +1796,64 @@ class _AssociationSet(_AssociationCollection): for member in removals: remover(member) - def __ior__(self, other): + def __ior__( + self: Self, other: AbstractSet[_S] + ) -> MutableSet[Union[_T, _S]]: if not collections._set_binops_check_strict(self, other): - return NotImplemented + raise NotImplementedError() for value in other: self.add(value) return self - def _set(self): + def _set(self) -> Set[_T]: return set(iter(self)) - def union(self, other): - return set(self).union(other) + def union(self, *s: Iterable[_S]) -> MutableSet[Union[_T, _S]]: + return set(self).union(*s) - __or__ = union + def __or__(self, __s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: + return self.union(__s) - def difference(self, other): - return set(self).difference(other) + def difference(self, *s: Iterable[Any]) -> MutableSet[_T]: + return set(self).difference(*s) - __sub__ = difference + def __sub__(self, s: AbstractSet[Any]) -> MutableSet[_T]: + return self.difference(s) - def difference_update(self, other): - for value in other: - self.discard(value) + def difference_update(self, *s: Iterable[Any]) -> None: + for other in s: + for value in other: + self.discard(value) - def __isub__(self, other): - if not collections._set_binops_check_strict(self, other): - return NotImplemented - for value in other: + def __isub__(self: Self, s: AbstractSet[Any]) -> Self: + if not collections._set_binops_check_strict(self, s): + raise NotImplementedError() + for value in s: self.discard(value) return self - def intersection(self, other): - return set(self).intersection(other) + def intersection(self, *s: Iterable[Any]) -> MutableSet[_T]: + return set(self).intersection(*s) - __and__ = intersection + def __and__(self, s: AbstractSet[Any]) -> MutableSet[_T]: + return self.intersection(s) - def intersection_update(self, other): - want, have = self.intersection(other), set(self) + def intersection_update(self, *s: Iterable[Any]) -> None: + for other in s: + want, have = self.intersection(other), set(self) - remove, add = have - want, want - have + remove, add = have - want, want - have - for value in remove: - self.remove(value) - for value in add: - self.add(value) + for value in remove: + self.remove(value) + for value in add: + self.add(value) - def __iand__(self, other): - if not collections._set_binops_check_strict(self, other): - return NotImplemented - want, have = self.intersection(other), set(self) + def __iand__(self: Self, s: AbstractSet[Any]) -> Self: + if not collections._set_binops_check_strict(self, s): + raise NotImplementedError() + want = self.intersection(s) + have: Set[_T] = set(self) remove, add = have - want, want - have @@ -1508,12 +1863,13 @@ class _AssociationSet(_AssociationCollection): self.add(value) return self - def symmetric_difference(self, other): - return set(self).symmetric_difference(other) + def symmetric_difference(self, __s: Iterable[_T]) -> MutableSet[_T]: + return set(self).symmetric_difference(__s) - __xor__ = symmetric_difference + def __xor__(self, s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: + return self.symmetric_difference(s) # type: ignore - def symmetric_difference_update(self, other): + def symmetric_difference_update(self, other: Iterable[Any]) -> None: want, have = self.symmetric_difference(other), set(self) remove, add = have - want, want - have @@ -1523,61 +1879,56 @@ class _AssociationSet(_AssociationCollection): for value in add: self.add(value) - def __ixor__(self, other): + def __ixor__(self, other: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: if not collections._set_binops_check_strict(self, other): - return NotImplemented - want, have = self.symmetric_difference(other), set(self) - - remove, add = have - want, want - have + raise NotImplementedError() - for value in remove: - self.remove(value) - for value in add: - self.add(value) + self.symmetric_difference_update(other) return self - def issubset(self, other): - return set(self).issubset(other) + def issubset(self, __s: Iterable[Any]) -> bool: + return set(self).issubset(__s) - def issuperset(self, other): - return set(self).issuperset(other) + def issuperset(self, __s: Iterable[Any]) -> bool: + return set(self).issuperset(__s) - def clear(self): + def clear(self) -> None: self.col.clear() - def copy(self): + def copy(self) -> AbstractSet[_T]: return set(self) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return set(self) == other - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return set(self) != other - def __lt__(self, other): + def __lt__(self, other: AbstractSet[Any]) -> bool: return set(self) < other - def __le__(self, other): + def __le__(self, other: AbstractSet[Any]) -> bool: return set(self) <= other - def __gt__(self, other): + def __gt__(self, other: AbstractSet[Any]) -> bool: return set(self) > other - def __ge__(self, other): + def __ge__(self, other: AbstractSet[Any]) -> bool: return set(self) >= other - def __repr__(self): + def __repr__(self) -> str: return repr(set(self)) - def __hash__(self): + def __hash__(self) -> NoReturn: raise TypeError("%s objects are unhashable" % type(self).__name__) - for func_name, func in list(locals().items()): - if ( - callable(func) - and func.__name__ == func_name - and not func.__doc__ - and hasattr(set, func_name) - ): - func.__doc__ = getattr(set, func_name).__doc__ - del func_name, func + if not typing.TYPE_CHECKING: + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(set, func_name) + ): + func.__doc__ = getattr(set, func_name).__doc__ + del func_name, func diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index c7d9d4f887..dc34a2ef58 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -807,45 +807,51 @@ from typing import TypeVar from .. import util from ..orm import attributes +from ..orm import InspectionAttrExtensionType from ..orm import interfaces +from ..orm import ORMDescriptor + _T = TypeVar("_T", bound=Any) -HYBRID_METHOD = util.symbol("HYBRID_METHOD") -"""Symbol indicating an :class:`InspectionAttr` that's - of type :class:`.hybrid_method`. - Is assigned to the :attr:`.InspectionAttr.extension_type` - attribute. +class HybridExtensionType(InspectionAttrExtensionType): - .. seealso:: + HYBRID_METHOD = "HYBRID_METHOD" + """Symbol indicating an :class:`InspectionAttr` that's + of type :class:`.hybrid_method`. - :attr:`_orm.Mapper.all_orm_attributes` + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. -""" + .. seealso:: -HYBRID_PROPERTY = util.symbol("HYBRID_PROPERTY") -"""Symbol indicating an :class:`InspectionAttr` that's - of type :class:`.hybrid_method`. + :attr:`_orm.Mapper.all_orm_attributes` - Is assigned to the :attr:`.InspectionAttr.extension_type` - attribute. + """ - .. seealso:: + HYBRID_PROPERTY = "HYBRID_PROPERTY" + """Symbol indicating an :class:`InspectionAttr` that's + of type :class:`.hybrid_method`. - :attr:`_orm.Mapper.all_orm_attributes` + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. -""" + .. seealso:: + + :attr:`_orm.Mapper.all_orm_attributes` + + """ -class hybrid_method(interfaces.InspectionAttrInfo): +class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): """A decorator which allows definition of a Python object method with both instance-level and class-level behavior. """ is_attribute = True - extension_type = HYBRID_METHOD + extension_type = HybridExtensionType.HYBRID_METHOD def __init__(self, func, expr=None): """Create a new :class:`.hybrid_method`. @@ -890,7 +896,7 @@ class hybrid_property(interfaces.InspectionAttrInfo): """ is_attribute = True - extension_type = HYBRID_PROPERTY + extension_type = HybridExtensionType.HYBRID_PROPERTY def __init__( self, diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 5a8a0f6cf6..141702ae65 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -43,7 +43,11 @@ from ._orm_constructors import with_polymorphic as with_polymorphic from .attributes import AttributeEvent as AttributeEvent from .attributes import InstrumentedAttribute as InstrumentedAttribute from .attributes import QueryableAttribute as QueryableAttribute +from .base import class_mapper as class_mapper +from .base import InspectionAttrExtensionType as InspectionAttrExtensionType from .base import Mapped as Mapped +from .base import NotExtension as NotExtension +from .base import ORMDescriptor as ORMDescriptor from .context import QueryContext as QueryContext from .decl_api import add_mapped_attribute as add_mapped_attribute from .decl_api import as_declarative as as_declarative @@ -75,13 +79,11 @@ from .interfaces import InspectionAttrInfo as InspectionAttrInfo from .interfaces import MANYTOMANY as MANYTOMANY from .interfaces import MANYTOONE as MANYTOONE from .interfaces import MapperProperty as MapperProperty -from .interfaces import NOT_EXTENSION as NOT_EXTENSION from .interfaces import ONETOMANY as ONETOMANY from .interfaces import PropComparator as PropComparator from .interfaces import UserDefinedOption as UserDefinedOption from .loading import merge_frozen_result as merge_frozen_result from .loading import merge_result as merge_result -from .mapper import class_mapper as class_mapper from .mapper import configure_mappers as configure_mappers from .mapper import Mapper as Mapper from .mapper import reconstructor as reconstructor diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index b9c881cfe0..c63a89c704 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -11,14 +11,17 @@ from __future__ import annotations +from enum import Enum import operator import typing from typing import Any from typing import Callable +from typing import Dict from typing import Generic from typing import Optional from typing import overload from typing import Tuple +from typing import Type from typing import TypeVar from typing import Union @@ -29,11 +32,13 @@ from .. import util from ..sql.elements import SQLCoreOperations from ..util.langhelpers import TypingOnly from ..util.typing import Concatenate +from ..util.typing import Literal from ..util.typing import ParamSpec - +from ..util.typing import Self if typing.TYPE_CHECKING: from .attributes import InstrumentedAttribute + from .mapper import Mapper _T = TypeVar("_T", bound=Any) @@ -223,16 +228,22 @@ MANYTOMANY = util.symbol( """, ) -NOT_EXTENSION = util.symbol( - "NOT_EXTENSION", + +class InspectionAttrExtensionType(Enum): + """Symbols indicating the type of extension that a + :class:`.InspectionAttr` is part of.""" + + +class NotExtension(InspectionAttrExtensionType): + NOT_EXTENSION = "not_extension" """Symbol indicating an :class:`InspectionAttr` that's not part of sqlalchemy.ext. Is assigned to the :attr:`.InspectionAttr.extension_type` attribute. - """, -) + """ + _never_set = frozenset([NEVER_SET]) @@ -455,7 +466,7 @@ def _inspect_mapped_class(class_, configure=False): return mapper -def class_mapper(class_, configure=True): +def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]: """Given a class, return the primary :class:`_orm.Mapper` associated with the key. @@ -546,17 +557,15 @@ class InspectionAttr: """True if this object is an instance of :class:`_expression.ClauseElement`.""" - extension_type = NOT_EXTENSION + extension_type: InspectionAttrExtensionType = NotExtension.NOT_EXTENSION """The extension type, if any. - Defaults to :data:`.interfaces.NOT_EXTENSION` + Defaults to :attr:`.interfaces.NotExtension.NOT_EXTENSION` .. seealso:: - :data:`.HYBRID_METHOD` + :class:`.HybridExtensionType` - :data:`.HYBRID_PROPERTY` - - :data:`.ASSOCIATION_PROXY` + :class:`.AssociationProxyExtensionType` """ @@ -571,7 +580,7 @@ class InspectionAttrInfo(InspectionAttr): """ @util.memoized_property - def info(self): + def info(self) -> Dict[Any, Any]: """Info dictionary associated with the object, allowing user-defined data to be associated with this :class:`.InspectionAttr`. @@ -614,7 +623,35 @@ class SQLORMOperations(SQLCoreOperations[_T], TypingOnly): ... -class Mapped(Generic[_T], TypingOnly): +class ORMDescriptor(Generic[_T], TypingOnly): + """Represent any Python descriptor that provides a SQL expression + construct at the class level.""" + + __slots__ = () + + if typing.TYPE_CHECKING: + + @overload + def __get__(self: Self, instance: Any, owner: Literal[None]) -> Self: + ... + + @overload + def __get__( + self, instance: Literal[None], owner: Any + ) -> SQLORMOperations[_T]: + ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T: + ... + + def __get__( + self, instance: object, owner: Any + ) -> Union[SQLORMOperations[_T], _T]: + ... + + +class Mapped(ORMDescriptor[_T], TypingOnly): """Represent an ORM mapped attribute on a mapped class. This class represents the complete descriptor interface for any class @@ -646,7 +683,7 @@ class Mapped(Generic[_T], TypingOnly): @overload def __get__( self, instance: None, owner: Any - ) -> "InstrumentedAttribute[_T]": + ) -> InstrumentedAttribute[_T]: ... @overload @@ -655,11 +692,11 @@ class Mapped(Generic[_T], TypingOnly): def __get__( self, instance: object, owner: Any - ) -> Union["InstrumentedAttribute[_T]", _T]: + ) -> Union[InstrumentedAttribute[_T], _T]: ... @classmethod - def _empty_constructor(cls, arg1: Any) -> "Mapped[_T]": + def _empty_constructor(cls, arg1: Any) -> Mapped[_T]: ... def __set__( diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index d0cb53e29b..fe6dbfdc9a 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -293,7 +293,7 @@ class _GetColumns: ) desc = mp.all_orm_descriptors[key] - if desc.extension_type is interfaces.NOT_EXTENSION: + if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION: prop = desc.property if isinstance(prop, Synonym): key = prop.name diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 00ae9dac75..b1854de5a3 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -107,6 +107,7 @@ from __future__ import annotations import operator import threading import typing +from typing import Any import weakref from .. import exc as sa_exc @@ -1239,13 +1240,13 @@ def _dict_decorators(): _set_binop_bases = (set, frozenset) -def _set_binops_check_strict(self, obj): +def _set_binops_check_strict(self: Any, obj: Any) -> bool: """Allow only set, frozenset and self.__class__-derived objects in binops.""" return isinstance(obj, _set_binop_bases + (self.__class__,)) -def _set_binops_check_loose(self, obj): +def _set_binops_check_loose(self: Any, obj: Any) -> bool: """Allow anything set-like to participate in set binops.""" return ( isinstance(obj, _set_binop_bases + (self.__class__,)) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index eed9735263..04fc07f61b 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -31,17 +31,19 @@ from typing import Union from . import exc as orm_exc from . import path_registry -from .base import _MappedAttribute # noqa -from .base import EXT_CONTINUE -from .base import EXT_SKIP -from .base import EXT_STOP -from .base import InspectionAttr # noqa -from .base import InspectionAttrInfo # noqa -from .base import MANYTOMANY -from .base import MANYTOONE -from .base import NOT_EXTENSION -from .base import ONETOMANY +from .base import _MappedAttribute as _MappedAttribute +from .base import EXT_CONTINUE as EXT_CONTINUE +from .base import EXT_SKIP as EXT_SKIP +from .base import EXT_STOP as EXT_STOP +from .base import InspectionAttr as InspectionAttr +from .base import InspectionAttrExtensionType as InspectionAttrExtensionType +from .base import InspectionAttrInfo as InspectionAttrInfo +from .base import MANYTOMANY as MANYTOMANY +from .base import MANYTOONE as MANYTOONE +from .base import NotExtension as NotExtension +from .base import ONETOMANY as ONETOMANY from .base import SQLORMOperations +from .. import ColumnElement from .. import inspect from .. import inspection from .. import util @@ -51,6 +53,7 @@ from ..sql import visitors from ..sql._typing import _ColumnsClauseElement from ..sql.base import ExecutableOption from ..sql.cache_key import HasCacheKey +from ..sql.elements import SQLCoreOperations from ..sql.schema import Column from ..sql.type_api import TypeEngine from ..util.typing import TypedDict @@ -60,22 +63,6 @@ if typing.TYPE_CHECKING: _T = TypeVar("_T", bound=Any) -__all__ = ( - "EXT_CONTINUE", - "EXT_STOP", - "EXT_SKIP", - "ONETOMANY", - "MANYTOMANY", - "MANYTOONE", - "NOT_EXTENSION", - "LoaderStrategy", - "MapperOption", - "LoaderOption", - "MapperProperty", - "PropComparator", - "StrategizedProperty", -) - class ORMStatementRole(roles.StatementRole): __slots__ = () @@ -190,6 +177,10 @@ class MapperProperty( """ + comparator: PropComparator[_T] + """The :class:`_orm.PropComparator` instance that implements SQL + expression construction on behalf of this mapped attribute.""" + @property def _links_to_entity(self): """True if this MapperProperty refers to a mapped entity. @@ -512,6 +503,11 @@ class PropComparator( } ) + def _criterion_exists( + self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any + ) -> ColumnElement[Any]: + return self.prop.comparator._criterion_exists(criterion, **kwargs) + @property def adapter(self): """Produce a callable that adapts column expressions @@ -547,12 +543,12 @@ class PropComparator( def operate( self, op: operators.OperatorType, *other: Any, **kwargs: Any - ) -> "SQLORMOperations": + ) -> "SQLCoreOperations[Any]": ... def reverse_operate( self, op: operators.OperatorType, other: Any, **kwargs: Any - ) -> "SQLORMOperations": + ) -> "SQLCoreOperations[Any]": ... def of_type(self, class_) -> "SQLORMOperations[_T]": @@ -609,9 +605,11 @@ class PropComparator( """ return self.operate(operators.and_, *criteria) - def any(self, criterion=None, **kwargs) -> "SQLORMOperations[_T]": - r"""Return true if this collection contains any member that meets the - given criterion. + def any( + self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs + ) -> ColumnElement[bool]: + r"""Return a SQL expression representing true if this element + references a member which meets the given criterion. The usual implementation of ``any()`` is :meth:`.Relationship.Comparator.any`. @@ -627,9 +625,11 @@ class PropComparator( return self.operate(PropComparator.any_op, criterion, **kwargs) - def has(self, criterion=None, **kwargs) -> "SQLORMOperations[_T]": - r"""Return true if this element references a member which meets the - given criterion. + def has( + self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs + ) -> ColumnElement[bool]: + r"""Return a SQL expression representing true if this element + references a member which meets the given criterion. The usual implementation of ``has()`` is :meth:`.Relationship.Comparator.has`. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 15e9b84311..5a34188a9c 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -21,6 +21,7 @@ from functools import reduce from itertools import chain import sys import threading +from typing import Any from typing import Generic from typing import Type from typing import TypeVar @@ -113,6 +114,9 @@ class Mapper( _dispose_called = False _ready_for_configure = False + class_: Type[_MC] + """The class to which this :class:`_orm.Mapper` is mapped.""" + @util.deprecated_params( non_primary=( "1.3", @@ -1984,10 +1988,12 @@ class Mapper( else: return False - def has_property(self, key): + def has_property(self, key: str) -> bool: return key in self._props - def get_property(self, key, _configure_mappers=True): + def get_property( + self, key: str, _configure_mappers: bool = True + ) -> MapperProperty[Any]: """return a MapperProperty associated with the given key.""" if _configure_mappers: @@ -2715,7 +2721,7 @@ class Mapper( else: return _state_mapper(state) is s - def isa(self, other): + def isa(self, other: Mapper[Any]) -> bool: """Return True if the this mapper inherits from the given mapper.""" m = self diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 1b8f778c0a..b4697912bc 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -42,6 +42,7 @@ from .util import _orm_annotate from .util import _orm_deannotate from .util import CascadeOptions from .. import exc as sa_exc +from .. import Exists from .. import log from .. import schema from .. import sql @@ -52,6 +53,7 @@ from ..sql import expression from ..sql import operators from ..sql import roles from ..sql import visitors +from ..sql.elements import SQLCoreOperations from ..sql.util import _deep_deannotate from ..sql.util import _shallow_annotate from ..sql.util import adapt_criterion_to_null @@ -534,7 +536,11 @@ class Relationship( ) ) - def _criterion_exists(self, criterion=None, **kwargs): + def _criterion_exists( + self, + criterion: Optional[SQLCoreOperations[Any]] = None, + **kwargs: Any, + ) -> Exists[bool]: if getattr(self, "_of_type", None): info = inspect(self._of_type) target_mapper, to_selectable, is_aliased_class = ( @@ -1327,7 +1333,7 @@ class Relationship( return self.entity @util.memoized_property - def mapper(self) -> "Mapper": + def mapper(self) -> Mapper[_T]: """Return the targeted :class:`_orm.Mapper` for this :class:`.Relationship`. diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index 4f6ff06889..9d15cdcc3a 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -35,6 +35,7 @@ from .elements import FunctionFilter from .elements import Label from .elements import Null from .elements import Over +from .elements import SQLCoreOperations from .elements import TextClause from .elements import True_ from .elements import Tuple @@ -1228,7 +1229,7 @@ def nulls_last(column): return UnaryExpression._create_nulls_last(column) -def or_(*clauses): +def or_(*clauses: SQLCoreOperations) -> BooleanClauseList: """Produce a conjunction of expressions joined by ``OR``. E.g.:: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index ac5dc46db1..4c38c4efab 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -44,6 +44,7 @@ from .base import SingletonConstant from .cache_key import MemoizedHasCacheKey from .cache_key import NO_CACHE from .coercions import _document_text_coercion # noqa +from .operators import ColumnOperators from .traversals import HasCopyInternals from .visitors import cloned_traverse from .visitors import InternalTraversal @@ -57,6 +58,7 @@ from ..util.langhelpers import TypingOnly if typing.TYPE_CHECKING: from decimal import Decimal + from .operators import OperatorType from .selectable import FromClause from .selectable import Select from .sqltypes import Boolean # noqa @@ -586,7 +588,9 @@ class CompilerColumnElement( __slots__ = () -class SQLCoreOperations(Generic[_T], TypingOnly): +class SQLCoreOperations( + Generic[_T], ColumnOperators["SQLCoreOperations"], TypingOnly +): __slots__ = () # annotations for comparison methods @@ -594,6 +598,16 @@ class SQLCoreOperations(Generic[_T], TypingOnly): # redefined with the specific types returned by ColumnElement hierarchies if typing.TYPE_CHECKING: + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement: + ... + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement: + ... + def op( self, opstring: Any, @@ -620,34 +634,34 @@ class SQLCoreOperations(Generic[_T], TypingOnly): def __invert__(self) -> "UnaryExpression[_T]": ... - def __lt__(self, other: Any) -> "BinaryExpression[bool]": + def __lt__(self, other: Any) -> "ColumnElement[bool]": ... - def __le__(self, other: Any) -> "BinaryExpression[bool]": + def __le__(self, other: Any) -> "ColumnElement[bool]": ... - def __eq__(self, other: Any) -> "BinaryExpression[bool]": + def __eq__(self, other: Any) -> "ColumnElement[bool]": # type: ignore[override] # noqa: E501 ... - def __ne__(self, other: Any) -> "BinaryExpression[bool]": + def __ne__(self, other: Any) -> "ColumnElement[bool]": # type: ignore[override] # noqa: E501 ... - def is_distinct_from(self, other: Any) -> "BinaryExpression[bool]": + def is_distinct_from(self, other: Any) -> "ColumnElement[bool]": ... - def is_not_distinct_from(self, other: Any) -> "BinaryExpression[bool]": + def is_not_distinct_from(self, other: Any) -> "ColumnElement[bool]": ... - def __gt__(self, other: Any) -> "BinaryExpression[bool]": + def __gt__(self, other: Any) -> "ColumnElement[bool]": ... - def __ge__(self, other: Any) -> "BinaryExpression[bool]": + def __ge__(self, other: Any) -> "ColumnElement[bool]": ... def __neg__(self) -> "UnaryExpression[_T]": ... - def __contains__(self, other: Any) -> "BinaryExpression[bool]": + def __contains__(self, other: Any) -> "ColumnElement[bool]": ... def __getitem__(self, index: Any) -> "ColumnElement": @@ -656,14 +670,14 @@ class SQLCoreOperations(Generic[_T], TypingOnly): @overload def concat( self: "SQLCoreOperations[_ST]", other: Any - ) -> "BinaryExpression[_ST]": + ) -> "ColumnElement[_ST]": ... @overload - def concat(self, other: Any) -> "BinaryExpression": + def concat(self, other: Any) -> "ColumnElement": ... - def concat(self, other: Any) -> "BinaryExpression": + def concat(self, other: Any) -> "ColumnElement": ... def like(self, other: Any, escape=None) -> "BinaryExpression[bool]": @@ -702,30 +716,26 @@ class SQLCoreOperations(Generic[_T], TypingOnly): def startswith( self, other: Any, escape=None, autoescape=False - ) -> "BinaryExpression[bool]": + ) -> "ColumnElement[bool]": ... def endswith( self, other: Any, escape=None, autoescape=False - ) -> "BinaryExpression[bool]": + ) -> "ColumnElement[bool]": ... - def contains( - self, other: Any, escape=None, autoescape=False - ) -> "BinaryExpression[bool]": + def contains(self, other: Any, **kw: Any) -> "ColumnElement[bool]": ... - def match(self, other: Any, **kwargs) -> "BinaryExpression[bool]": + def match(self, other: Any, **kwargs) -> "ColumnElement[bool]": ... - def regexp_match( - self, pattern, flags=None - ) -> "BinaryExpression[bool]": + def regexp_match(self, pattern, flags=None) -> "ColumnElement[bool]": ... def regexp_replace( self, pattern, replacement, flags=None - ) -> "BinaryExpression": + ) -> "ColumnElement": ... def desc(self) -> "UnaryExpression[_T]": @@ -745,7 +755,7 @@ class SQLCoreOperations(Generic[_T], TypingOnly): def between( self, cleft, cright, symmetric=False - ) -> "BinaryExpression[bool]": + ) -> "ColumnElement[bool]": ... def distinct(self: "SQLCoreOperations[_T]") -> "UnaryExpression[_T]": @@ -766,166 +776,166 @@ class SQLCoreOperations(Generic[_T], TypingOnly): def __add__( self: "Union[_SQO[_NT], _SQO[Optional[_NT]]]", other: "Union[_SQO[Optional[_NT]], _SQO[_NT], _NT]", - ) -> "BinaryExpression[_NT]": + ) -> "ColumnElement[_NT]": ... @overload def __add__( self: "Union[_SQO[_NT], _SQO[Optional[_NT]]]", other: Any, - ) -> "BinaryExpression[_NUMERIC]": + ) -> "ColumnElement[_NUMERIC]": ... @overload def __add__( self: "Union[_SQO[_ST], _SQO[Optional[_ST]]]", other: Any, - ) -> "BinaryExpression[_ST]": + ) -> "ColumnElement[_ST]": ... - def __add__(self, other: Any) -> "BinaryExpression": + def __add__(self, other: Any) -> "ColumnElement": ... @overload - def __radd__(self, other: Any) -> "BinaryExpression[_NUMERIC]": + def __radd__(self, other: Any) -> "ColumnElement[_NUMERIC]": ... @overload - def __radd__(self, other: Any) -> "BinaryExpression": + def __radd__(self, other: Any) -> "ColumnElement": ... - def __radd__(self, other: Any) -> "BinaryExpression": + def __radd__(self, other: Any) -> "ColumnElement": ... @overload def __sub__( self: "SQLCoreOperations[_NT]", other: "Union[SQLCoreOperations[_NT], _NT]", - ) -> "BinaryExpression[_NT]": + ) -> "ColumnElement[_NT]": ... @overload - def __sub__(self, other: Any) -> "BinaryExpression": + def __sub__(self, other: Any) -> "ColumnElement": ... - def __sub__(self, other: Any) -> "BinaryExpression": + def __sub__(self, other: Any) -> "ColumnElement": ... @overload def __rsub__( self: "SQLCoreOperations[_NT]", other: Any - ) -> "BinaryExpression[_NUMERIC]": + ) -> "ColumnElement[_NUMERIC]": ... @overload - def __rsub__(self, other: Any) -> "BinaryExpression": + def __rsub__(self, other: Any) -> "ColumnElement": ... - def __rsub__(self, other: Any) -> "BinaryExpression": + def __rsub__(self, other: Any) -> "ColumnElement": ... @overload def __mul__( self: "SQLCoreOperations[_NT]", other: Any - ) -> "BinaryExpression[_NUMERIC]": + ) -> "ColumnElement[_NUMERIC]": ... @overload - def __mul__(self, other: Any) -> "BinaryExpression": + def __mul__(self, other: Any) -> "ColumnElement": ... - def __mul__(self, other: Any) -> "BinaryExpression": + def __mul__(self, other: Any) -> "ColumnElement": ... @overload def __rmul__( self: "SQLCoreOperations[_NT]", other: Any - ) -> "BinaryExpression[_NUMERIC]": + ) -> "ColumnElement[_NUMERIC]": ... @overload - def __rmul__(self, other: Any) -> "BinaryExpression": + def __rmul__(self, other: Any) -> "ColumnElement": ... - def __rmul__(self, other: Any) -> "BinaryExpression": + def __rmul__(self, other: Any) -> "ColumnElement": ... @overload def __mod__( self: "SQLCoreOperations[_NT]", other: Any - ) -> "BinaryExpression[_NUMERIC]": + ) -> "ColumnElement[_NUMERIC]": ... @overload - def __mod__(self, other: Any) -> "BinaryExpression": + def __mod__(self, other: Any) -> "ColumnElement": ... - def __mod__(self, other: Any) -> "BinaryExpression": + def __mod__(self, other: Any) -> "ColumnElement": ... @overload def __rmod__( self: "SQLCoreOperations[_NT]", other: Any - ) -> "BinaryExpression[_NUMERIC]": + ) -> "ColumnElement[_NUMERIC]": ... @overload - def __rmod__(self, other: Any) -> "BinaryExpression": + def __rmod__(self, other: Any) -> "ColumnElement": ... - def __rmod__(self, other: Any) -> "BinaryExpression": + def __rmod__(self, other: Any) -> "ColumnElement": ... @overload def __truediv__( self: "SQLCoreOperations[_NT]", other: Any - ) -> "BinaryExpression[_NUMERIC]": + ) -> "ColumnElement[_NUMERIC]": ... @overload - def __truediv__(self, other: Any) -> "BinaryExpression": + def __truediv__(self, other: Any) -> "ColumnElement": ... - def __truediv__(self, other: Any) -> "BinaryExpression": + def __truediv__(self, other: Any) -> "ColumnElement": ... @overload def __rtruediv__( self: "SQLCoreOperations[_NT]", other: Any - ) -> "BinaryExpression[_NUMERIC]": + ) -> "ColumnElement[_NUMERIC]": ... @overload - def __rtruediv__(self, other: Any) -> "BinaryExpression": + def __rtruediv__(self, other: Any) -> "ColumnElement": ... - def __rtruediv__(self, other: Any) -> "BinaryExpression": + def __rtruediv__(self, other: Any) -> "ColumnElement": ... @overload def __floordiv__( self: "SQLCoreOperations[_NT]", other: Any - ) -> "BinaryExpression[_NUMERIC]": + ) -> "ColumnElement[_NUMERIC]": ... @overload - def __floordiv__(self, other: Any) -> "BinaryExpression": + def __floordiv__(self, other: Any) -> "ColumnElement": ... - def __floordiv__(self, other: Any) -> "BinaryExpression": + def __floordiv__(self, other: Any) -> "ColumnElement": ... @overload def __rfloordiv__( self: "SQLCoreOperations[_NT]", other: Any - ) -> "BinaryExpression[_NUMERIC]": + ) -> "ColumnElement[_NUMERIC]": ... @overload - def __rfloordiv__(self, other: Any) -> "BinaryExpression": + def __rfloordiv__(self, other: Any) -> "ColumnElement": ... - def __rfloordiv__(self, other: Any) -> "BinaryExpression": + def __rfloordiv__(self, other: Any) -> "ColumnElement": ... diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index d4fa8042dd..f08e71bcd7 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -179,10 +179,10 @@ class Operators(Generic[_OP_RETURN]): precedence: int = 0, is_comparison: bool = False, return_type: Optional[ - Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] + Union[Type["TypeEngine[Any]"], "TypeEngine[Any]"] ] = None, python_impl=None, - ) -> Callable[[Any], _OP_RETURN]: + ) -> Callable[[Any], Any]: """Produce a generic operator function. e.g.:: @@ -270,7 +270,7 @@ class Operators(Generic[_OP_RETURN]): def bool_op( self, opstring: Any, precedence: int = 0, python_impl=None - ) -> Callable[[Any], _OP_RETURN]: + ) -> Callable[[Any], Any]: """Return a custom boolean operator. This method is shorthand for calling @@ -1021,9 +1021,7 @@ class ColumnOperators(Operators[_OP_RETURN]): endswith_op, other, escape=escape, autoescape=autoescape ) - def contains( - self, other: Any, escape=None, autoescape=False - ) -> "ColumnOperators": + def contains(self, other: Any, **kw: Any) -> "ColumnOperators": r"""Implement the 'contains' operator. Produces a LIKE expression that tests against a match for the middle @@ -1101,9 +1099,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ - return self.operate( - contains_op, other, escape=escape, autoescape=autoescape - ) + return self.operate(contains_op, other, **kw) def match(self, other: Any, **kwargs) -> "ColumnOperators": """Implements a database-specific 'match' operator. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 836c30af74..a5cbffb5e1 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -19,9 +19,11 @@ import itertools from operator import attrgetter import typing from typing import Any as TODO_Any +from typing import Any from typing import NamedTuple from typing import Optional from typing import Tuple +from typing import TypeVar from . import cache_key from . import coercions @@ -72,6 +74,8 @@ from .. import util and_ = BooleanClauseList.and_ +_T = TypeVar("_T", bound=Any) + class _OffsetLimitParam(BindParameter): inherit_cache = True @@ -5528,7 +5532,7 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): return self -class Exists(UnaryExpression): +class Exists(UnaryExpression[_T]): """Represent an ``EXISTS`` clause. See :func:`_sql.exists` for a description of usage. diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 1e79fd5474..5674e19afe 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1418,40 +1418,38 @@ def counter() -> Callable[[], int]: def duck_type_collection( - specimen: Union[object, Type[Any]], default: Optional[Type[Any]] = None -) -> Type[Any]: + specimen: Any, default: Optional[Type[Any]] = None +) -> Optional[Type[Any]]: """Given an instance or class, guess if it is or is acting as one of the basic collection types: list, set and dict. If the __emulates__ property is present, return that preferentially. """ - if typing.TYPE_CHECKING: - return object - else: - if hasattr(specimen, "__emulates__"): - # canonicalize set vs sets.Set to a standard: the builtin set - if specimen.__emulates__ is not None and issubclass( - specimen.__emulates__, set - ): - return set - else: - return specimen.__emulates__ - isa = isinstance(specimen, type) and issubclass or isinstance - if isa(specimen, list): - return list - elif isa(specimen, set): - return set - elif isa(specimen, dict): - return dict - - if hasattr(specimen, "append"): - return list - elif hasattr(specimen, "add"): + if hasattr(specimen, "__emulates__"): + # canonicalize set vs sets.Set to a standard: the builtin set + if specimen.__emulates__ is not None and issubclass( + specimen.__emulates__, set + ): return set - elif hasattr(specimen, "set"): - return dict else: - return default + return specimen.__emulates__ # type: ignore + + isa = isinstance(specimen, type) and issubclass or isinstance + if isa(specimen, list): + return list + elif isa(specimen, set): + return set + elif isa(specimen, dict): + return dict + + if hasattr(specimen, "append"): + return list + elif hasattr(specimen, "add"): + return set + elif hasattr(specimen, "set"): + return dict + else: + return default def assert_arg_type(arg: Any, argtype: Type[Any], name: str) -> Any: diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index ad9c8e5314..291061561d 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -19,6 +19,8 @@ from . import compat _T = TypeVar("_T", bound=Any) +Self = TypeVar("Self", bound=Any) + if compat.py310: # why they took until py310 to put this in stdlib is beyond me, # I've been wanting it since py27 @@ -26,6 +28,11 @@ if compat.py310: else: NoneType = type(None) # type: ignore +if typing.TYPE_CHECKING or compat.py38: + from typing import SupportsIndex as SupportsIndex +else: + from typing_extensions import SupportsIndex as SupportsIndex + if typing.TYPE_CHECKING or compat.py310: from typing import Annotated as Annotated else: diff --git a/pyproject.toml b/pyproject.toml index b2754b193d..cefda245f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ ignore_errors = true module = [ "sqlalchemy.connectors.*", "sqlalchemy.engine.*", + "sqlalchemy.ext.associationproxy", "sqlalchemy.pool.*", "sqlalchemy.event.*", "sqlalchemy.events", diff --git a/test/ext/mypy/plain_files/association_proxy_one.py b/test/ext/mypy/plain_files/association_proxy_one.py new file mode 100644 index 0000000000..c5c897956a --- /dev/null +++ b/test/ext/mypy/plain_files/association_proxy_one.py @@ -0,0 +1,47 @@ +import typing +from typing import Set + +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.ext.associationproxy import association_proxy +from sqlalchemy.ext.associationproxy import AssociationProxy +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship + + +class Base(DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "user" + + id = mapped_column(Integer, primary_key=True) + name = mapped_column(String, nullable=False) + + addresses: Mapped[Set["Address"]] = relationship() + + email_addresses: AssociationProxy[Set[str]] = association_proxy( + "addresses", "email" + ) + + +class Address(Base): + __tablename__ = "address" + + id = mapped_column(Integer, primary_key=True) + user_id = mapped_column(ForeignKey("user.id")) + email = mapped_column(String, nullable=False) + + +u1 = User() + +if typing.TYPE_CHECKING: + # EXPECTED_TYPE: sqlalchemy.*.associationproxy.AssociationProxyInstance\[builtins.set\*\[builtins.str\]\] + reveal_type(User.email_addresses) + + # EXPECTED_TYPE: builtins.set\*\[builtins.str\] + reveal_type(u1.email_addresses) diff --git a/test/ext/mypy/plain_files/sql_operations.py b/test/ext/mypy/plain_files/sql_operations.py new file mode 100644 index 0000000000..2cee1ddcab --- /dev/null +++ b/test/ext/mypy/plain_files/sql_operations.py @@ -0,0 +1,20 @@ +import typing + +from sqlalchemy import column +from sqlalchemy import Integer + +# builtin.pyi stubs define object.__eq__() as returning bool, which +# can't be overridden (it's final). So for us to type `__eq__()` and +# `__ne__()`, we have to use type: ignore[override]. Test if does this mean +# the typing tools don't know the type, or if they just ignore the error. +# (it's fortunately the former) +expr1 = column("x", Integer) == 10 + + +if typing.TYPE_CHECKING: + + # as far as if this is ColumnElement, BinaryElement, SQLCoreOperations, + # that might change. main thing is it's SomeSQLColThing[bool] and + # not 'bool' or 'Any'. + # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[builtins.bool\] + reveal_type(expr1) diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index 583898ce9f..484aed7953 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -1,3 +1,4 @@ +from collections import abc import copy import pickle from unittest.mock import call @@ -256,6 +257,23 @@ class _CollectionOperations(fixtures.MappedTest): ) cls.mapper_registry.map_imperatively(Child, children_table) + def test_abc(self): + Parent = self.classes.Parent + + p1 = Parent("x") + + collection_class = self.collection_class or list + + for abc_ in (abc.Set, abc.MutableMapping, abc.MutableSequence): + if issubclass(collection_class, abc_): + break + else: + abc_ = None + + if abc_: + p1 = Parent("x") + assert isinstance(p1.children, abc_) + def roundtrip(self, obj): if obj not in self.session: self.session.add(obj) @@ -512,6 +530,10 @@ class CustomDictTest(_CollectionOperations): p1.children["b"] = "proxied" + eq_(list(p1.children.keys()), ["a", "b"]) + eq_(list(p1.children.items()), [("a", "regular"), ("b", "proxied")]) + eq_(list(p1.children.values()), ["regular", "proxied"]) + self.assert_("proxied" in list(p1.children.values())) self.assert_("b" in p1.children) self.assert_("proxied" not in p1._children) @@ -2364,7 +2386,8 @@ class DictOfTupleUpdateTest(fixtures.MappedTest): a1 = self.classes.A() assert_raises_message( ValueError, - "dictionary update sequence requires " "2-element tuples", + "dictionary update sequence element #1 has length 5; " + "2 is required", a1.elements.update, (("B", 3), "elem2"), ) @@ -2373,7 +2396,7 @@ class DictOfTupleUpdateTest(fixtures.MappedTest): a1 = self.classes.A() assert_raises_message( TypeError, - "update expected at most 1 arguments, got 2", + "update expected at most 1 arguments?, got 2", a1.elements.update, (("B", 3), "elem2"), (("C", 4), "elem3"), diff --git a/test/orm/test_inspect.py b/test/orm/test_inspect.py index 6933777373..c6fc3ace19 100644 --- a/test/orm/test_inspect.py +++ b/test/orm/test_inspect.py @@ -235,16 +235,15 @@ class TestORMInspection(_fixtures.FixtureTest): def test_extension_types(self): from sqlalchemy.ext.associationproxy import ( association_proxy, - ASSOCIATION_PROXY, + AssociationProxyExtensionType, ) from sqlalchemy.ext.hybrid import ( hybrid_property, hybrid_method, - HYBRID_PROPERTY, - HYBRID_METHOD, + HybridExtensionType, ) from sqlalchemy import Table, MetaData, Integer, Column - from sqlalchemy.orm.interfaces import NOT_EXTENSION + from sqlalchemy.orm.interfaces import NotExtension class SomeClass(self.classes.User): some_assoc = association_proxy("addresses", "email_address") @@ -290,15 +289,15 @@ class TestORMInspection(_fixtures.FixtureTest): for k, v in list(insp.all_orm_descriptors.items()) ), { - "id": NOT_EXTENSION, - "name": NOT_EXTENSION, - "name_syn": NOT_EXTENSION, - "addresses": NOT_EXTENSION, - "orders": NOT_EXTENSION, - "upper_name": HYBRID_PROPERTY, - "foo": HYBRID_PROPERTY, - "conv": HYBRID_METHOD, - "some_assoc": ASSOCIATION_PROXY, + "id": NotExtension.NOT_EXTENSION, + "name": NotExtension.NOT_EXTENSION, + "name_syn": NotExtension.NOT_EXTENSION, + "addresses": NotExtension.NOT_EXTENSION, + "orders": NotExtension.NOT_EXTENSION, + "upper_name": HybridExtensionType.HYBRID_PROPERTY, + "foo": HybridExtensionType.HYBRID_PROPERTY, + "conv": HybridExtensionType.HYBRID_METHOD, + "some_assoc": AssociationProxyExtensionType.ASSOCIATION_PROXY, }, ) is_(