From: Gleb Kisenkov Date: Fri, 16 Dec 2022 17:46:08 +0000 (+0100) Subject: Covered low-hanging fruits of type refinement X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=896ae096608df533481ffd708f0ea36d01a129e4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Covered low-hanging fruits of type refinement --- diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index e9f325ad33..b96f5b8d14 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -4,6 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs """ORM event interfaces. @@ -11,9 +12,13 @@ from __future__ import annotations from typing import Any +from typing import Callable +from typing import Generic from typing import Optional +from typing import Set from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar import weakref from . import instrumentation @@ -29,26 +34,30 @@ from .session import sessionmaker from .. import event from .. import exc from .. import util +from ..event import EventTarget from ..util.compat import inspect_getfullargspec -from sqlalchemy.orm.session import Session -from sqlalchemy.orm.state import InstanceState -from sqlalchemy.orm.base import EventConstants -from sqlalchemy.orm.query import Query -from typing import Union -from weakref import ReferenceType -from sqlalchemy.event.registry import _EventKey -from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept -from sqlalchemy.orm.decl_api import DeclarativeMeta -from sqlalchemy.orm.instrumentation import ClassManager -from sqlalchemy.orm.mapper import Mapper -from sqlalchemy.orm.attributes import InstrumentedAttribute if TYPE_CHECKING: + from typing import Union + from weakref import ReferenceType + + from ._typing import _InternalEntityType from ._typing import _O + from ._typing import _T from .instrumentation import ClassManager + from ..event.registry import _ET + from ..event.registry import _EventKey + from ..orm.attributes import InstrumentedAttribute + from ..orm.base import EventConstants + from ..orm.decl_api import DeclarativeAttributeIntercept + from ..orm.decl_api import DeclarativeMeta + from ..orm.mapper import Mapper + from ..orm.state import InstanceState + +_ET2 = TypeVar("_ET2", bound=EventTarget) -class InstrumentationEvents(event.Events): +class InstrumentationEvents(event.Events[_T]): """Events related to class instrumentation events. The listeners here support being established against @@ -71,24 +80,28 @@ class InstrumentationEvents(event.Events): """ _target_class_doc = "SomeBaseClass" - _dispatch_target = instrumentation.InstrumentationFactory + _dispatch_target = instrumentation.InstrumentationFactory # type: ignore [assignment] # noqa: E501 @classmethod - def _accept_with(cls, target: type, identifier: str) -> _InstrumentationEventsHold: + def _accept_with( + cls, target: Union[_ET, Type[_ET]], identifier: str + ) -> Optional[Union[_ET, Type[_ET]]]: if isinstance(target, type): - return _InstrumentationEventsHold(target) + return _InstrumentationEventsHold(target) # type: ignore [return-value] # noqa: E501 else: return None @classmethod - def _listen(cls, event_key: _EventKey, propagate: bool = True, **kw: Any) -> None: + def _listen( + cls, event_key: _EventKey[_T], propagate: bool = True, **kw: Any + ) -> None: target, identifier, fn = ( event_key.dispatch_target, event_key.identifier, event_key._listen_fn, ) - def listen(target_cls: Union[int, type], *arg: Any) -> Optional[Any]: + def listen(target_cls: type, *arg: Any) -> Optional[Any]: listen_cls = target() # if weakref were collected, however this is not something @@ -102,9 +115,11 @@ class InstrumentationEvents(event.Events): return fn(target_cls, *arg) elif not propagate and target_cls is listen_cls: return fn(target_cls, *arg) + else: + return None - def remove(ref: ReferenceType) -> None: - key = event.registry._EventKey( + def remove(ref: ReferenceType[_T]) -> None: + key = event.registry._EventKey( # type: ignore [type-var] None, identifier, listen, @@ -157,7 +172,7 @@ class _InstrumentationEventsHold: dispatch = event.dispatcher(InstrumentationEvents) -class InstanceEvents(event.Events): +class InstanceEvents(event.Events[_ET]): """Define events specific to object lifecycle. e.g.:: @@ -207,15 +222,21 @@ class InstanceEvents(event.Events): _target_class_doc = "SomeClass" - _dispatch_target = instrumentation.ClassManager + _dispatch_target = instrumentation.ClassManager # type: ignore [assignment] # noqa: E501 @classmethod - def _new_classmanager_instance(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], classmanager: ClassManager) -> None: + def _new_classmanager_instance( + cls, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + classmanager: ClassManager[_O], + ) -> None: _InstanceEventsHold.populate(class_, classmanager) @classmethod @util.preload_module("sqlalchemy.orm") - def _accept_with(cls, target: Any, identifier: str) -> Union[_InstanceEventsHold, ClassManager, type]: + def _accept_with( + cls, target: Union[_ET, Type[_ET]], identifier: str + ) -> Optional[Union[_ET, Type[_ET]]]: orm = util.preloaded.orm if isinstance(target, instrumentation.ClassManager): @@ -244,7 +265,7 @@ class InstanceEvents(event.Events): @classmethod def _listen( cls, - event_key: _EventKey, + event_key: _EventKey[_ET], raw: bool = False, propagate: bool = False, restore_load_context: bool = False, @@ -254,7 +275,9 @@ class InstanceEvents(event.Events): if not raw or restore_load_context: - def wrap(state: InstanceState, *arg: Any, **kw: Any) -> Optional[Any]: + def wrap( + state: InstanceState[_O], *arg: Any, **kw: Any + ) -> Optional[Any]: if not raw: target = state.obj() else: @@ -566,7 +589,7 @@ class InstanceEvents(event.Events): """ -class _EventsHold(event.RefCollection): +class _EventsHold(event.RefCollection[_ET]): """Hold onto listeners against unmapped, uninstrumented classes. Establish _listen() for that class' mapper/instrumentation when @@ -574,19 +597,27 @@ class _EventsHold(event.RefCollection): """ - def __init__(self, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type]) -> None: + def __init__( + self, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + ) -> None: self.class_ = class_ @classmethod def _clear(cls) -> None: cls.all_holds.clear() - class HoldEvents: - _dispatch_target = None + class HoldEvents(Generic[_ET2]): + _dispatch_target: Optional[Type[_ET2]] = None @classmethod def _listen( - cls, event_key: _EventKey, raw: bool = False, propagate: bool = False, retval: bool = False, **kw: Any + cls, + event_key: _EventKey[_ET2], + raw: bool = False, + propagate: bool = False, + retval: bool = False, + **kw: Any, ) -> None: target = event_key.dispatch_target @@ -617,7 +648,7 @@ class _EventsHold(event.RefCollection): raw=raw, propagate=False, retval=retval, **kw ) - def remove(self, event_key: _EventKey) -> None: + def remove(self, event_key: _EventKey[_ET]) -> None: target = event_key.dispatch_target if isinstance(target, _EventsHold): @@ -625,7 +656,11 @@ class _EventsHold(event.RefCollection): del collection[event_key._key] @classmethod - def populate(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], subject: Union[ClassManager, Mapper]) -> None: + def populate( + cls, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + subject: Union[ClassManager[_O], Mapper[_O]], + ) -> None: for subclass in class_.__mro__: if subclass in cls.all_holds: collection = cls.all_holds[subclass] @@ -647,19 +682,19 @@ class _EventsHold(event.RefCollection): ) -class _InstanceEventsHold(_EventsHold): +class _InstanceEventsHold(_EventsHold[_ET]): all_holds = weakref.WeakKeyDictionary() def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]: return instrumentation.opt_manager_of_class(class_) - class HoldInstanceEvents(_EventsHold.HoldEvents, InstanceEvents): + class HoldInstanceEvents(_EventsHold.HoldEvents[_ET], InstanceEvents[_ET]): pass dispatch = event.dispatcher(HoldInstanceEvents) -class MapperEvents(event.Events): +class MapperEvents(event.Events[_ET]): """Define events specific to mappings. e.g.:: @@ -729,12 +764,18 @@ class MapperEvents(event.Events): _dispatch_target = mapperlib.Mapper @classmethod - def _new_mapper_instance(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], mapper: Mapper) -> None: + def _new_mapper_instance( + cls, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + mapper: Mapper[_O], + ) -> None: _MapperEventsHold.populate(class_, mapper) @classmethod @util.preload_module("sqlalchemy.orm") - def _accept_with(cls, target: Union[DeclarativeMeta, Mapper, type], identifier: str) -> Union[_MapperEventsHold, Mapper, type]: + def _accept_with( + cls, target: Union[DeclarativeMeta, Mapper[_O], type], identifier: str + ) -> Union[_MapperEventsHold[_ET], Mapper[_O], type]: orm = util.preloaded.orm if target is orm.mapper: @@ -759,7 +800,12 @@ class MapperEvents(event.Events): @classmethod def _listen( - cls, event_key: _EventKey, raw: bool = False, retval: bool = False, propagate: bool = False, **kw: Any + cls, + event_key: _EventKey[_ET], + raw: bool = False, + retval: bool = False, + propagate: bool = False, + **kw: Any, ) -> None: target, identifier, fn = ( event_key.dispatch_target, @@ -1347,19 +1393,21 @@ class MapperEvents(event.Events): """ -class _MapperEventsHold(_EventsHold): +class _MapperEventsHold(_EventsHold[_ET]): all_holds = weakref.WeakKeyDictionary() - def resolve(self, class_: Union[DeclarativeMeta, type]) -> Optional[Mapper]: + def resolve( + self, class_: Union[Type[_T], _InternalEntityType[_T]] + ) -> Optional[Mapper[_T]]: return _mapper_or_none(class_) - class HoldMapperEvents(_EventsHold.HoldEvents, MapperEvents): + class HoldMapperEvents(_EventsHold.HoldEvents[_ET], MapperEvents[_ET]): pass dispatch = event.dispatcher(HoldMapperEvents) -_sessionevents_lifecycle_event_names = set() +_sessionevents_lifecycle_event_names: Set[str] = set() class SessionEvents(event.Events[Session]): @@ -1407,12 +1455,16 @@ class SessionEvents(event.Events[Session]): _dispatch_target = Session - def _lifecycle_event(fn): + def _lifecycle_event( + fn: Callable[[SessionEvents, Session, Any], None] + ) -> Callable[[Session, Any], None]: _sessionevents_lifecycle_event_names.add(fn.__name__) return fn @classmethod - def _accept_with(cls, target: Any, identifier: str) -> Union[Session, type]: + def _accept_with( + cls, target: Any, identifier: str + ) -> Union[Session, type]: if isinstance(target, scoped_session): target = target.session_factory @@ -1458,7 +1510,12 @@ class SessionEvents(event.Events[Session]): fn = event_key._listen_fn - def wrap(session: Session, state: InstanceState, *arg: Any, **kw: Any) -> Optional[Any]: + def wrap( + session: Session, + state: InstanceState[_O], + *arg: Any, + **kw: Any, + ) -> Optional[Any]: if not raw: target = state.obj() if target is None: @@ -2235,7 +2292,7 @@ class SessionEvents(event.Events[Session]): """ -class AttributeEvents(event.Events): +class AttributeEvents(event.Events[_ET]): r"""Define events for object attributes. These are typically defined on the class-bound descriptor for the @@ -2305,7 +2362,7 @@ class AttributeEvents(event.Events): """ _target_class_doc = "SomeClass.some_attribute" - _dispatch_target = QueryableAttribute + _dispatch_target = QueryableAttribute # type: ignore [assignment] @staticmethod def _set_dispatch(cls, dispatch_cls): @@ -2314,7 +2371,9 @@ class AttributeEvents(event.Events): return dispatch @classmethod - def _accept_with(cls, target: InstrumentedAttribute, identifier: str) -> InstrumentedAttribute: + def _accept_with( + cls, target: InstrumentedAttribute[_T], identifier: str + ) -> InstrumentedAttribute[_T]: # TODO: coverage if isinstance(target, interfaces.MapperProperty): return getattr(target.parent.class_, target.key) @@ -2324,7 +2383,7 @@ class AttributeEvents(event.Events): @classmethod def _listen( cls, - event_key: _EventKey, + event_key: _EventKey[_ET], active_history: bool = False, raw: bool = False, retval: bool = False, @@ -2339,7 +2398,7 @@ class AttributeEvents(event.Events): if not raw or not retval or not include_key: - def wrap(target: InstanceState, *arg: Any, **kw: Any) -> str: + def wrap(target: InstanceState[_O], *arg: Any, **kw: Any) -> Any: if not raw: target = target.obj() if not retval: @@ -2809,7 +2868,7 @@ class AttributeEvents(event.Events): """ -class QueryEvents(event.Events): +class QueryEvents(event.Events[_ET]): """Represent events within the construction of a :class:`_query.Query` object. @@ -2826,7 +2885,7 @@ class QueryEvents(event.Events): """ _target_class_doc = "SomeQuery" - _dispatch_target = Query + _dispatch_target = Query # type: ignore [assignment] def before_compile(self, query): """Receive the :class:`_query.Query` @@ -2984,12 +3043,18 @@ class QueryEvents(event.Events): """ @classmethod - def _listen(cls, event_key: _EventKey, retval: bool = False, bake_ok: bool = False, **kw: Any) -> None: + def _listen( + cls, + event_key: _EventKey[_ET], + retval: bool = False, + bake_ok: bool = False, + **kw: Any, + ) -> None: fn = event_key._listen_fn if not retval: - def wrap(*arg: Query, **kw: Any) -> Query: + def wrap(*arg: Any, **kw: Any) -> Any: if not retval: query = arg[0] fn(*arg, **kw) @@ -3000,11 +3065,11 @@ class QueryEvents(event.Events): event_key = event_key.with_wrapper(wrap) else: # don't assume we can apply an attribute to the callable - def wrap(*arg: Any, **kw: Any) -> Query: + def wrap(*arg: Any, **kw: Any) -> Any: return fn(*arg, **kw) event_key = event_key.with_wrapper(wrap) - wrap._bake_ok = bake_ok + wrap._bake_ok = bake_ok # type: ignore [attr-defined] event_key.base_listen(**kw)