From: Gleb Kisenkov Date: Wed, 21 Dec 2022 22:47:29 +0000 (+0100) Subject: Typed generic Events X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cb31d38a506efa6fd1c18d788e13b43b2d33c358;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Typed generic Events --- diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index b96f5b8d14..346f43a2b9 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -47,8 +47,6 @@ if TYPE_CHECKING: from .instrumentation import ClassManager from ..event.registry import _ET from ..event.registry import _EventKey - from ..orm.attributes import InstrumentedAttribute - from ..orm.base import EventConstants from ..orm.decl_api import DeclarativeAttributeIntercept from ..orm.decl_api import DeclarativeMeta from ..orm.mapper import Mapper @@ -57,7 +55,9 @@ if TYPE_CHECKING: _ET2 = TypeVar("_ET2", bound=EventTarget) -class InstrumentationEvents(event.Events[_T]): +class InstrumentationEvents( + event.Events[instrumentation.InstrumentationFactory] +): """Events related to class instrumentation events. The listeners here support being established against @@ -80,12 +80,22 @@ class InstrumentationEvents(event.Events[_T]): """ _target_class_doc = "SomeBaseClass" - _dispatch_target = instrumentation.InstrumentationFactory # type: ignore [assignment] # noqa: E501 + _dispatch_target = instrumentation.InstrumentationFactory @classmethod def _accept_with( - cls, target: Union[_ET, Type[_ET]], identifier: str - ) -> Optional[Union[_ET, Type[_ET]]]: + cls, + target: Union[ + instrumentation.InstrumentationFactory, + Type[instrumentation.InstrumentationFactory], + ], + identifier: str, + ) -> Optional[ + Union[ + instrumentation.InstrumentationFactory, + Type[instrumentation.InstrumentationFactory], + ] + ]: if isinstance(target, type): return _InstrumentationEventsHold(target) # type: ignore [return-value] # noqa: E501 else: @@ -172,7 +182,7 @@ class _InstrumentationEventsHold: dispatch = event.dispatcher(InstrumentationEvents) -class InstanceEvents(event.Events[_ET]): +class InstanceEvents(event.Events[instrumentation.ClassManager[Any]]): """Define events specific to object lifecycle. e.g.:: @@ -222,7 +232,7 @@ class InstanceEvents(event.Events[_ET]): _target_class_doc = "SomeClass" - _dispatch_target = instrumentation.ClassManager # type: ignore [assignment] # noqa: E501 + _dispatch_target = instrumentation.ClassManager @classmethod def _new_classmanager_instance( @@ -235,15 +245,25 @@ class InstanceEvents(event.Events[_ET]): @classmethod @util.preload_module("sqlalchemy.orm") def _accept_with( - cls, target: Union[_ET, Type[_ET]], identifier: str - ) -> Optional[Union[_ET, Type[_ET]]]: + cls, + target: Union[ + instrumentation.ClassManager[Any], + Type[instrumentation.ClassManager[Any]], + ], + identifier: str, + ) -> Optional[ + Union[ + instrumentation.ClassManager[Any], + Type[instrumentation.ClassManager[Any]], + ] + ]: orm = util.preloaded.orm if isinstance(target, instrumentation.ClassManager): return target elif isinstance(target, mapperlib.Mapper): return target.class_manager - elif target is orm.mapper: + elif target is orm.mapper: # type: ignore [attr-defined] util.warn_deprecated( "The `sqlalchemy.orm.mapper()` symbol is deprecated and " "will be removed in a future release. For the mapper-wide " @@ -259,13 +279,13 @@ class InstanceEvents(event.Events[_ET]): if manager: return manager else: - return _InstanceEventsHold(target) + return _InstanceEventsHold(target) # type: ignore [return-value] # noqa: E501 return None @classmethod def _listen( cls, - event_key: _EventKey[_ET], + event_key: _EventKey[instrumentation.ClassManager[Any]], raw: bool = False, propagate: bool = False, restore_load_context: bool = False, @@ -279,7 +299,7 @@ class InstanceEvents(event.Events[_ET]): state: InstanceState[_O], *arg: Any, **kw: Any ) -> Optional[Any]: if not raw: - target = state.obj() + target: Any = state.obj() else: target = state if restore_load_context: @@ -597,6 +617,8 @@ class _EventsHold(event.RefCollection[_ET]): """ + all_holds: weakref.WeakKeyDictionary[Any, Any] + def __init__( self, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], @@ -683,18 +705,20 @@ class _EventsHold(event.RefCollection[_ET]): class _InstanceEventsHold(_EventsHold[_ET]): - all_holds = weakref.WeakKeyDictionary() + all_holds: weakref.WeakKeyDictionary[ + Any, Any + ] = weakref.WeakKeyDictionary() def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]: return instrumentation.opt_manager_of_class(class_) - class HoldInstanceEvents(_EventsHold.HoldEvents[_ET], InstanceEvents[_ET]): + class HoldInstanceEvents(_EventsHold.HoldEvents[_ET], InstanceEvents): # type: ignore [misc] # noqa: E501 pass dispatch = event.dispatcher(HoldInstanceEvents) -class MapperEvents(event.Events[_ET]): +class MapperEvents(event.Events[mapperlib.Mapper[Any]]): """Define events specific to mappings. e.g.:: @@ -774,11 +798,13 @@ class MapperEvents(event.Events[_ET]): @classmethod @util.preload_module("sqlalchemy.orm") def _accept_with( - cls, target: Union[DeclarativeMeta, Mapper[_O], type], identifier: str - ) -> Union[_MapperEventsHold[_ET], Mapper[_O], type]: + cls, + target: Union[mapperlib.Mapper[Any], Type[mapperlib.Mapper[Any]]], + identifier: str, + ) -> Optional[Union[mapperlib.Mapper[Any], Type[mapperlib.Mapper[Any]]]]: orm = util.preloaded.orm - if target is orm.mapper: + if target is orm.mapper: # type: ignore [attr-defined] util.warn_deprecated( "The `sqlalchemy.orm.mapper()` symbol is deprecated and " "will be removed in a future release. For the mapper-wide " @@ -833,10 +859,10 @@ class MapperEvents(event.Events[_ET]): except ValueError: target_index = None - def wrap(*arg: Any, **kw: Any) -> EventConstants: + def wrap(*arg: Any, **kw: Any) -> Any: if not raw and target_index is not None: - arg = list(arg) - arg[target_index] = arg[target_index].obj() + arg = list(arg) # type: ignore [assignment] + arg[target_index] = arg[target_index].obj() # type: ignore [index] # noqa: E501 if not retval: fn(*arg, **kw) return interfaces.EXT_CONTINUE @@ -1401,7 +1427,7 @@ class _MapperEventsHold(_EventsHold[_ET]): ) -> Optional[Mapper[_T]]: return _mapper_or_none(class_) - class HoldMapperEvents(_EventsHold.HoldEvents[_ET], MapperEvents[_ET]): + class HoldMapperEvents(_EventsHold.HoldEvents[_ET], MapperEvents): # type: ignore [misc] # noqa: E501 pass dispatch = event.dispatcher(HoldMapperEvents) @@ -1455,14 +1481,14 @@ class SessionEvents(event.Events[Session]): _dispatch_target = Session - def _lifecycle_event( + def _lifecycle_event( # type: ignore [misc] fn: Callable[[SessionEvents, Session, Any], None] ) -> Callable[[Session, Any], None]: _sessionevents_lifecycle_event_names.add(fn.__name__) - return fn + return fn # type: ignore [return-value] @classmethod - def _accept_with( + def _accept_with( # type: ignore [return] cls, target: Any, identifier: str ) -> Union[Session, type]: if isinstance(target, scoped_session): @@ -1490,7 +1516,7 @@ class SessionEvents(event.Events[Session]): target._no_async_engine_events() else: # allows alternate SessionEvents-like-classes to be consulted - return event.Events._accept_with(target, identifier) + return event.Events._accept_with(target, identifier) # type: ignore [return-value] # noqa: E501 @classmethod def _listen( @@ -1521,9 +1547,9 @@ class SessionEvents(event.Events[Session]): if target is None: # existing behavior is that if the object is # garbage collected, no event is emitted - return + return None else: - target = state + target = state # type: ignore [assignment] if restore_load_context: runid = state.runid try: @@ -2292,7 +2318,7 @@ class SessionEvents(event.Events[Session]): """ -class AttributeEvents(event.Events[_ET]): +class AttributeEvents(event.Events[QueryableAttribute[Any]]): r"""Define events for object attributes. These are typically defined on the class-bound descriptor for the @@ -2362,7 +2388,7 @@ class AttributeEvents(event.Events[_ET]): """ _target_class_doc = "SomeClass.some_attribute" - _dispatch_target = QueryableAttribute # type: ignore [assignment] + _dispatch_target = QueryableAttribute @staticmethod def _set_dispatch(cls, dispatch_cls): @@ -2372,8 +2398,10 @@ class AttributeEvents(event.Events[_ET]): @classmethod def _accept_with( - cls, target: InstrumentedAttribute[_T], identifier: str - ) -> InstrumentedAttribute[_T]: + cls, + target: Union[QueryableAttribute[Any], Type[QueryableAttribute[Any]]], + identifier: str, + ) -> Union[QueryableAttribute[Any], Type[QueryableAttribute[Any]]]: # TODO: coverage if isinstance(target, interfaces.MapperProperty): return getattr(target.parent.class_, target.key) @@ -2381,9 +2409,9 @@ class AttributeEvents(event.Events[_ET]): return target @classmethod - def _listen( + def _listen( # type: ignore [override] cls, - event_key: _EventKey[_ET], + event_key: _EventKey[QueryableAttribute[Any]], active_history: bool = False, raw: bool = False, retval: bool = False, @@ -2400,7 +2428,7 @@ class AttributeEvents(event.Events[_ET]): def wrap(target: InstanceState[_O], *arg: Any, **kw: Any) -> Any: if not raw: - target = target.obj() + target = target.obj() # type: ignore [assignment] if not retval: if arg: value = arg[0] @@ -2424,7 +2452,7 @@ class AttributeEvents(event.Events[_ET]): if propagate: manager = instrumentation.manager_of_class(target.class_) - for mgr in manager.subclass_managers(True): + for mgr in manager.subclass_managers(True): # type: ignore [no-untyped-call] # noqa: E501 event_key.with_dispatch_target(mgr[target.key]).base_listen( propagate=True ) @@ -2868,7 +2896,7 @@ class AttributeEvents(event.Events[_ET]): """ -class QueryEvents(event.Events[_ET]): +class QueryEvents(event.Events[Query[Any]]): """Represent events within the construction of a :class:`_query.Query` object. @@ -2885,7 +2913,7 @@ class QueryEvents(event.Events[_ET]): """ _target_class_doc = "SomeQuery" - _dispatch_target = Query # type: ignore [assignment] + _dispatch_target = Query def before_compile(self, query): """Receive the :class:`_query.Query`