From 804ad7ee5bd17fc89c4596eb2299bbb613350547 Mon Sep 17 00:00:00 2001 From: Gleb Kisenkov Date: Tue, 13 Dec 2022 00:16:34 +0100 Subject: [PATCH] Runtime types collected --- lib/sqlalchemy/orm/events.py | 99 ++++++++++++++++++++---------------- 1 file changed, 55 insertions(+), 44 deletions(-) diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 32de155a15..e9f325ad33 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors """ORM event interfaces. @@ -31,6 +30,18 @@ from .. import event from .. import exc from .. import util from ..util.compat import inspect_getfullargspec +from sqlalchemy.orm.session import Session +from sqlalchemy.orm.state import InstanceState +from sqlalchemy.orm.base import EventConstants +from sqlalchemy.orm.query import Query +from typing import Union +from weakref import ReferenceType +from sqlalchemy.event.registry import _EventKey +from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept +from sqlalchemy.orm.decl_api import DeclarativeMeta +from sqlalchemy.orm.instrumentation import ClassManager +from sqlalchemy.orm.mapper import Mapper +from sqlalchemy.orm.attributes import InstrumentedAttribute if TYPE_CHECKING: from ._typing import _O @@ -63,21 +74,21 @@ class InstrumentationEvents(event.Events): _dispatch_target = instrumentation.InstrumentationFactory @classmethod - def _accept_with(cls, target, identifier): + def _accept_with(cls, target: type, identifier: str) -> _InstrumentationEventsHold: if isinstance(target, type): return _InstrumentationEventsHold(target) else: return None @classmethod - def _listen(cls, event_key, propagate=True, **kw): + def _listen(cls, event_key: _EventKey, propagate: bool = True, **kw: Any) -> None: target, identifier, fn = ( event_key.dispatch_target, event_key.identifier, event_key._listen_fn, ) - def listen(target_cls, *arg): + def listen(target_cls: Union[int, type], *arg: Any) -> Optional[Any]: listen_cls = target() # if weakref were collected, however this is not something @@ -92,7 +103,7 @@ class InstrumentationEvents(event.Events): elif not propagate and target_cls is listen_cls: return fn(target_cls, *arg) - def remove(ref): + def remove(ref: ReferenceType) -> None: key = event.registry._EventKey( None, identifier, @@ -110,7 +121,7 @@ class InstrumentationEvents(event.Events): ).with_wrapper(listen).base_listen(**kw) @classmethod - def _clear(cls): + def _clear(cls) -> None: super()._clear() instrumentation._instrumentation_factory.dispatch._clear() @@ -140,7 +151,7 @@ class _InstrumentationEventsHold: """ - def __init__(self, class_): + def __init__(self, class_: type) -> None: self.class_ = class_ dispatch = event.dispatcher(InstrumentationEvents) @@ -199,12 +210,12 @@ class InstanceEvents(event.Events): _dispatch_target = instrumentation.ClassManager @classmethod - def _new_classmanager_instance(cls, class_, classmanager): + def _new_classmanager_instance(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], classmanager: ClassManager) -> None: _InstanceEventsHold.populate(class_, classmanager) @classmethod @util.preload_module("sqlalchemy.orm") - def _accept_with(cls, target, identifier): + def _accept_with(cls, target: Any, identifier: str) -> Union[_InstanceEventsHold, ClassManager, type]: orm = util.preloaded.orm if isinstance(target, instrumentation.ClassManager): @@ -233,17 +244,17 @@ class InstanceEvents(event.Events): @classmethod def _listen( cls, - event_key, - raw=False, - propagate=False, - restore_load_context=False, - **kw, - ): + event_key: _EventKey, + raw: bool = False, + propagate: bool = False, + restore_load_context: bool = False, + **kw: Any, + ) -> None: target, fn = (event_key.dispatch_target, event_key._listen_fn) if not raw or restore_load_context: - def wrap(state, *arg, **kw): + def wrap(state: InstanceState, *arg: Any, **kw: Any) -> Optional[Any]: if not raw: target = state.obj() else: @@ -265,7 +276,7 @@ class InstanceEvents(event.Events): event_key.with_dispatch_target(mgr).base_listen(propagate=True) @classmethod - def _clear(cls): + def _clear(cls) -> None: super()._clear() _InstanceEventsHold._clear() @@ -563,11 +574,11 @@ class _EventsHold(event.RefCollection): """ - def __init__(self, class_): + def __init__(self, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type]) -> None: self.class_ = class_ @classmethod - def _clear(cls): + def _clear(cls) -> None: cls.all_holds.clear() class HoldEvents: @@ -575,8 +586,8 @@ class _EventsHold(event.RefCollection): @classmethod def _listen( - cls, event_key, raw=False, propagate=False, retval=False, **kw - ): + cls, event_key: _EventKey, raw: bool = False, propagate: bool = False, retval: bool = False, **kw: Any + ) -> None: target = event_key.dispatch_target if target.class_ in target.all_holds: @@ -606,7 +617,7 @@ class _EventsHold(event.RefCollection): raw=raw, propagate=False, retval=retval, **kw ) - def remove(self, event_key): + def remove(self, event_key: _EventKey) -> None: target = event_key.dispatch_target if isinstance(target, _EventsHold): @@ -614,7 +625,7 @@ class _EventsHold(event.RefCollection): del collection[event_key._key] @classmethod - def populate(cls, class_, subject): + def populate(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], subject: Union[ClassManager, Mapper]) -> None: for subclass in class_.__mro__: if subclass in cls.all_holds: collection = cls.all_holds[subclass] @@ -718,12 +729,12 @@ class MapperEvents(event.Events): _dispatch_target = mapperlib.Mapper @classmethod - def _new_mapper_instance(cls, class_, mapper): + def _new_mapper_instance(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], mapper: Mapper) -> None: _MapperEventsHold.populate(class_, mapper) @classmethod @util.preload_module("sqlalchemy.orm") - def _accept_with(cls, target, identifier): + def _accept_with(cls, target: Union[DeclarativeMeta, Mapper, type], identifier: str) -> Union[_MapperEventsHold, Mapper, type]: orm = util.preloaded.orm if target is orm.mapper: @@ -748,8 +759,8 @@ class MapperEvents(event.Events): @classmethod def _listen( - cls, event_key, raw=False, retval=False, propagate=False, **kw - ): + cls, event_key: _EventKey, raw: bool = False, retval: bool = False, propagate: bool = False, **kw: Any + ) -> None: target, identifier, fn = ( event_key.dispatch_target, event_key.identifier, @@ -776,7 +787,7 @@ class MapperEvents(event.Events): except ValueError: target_index = None - def wrap(*arg, **kw): + def wrap(*arg: Any, **kw: Any) -> EventConstants: if not raw and target_index is not None: arg = list(arg) arg[target_index] = arg[target_index].obj() @@ -797,7 +808,7 @@ class MapperEvents(event.Events): event_key.base_listen(**kw) @classmethod - def _clear(cls): + def _clear(cls) -> None: super()._clear() _MapperEventsHold._clear() @@ -1339,7 +1350,7 @@ class MapperEvents(event.Events): class _MapperEventsHold(_EventsHold): all_holds = weakref.WeakKeyDictionary() - def resolve(self, class_): + def resolve(self, class_: Union[DeclarativeMeta, type]) -> Optional[Mapper]: return _mapper_or_none(class_) class HoldMapperEvents(_EventsHold.HoldEvents, MapperEvents): @@ -1401,7 +1412,7 @@ class SessionEvents(event.Events[Session]): return fn @classmethod - def _accept_with(cls, target, identifier): + def _accept_with(cls, target: Any, identifier: str) -> Union[Session, type]: if isinstance(target, scoped_session): target = target.session_factory @@ -1447,7 +1458,7 @@ class SessionEvents(event.Events[Session]): fn = event_key._listen_fn - def wrap(session, state, *arg, **kw): + def wrap(session: Session, state: InstanceState, *arg: Any, **kw: Any) -> Optional[Any]: if not raw: target = state.obj() if target is None: @@ -2303,7 +2314,7 @@ class AttributeEvents(event.Events): return dispatch @classmethod - def _accept_with(cls, target, identifier): + def _accept_with(cls, target: InstrumentedAttribute, identifier: str) -> InstrumentedAttribute: # TODO: coverage if isinstance(target, interfaces.MapperProperty): return getattr(target.parent.class_, target.key) @@ -2313,13 +2324,13 @@ class AttributeEvents(event.Events): @classmethod def _listen( cls, - event_key, - active_history=False, - raw=False, - retval=False, - propagate=False, - include_key=False, - ): + event_key: _EventKey, + active_history: bool = False, + raw: bool = False, + retval: bool = False, + propagate: bool = False, + include_key: bool = False, + ) -> None: target, fn = event_key.dispatch_target, event_key._listen_fn @@ -2328,7 +2339,7 @@ class AttributeEvents(event.Events): if not raw or not retval or not include_key: - def wrap(target, *arg, **kw): + def wrap(target: InstanceState, *arg: Any, **kw: Any) -> str: if not raw: target = target.obj() if not retval: @@ -2973,12 +2984,12 @@ class QueryEvents(event.Events): """ @classmethod - def _listen(cls, event_key, retval=False, bake_ok=False, **kw): + def _listen(cls, event_key: _EventKey, retval: bool = False, bake_ok: bool = False, **kw: Any) -> None: fn = event_key._listen_fn if not retval: - def wrap(*arg, **kw): + def wrap(*arg: Query, **kw: Any) -> Query: if not retval: query = arg[0] fn(*arg, **kw) @@ -2989,7 +3000,7 @@ class QueryEvents(event.Events): event_key = event_key.with_wrapper(wrap) else: # don't assume we can apply an attribute to the callable - def wrap(*arg, **kw): + def wrap(*arg: Any, **kw: Any) -> Query: return fn(*arg, **kw) event_key = event_key.with_wrapper(wrap) -- 2.47.3