]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Typed generic Events
authorGleb Kisenkov <g.kisenkov@gmail.com>
Wed, 21 Dec 2022 22:47:29 +0000 (23:47 +0100)
committerGleb Kisenkov <g.kisenkov@gmail.com>
Wed, 21 Dec 2022 22:47:29 +0000 (23:47 +0100)
lib/sqlalchemy/orm/events.py

index b96f5b8d146f60449353ac09748ebf316c01689e..346f43a2b9b7ff2704418d96cea3f1aeada8c1e0 100644 (file)
@@ -47,8 +47,6 @@ if TYPE_CHECKING:
     from .instrumentation import ClassManager
     from ..event.registry import _ET
     from ..event.registry import _EventKey
-    from ..orm.attributes import InstrumentedAttribute
-    from ..orm.base import EventConstants
     from ..orm.decl_api import DeclarativeAttributeIntercept
     from ..orm.decl_api import DeclarativeMeta
     from ..orm.mapper import Mapper
@@ -57,7 +55,9 @@ if TYPE_CHECKING:
 _ET2 = TypeVar("_ET2", bound=EventTarget)
 
 
-class InstrumentationEvents(event.Events[_T]):
+class InstrumentationEvents(
+    event.Events[instrumentation.InstrumentationFactory]
+):
     """Events related to class instrumentation events.
 
     The listeners here support being established against
@@ -80,12 +80,22 @@ class InstrumentationEvents(event.Events[_T]):
     """
 
     _target_class_doc = "SomeBaseClass"
-    _dispatch_target = instrumentation.InstrumentationFactory  # type: ignore [assignment] # noqa: E501
+    _dispatch_target = instrumentation.InstrumentationFactory
 
     @classmethod
     def _accept_with(
-        cls, target: Union[_ET, Type[_ET]], identifier: str
-    ) -> Optional[Union[_ET, Type[_ET]]]:
+        cls,
+        target: Union[
+            instrumentation.InstrumentationFactory,
+            Type[instrumentation.InstrumentationFactory],
+        ],
+        identifier: str,
+    ) -> Optional[
+        Union[
+            instrumentation.InstrumentationFactory,
+            Type[instrumentation.InstrumentationFactory],
+        ]
+    ]:
         if isinstance(target, type):
             return _InstrumentationEventsHold(target)  # type: ignore [return-value] # noqa: E501
         else:
@@ -172,7 +182,7 @@ class _InstrumentationEventsHold:
     dispatch = event.dispatcher(InstrumentationEvents)
 
 
-class InstanceEvents(event.Events[_ET]):
+class InstanceEvents(event.Events[instrumentation.ClassManager[Any]]):
     """Define events specific to object lifecycle.
 
     e.g.::
@@ -222,7 +232,7 @@ class InstanceEvents(event.Events[_ET]):
 
     _target_class_doc = "SomeClass"
 
-    _dispatch_target = instrumentation.ClassManager  # type: ignore [assignment] # noqa: E501
+    _dispatch_target = instrumentation.ClassManager
 
     @classmethod
     def _new_classmanager_instance(
@@ -235,15 +245,25 @@ class InstanceEvents(event.Events[_ET]):
     @classmethod
     @util.preload_module("sqlalchemy.orm")
     def _accept_with(
-        cls, target: Union[_ET, Type[_ET]], identifier: str
-    ) -> Optional[Union[_ET, Type[_ET]]]:
+        cls,
+        target: Union[
+            instrumentation.ClassManager[Any],
+            Type[instrumentation.ClassManager[Any]],
+        ],
+        identifier: str,
+    ) -> Optional[
+        Union[
+            instrumentation.ClassManager[Any],
+            Type[instrumentation.ClassManager[Any]],
+        ]
+    ]:
         orm = util.preloaded.orm
 
         if isinstance(target, instrumentation.ClassManager):
             return target
         elif isinstance(target, mapperlib.Mapper):
             return target.class_manager
-        elif target is orm.mapper:
+        elif target is orm.mapper:  # type: ignore [attr-defined]
             util.warn_deprecated(
                 "The `sqlalchemy.orm.mapper()` symbol is deprecated and "
                 "will be removed in a future release. For the mapper-wide "
@@ -259,13 +279,13 @@ class InstanceEvents(event.Events[_ET]):
                 if manager:
                     return manager
                 else:
-                    return _InstanceEventsHold(target)
+                    return _InstanceEventsHold(target)  # type: ignore [return-value] # noqa: E501
         return None
 
     @classmethod
     def _listen(
         cls,
-        event_key: _EventKey[_ET],
+        event_key: _EventKey[instrumentation.ClassManager[Any]],
         raw: bool = False,
         propagate: bool = False,
         restore_load_context: bool = False,
@@ -279,7 +299,7 @@ class InstanceEvents(event.Events[_ET]):
                 state: InstanceState[_O], *arg: Any, **kw: Any
             ) -> Optional[Any]:
                 if not raw:
-                    target = state.obj()
+                    target: Any = state.obj()
                 else:
                     target = state
                 if restore_load_context:
@@ -597,6 +617,8 @@ class _EventsHold(event.RefCollection[_ET]):
 
     """
 
+    all_holds: weakref.WeakKeyDictionary[Any, Any]
+
     def __init__(
         self,
         class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type],
@@ -683,18 +705,20 @@ class _EventsHold(event.RefCollection[_ET]):
 
 
 class _InstanceEventsHold(_EventsHold[_ET]):
-    all_holds = weakref.WeakKeyDictionary()
+    all_holds: weakref.WeakKeyDictionary[
+        Any, Any
+    ] = weakref.WeakKeyDictionary()
 
     def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]:
         return instrumentation.opt_manager_of_class(class_)
 
-    class HoldInstanceEvents(_EventsHold.HoldEvents[_ET], InstanceEvents[_ET]):
+    class HoldInstanceEvents(_EventsHold.HoldEvents[_ET], InstanceEvents):  # type: ignore [misc] # noqa: E501
         pass
 
     dispatch = event.dispatcher(HoldInstanceEvents)
 
 
-class MapperEvents(event.Events[_ET]):
+class MapperEvents(event.Events[mapperlib.Mapper[Any]]):
     """Define events specific to mappings.
 
     e.g.::
@@ -774,11 +798,13 @@ class MapperEvents(event.Events[_ET]):
     @classmethod
     @util.preload_module("sqlalchemy.orm")
     def _accept_with(
-        cls, target: Union[DeclarativeMeta, Mapper[_O], type], identifier: str
-    ) -> Union[_MapperEventsHold[_ET], Mapper[_O], type]:
+        cls,
+        target: Union[mapperlib.Mapper[Any], Type[mapperlib.Mapper[Any]]],
+        identifier: str,
+    ) -> Optional[Union[mapperlib.Mapper[Any], Type[mapperlib.Mapper[Any]]]]:
         orm = util.preloaded.orm
 
-        if target is orm.mapper:
+        if target is orm.mapper:  # type: ignore [attr-defined]
             util.warn_deprecated(
                 "The `sqlalchemy.orm.mapper()` symbol is deprecated and "
                 "will be removed in a future release. For the mapper-wide "
@@ -833,10 +859,10 @@ class MapperEvents(event.Events[_ET]):
                 except ValueError:
                     target_index = None
 
-            def wrap(*arg: Any, **kw: Any) -> EventConstants:
+            def wrap(*arg: Any, **kw: Any) -> Any:
                 if not raw and target_index is not None:
-                    arg = list(arg)
-                    arg[target_index] = arg[target_index].obj()
+                    arg = list(arg)  # type: ignore [assignment]
+                    arg[target_index] = arg[target_index].obj()  # type: ignore [index] # noqa: E501
                 if not retval:
                     fn(*arg, **kw)
                     return interfaces.EXT_CONTINUE
@@ -1401,7 +1427,7 @@ class _MapperEventsHold(_EventsHold[_ET]):
     ) -> Optional[Mapper[_T]]:
         return _mapper_or_none(class_)
 
-    class HoldMapperEvents(_EventsHold.HoldEvents[_ET], MapperEvents[_ET]):
+    class HoldMapperEvents(_EventsHold.HoldEvents[_ET], MapperEvents):  # type: ignore [misc] # noqa: E501
         pass
 
     dispatch = event.dispatcher(HoldMapperEvents)
@@ -1455,14 +1481,14 @@ class SessionEvents(event.Events[Session]):
 
     _dispatch_target = Session
 
-    def _lifecycle_event(
+    def _lifecycle_event(  # type: ignore [misc]
         fn: Callable[[SessionEvents, Session, Any], None]
     ) -> Callable[[Session, Any], None]:
         _sessionevents_lifecycle_event_names.add(fn.__name__)
-        return fn
+        return fn  # type: ignore [return-value]
 
     @classmethod
-    def _accept_with(
+    def _accept_with(  # type: ignore [return]
         cls, target: Any, identifier: str
     ) -> Union[Session, type]:
         if isinstance(target, scoped_session):
@@ -1490,7 +1516,7 @@ class SessionEvents(event.Events[Session]):
             target._no_async_engine_events()
         else:
             # allows alternate SessionEvents-like-classes to be consulted
-            return event.Events._accept_with(target, identifier)
+            return event.Events._accept_with(target, identifier)  # type: ignore [return-value] # noqa: E501
 
     @classmethod
     def _listen(
@@ -1521,9 +1547,9 @@ class SessionEvents(event.Events[Session]):
                         if target is None:
                             # existing behavior is that if the object is
                             # garbage collected, no event is emitted
-                            return
+                            return None
                     else:
-                        target = state
+                        target = state  # type: ignore [assignment]
                     if restore_load_context:
                         runid = state.runid
                     try:
@@ -2292,7 +2318,7 @@ class SessionEvents(event.Events[Session]):
         """
 
 
-class AttributeEvents(event.Events[_ET]):
+class AttributeEvents(event.Events[QueryableAttribute[Any]]):
     r"""Define events for object attributes.
 
     These are typically defined on the class-bound descriptor for the
@@ -2362,7 +2388,7 @@ class AttributeEvents(event.Events[_ET]):
     """
 
     _target_class_doc = "SomeClass.some_attribute"
-    _dispatch_target = QueryableAttribute  # type: ignore [assignment]
+    _dispatch_target = QueryableAttribute
 
     @staticmethod
     def _set_dispatch(cls, dispatch_cls):
@@ -2372,8 +2398,10 @@ class AttributeEvents(event.Events[_ET]):
 
     @classmethod
     def _accept_with(
-        cls, target: InstrumentedAttribute[_T], identifier: str
-    ) -> InstrumentedAttribute[_T]:
+        cls,
+        target: Union[QueryableAttribute[Any], Type[QueryableAttribute[Any]]],
+        identifier: str,
+    ) -> Union[QueryableAttribute[Any], Type[QueryableAttribute[Any]]]:
         # TODO: coverage
         if isinstance(target, interfaces.MapperProperty):
             return getattr(target.parent.class_, target.key)
@@ -2381,9 +2409,9 @@ class AttributeEvents(event.Events[_ET]):
             return target
 
     @classmethod
-    def _listen(
+    def _listen(  # type: ignore [override]
         cls,
-        event_key: _EventKey[_ET],
+        event_key: _EventKey[QueryableAttribute[Any]],
         active_history: bool = False,
         raw: bool = False,
         retval: bool = False,
@@ -2400,7 +2428,7 @@ class AttributeEvents(event.Events[_ET]):
 
             def wrap(target: InstanceState[_O], *arg: Any, **kw: Any) -> Any:
                 if not raw:
-                    target = target.obj()
+                    target = target.obj()  # type: ignore [assignment]
                 if not retval:
                     if arg:
                         value = arg[0]
@@ -2424,7 +2452,7 @@ class AttributeEvents(event.Events[_ET]):
         if propagate:
             manager = instrumentation.manager_of_class(target.class_)
 
-            for mgr in manager.subclass_managers(True):
+            for mgr in manager.subclass_managers(True):  # type: ignore [no-untyped-call] # noqa: E501
                 event_key.with_dispatch_target(mgr[target.key]).base_listen(
                     propagate=True
                 )
@@ -2868,7 +2896,7 @@ class AttributeEvents(event.Events[_ET]):
         """
 
 
-class QueryEvents(event.Events[_ET]):
+class QueryEvents(event.Events[Query[Any]]):
     """Represent events within the construction of a :class:`_query.Query`
     object.
 
@@ -2885,7 +2913,7 @@ class QueryEvents(event.Events[_ET]):
     """
 
     _target_class_doc = "SomeQuery"
-    _dispatch_target = Query  # type: ignore [assignment]
+    _dispatch_target = Query
 
     def before_compile(self, query):
         """Receive the :class:`_query.Query`