]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Type annotations for sqlalchemy.orm.events
authorGleb Kisenkov <g.kisenkov@gmail.com>
Wed, 28 Dec 2022 19:23:23 +0000 (14:23 -0500)
committersqla-tester <sqla-tester@sqlalchemy.org>
Wed, 28 Dec 2022 19:23:23 +0000 (14:23 -0500)
<!-- Provide a general summary of your proposed changes in the Title field above -->

### Description
An attempt to annotate `lib/sqlalchemy/orm/events.py` with type hints (issue #6810).

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [ ] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [x] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #9025
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9025
Pull-request-sha: a3fd2c0c3790164c433305ccc7ac6b73e813e037

Change-Id: I0808b6485504615fa20691dc8f4631d38bc89ab3

lib/sqlalchemy/orm/events.py

index 32de155a15232b58cc6196b0977ac25df17ec6ac..b182b91ca2291090a90f513d82cdf826b23056b4 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 """ORM event interfaces.
 
 from __future__ import annotations
 
 from typing import Any
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import Generic
+from typing import Iterable
 from typing import Optional
+from typing import Sequence
+from typing import Set
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
 import weakref
 
 from . import instrumentation
@@ -23,6 +31,10 @@ from . import mapperlib
 from .attributes import QueryableAttribute
 from .base import _mapper_or_none
 from .base import NO_KEY
+from .instrumentation import ClassManager
+from .instrumentation import InstrumentationFactory
+from .query import BulkDelete
+from .query import BulkUpdate
 from .query import Query
 from .scoping import scoped_session
 from .session import Session
@@ -30,14 +42,38 @@ from .session import sessionmaker
 from .. import event
 from .. import exc
 from .. import util
+from ..event import EventTarget
+from ..event.registry import _ET
 from ..util.compat import inspect_getfullargspec
 
 if TYPE_CHECKING:
-    from ._typing import _O
-    from .instrumentation import ClassManager
-
+    from weakref import ReferenceType
 
-class InstrumentationEvents(event.Events):
+    from ._typing import _InstanceDict
+    from ._typing import _InternalEntityType
+    from ._typing import _O
+    from ._typing import _T
+    from .attributes import Event
+    from .base import EventConstants
+    from .session import ORMExecuteState
+    from .session import SessionTransaction
+    from .unitofwork import UOWTransaction
+    from ..engine import Connection
+    from ..event.base import _Dispatch
+    from ..event.base import _HasEventsDispatch
+    from ..event.registry import _EventKey
+    from ..orm.collections import CollectionAdapter
+    from ..orm.context import QueryContext
+    from ..orm.decl_api import DeclarativeAttributeIntercept
+    from ..orm.decl_api import DeclarativeMeta
+    from ..orm.mapper import Mapper
+    from ..orm.state import InstanceState
+
+_KT = TypeVar("_KT", bound=Any)
+_ET2 = TypeVar("_ET2", bound=EventTarget)
+
+
+class InstrumentationEvents(event.Events[InstrumentationFactory]):
     """Events related to class instrumentation events.
 
     The listeners here support being established against
@@ -60,24 +96,38 @@ class InstrumentationEvents(event.Events):
     """
 
     _target_class_doc = "SomeBaseClass"
-    _dispatch_target = instrumentation.InstrumentationFactory
+    _dispatch_target = InstrumentationFactory
 
     @classmethod
-    def _accept_with(cls, target, identifier):
+    def _accept_with(
+        cls,
+        target: Union[
+            InstrumentationFactory,
+            Type[InstrumentationFactory],
+        ],
+        identifier: str,
+    ) -> Optional[
+        Union[
+            InstrumentationFactory,
+            Type[InstrumentationFactory],
+        ]
+    ]:
         if isinstance(target, type):
-            return _InstrumentationEventsHold(target)
+            return _InstrumentationEventsHold(target)  # type: ignore [return-value] # noqa: E501
         else:
             return None
 
     @classmethod
-    def _listen(cls, event_key, propagate=True, **kw):
+    def _listen(
+        cls, event_key: _EventKey[_T], propagate: bool = True, **kw: Any
+    ) -> None:
         target, identifier, fn = (
             event_key.dispatch_target,
             event_key.identifier,
             event_key._listen_fn,
         )
 
-        def listen(target_cls, *arg):
+        def listen(target_cls: type, *arg: Any) -> Optional[Any]:
             listen_cls = target()
 
             # if weakref were collected, however this is not something
@@ -91,9 +141,11 @@ class InstrumentationEvents(event.Events):
                 return fn(target_cls, *arg)
             elif not propagate and target_cls is listen_cls:
                 return fn(target_cls, *arg)
+            else:
+                return None
 
-        def remove(ref):
-            key = event.registry._EventKey(
+        def remove(ref: ReferenceType[_T]) -> None:
+            key = event.registry._EventKey(  # type: ignore [type-var]
                 None,
                 identifier,
                 listen,
@@ -110,11 +162,11 @@ class InstrumentationEvents(event.Events):
         ).with_wrapper(listen).base_listen(**kw)
 
     @classmethod
-    def _clear(cls):
+    def _clear(cls) -> None:
         super()._clear()
         instrumentation._instrumentation_factory.dispatch._clear()
 
-    def class_instrument(self, cls):
+    def class_instrument(self, cls: ClassManager[_O]) -> None:
         """Called after the given class is instrumented.
 
         To get at the :class:`.ClassManager`, use
@@ -122,7 +174,7 @@ class InstrumentationEvents(event.Events):
 
         """
 
-    def class_uninstrument(self, cls):
+    def class_uninstrument(self, cls: ClassManager[_O]) -> None:
         """Called before the given class is uninstrumented.
 
         To get at the :class:`.ClassManager`, use
@@ -130,7 +182,9 @@ class InstrumentationEvents(event.Events):
 
         """
 
-    def attribute_instrument(self, cls, key, inst):
+    def attribute_instrument(
+        self, cls: ClassManager[_O], key: _KT, inst: _O
+    ) -> None:
         """Called when an attribute is instrumented."""
 
 
@@ -140,13 +194,13 @@ class _InstrumentationEventsHold:
 
     """
 
-    def __init__(self, class_):
+    def __init__(self, class_: type) -> None:
         self.class_ = class_
 
     dispatch = event.dispatcher(InstrumentationEvents)
 
 
-class InstanceEvents(event.Events):
+class InstanceEvents(event.Events[ClassManager[Any]]):
     """Define events specific to object lifecycle.
 
     e.g.::
@@ -196,56 +250,69 @@ class InstanceEvents(event.Events):
 
     _target_class_doc = "SomeClass"
 
-    _dispatch_target = instrumentation.ClassManager
+    _dispatch_target = ClassManager
 
     @classmethod
-    def _new_classmanager_instance(cls, class_, classmanager):
+    def _new_classmanager_instance(
+        cls,
+        class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type],
+        classmanager: ClassManager[_O],
+    ) -> None:
         _InstanceEventsHold.populate(class_, classmanager)
 
     @classmethod
     @util.preload_module("sqlalchemy.orm")
-    def _accept_with(cls, target, identifier):
+    def _accept_with(
+        cls,
+        target: Union[
+            ClassManager[Any],
+            Type[ClassManager[Any]],
+        ],
+        identifier: str,
+    ) -> Optional[Union[ClassManager[Any], Type[ClassManager[Any]]]]:
         orm = util.preloaded.orm
 
-        if isinstance(target, instrumentation.ClassManager):
+        if isinstance(target, ClassManager):
             return target
         elif isinstance(target, mapperlib.Mapper):
             return target.class_manager
-        elif target is orm.mapper:
+        elif target is orm.mapper:  # type: ignore [attr-defined]
             util.warn_deprecated(
                 "The `sqlalchemy.orm.mapper()` symbol is deprecated and "
                 "will be removed in a future release. For the mapper-wide "
                 "event target, use the 'sqlalchemy.orm.Mapper' class.",
                 "2.0",
             )
-            return instrumentation.ClassManager
+            return ClassManager
         elif isinstance(target, type):
             if issubclass(target, mapperlib.Mapper):
-                return instrumentation.ClassManager
+                return ClassManager
             else:
                 manager = instrumentation.opt_manager_of_class(target)
                 if manager:
                     return manager
                 else:
-                    return _InstanceEventsHold(target)
+                    return _InstanceEventsHold(target)  # type: ignore [return-value] # noqa: E501
         return None
 
     @classmethod
     def _listen(
         cls,
-        event_key,
-        raw=False,
-        propagate=False,
-        restore_load_context=False,
-        **kw,
-    ):
+        event_key: _EventKey[ClassManager[Any]],
+        raw: bool = False,
+        propagate: bool = False,
+        restore_load_context: bool = False,
+        **kw: Any,
+    ) -> None:
         target, fn = (event_key.dispatch_target, event_key._listen_fn)
 
         if not raw or restore_load_context:
 
-            def wrap(state, *arg, **kw):
+            def wrap(
+                state: InstanceState[_O], *arg: Any, **kw: Any
+            ) -> Optional[Any]:
                 if not raw:
-                    target = state.obj()
+                    target: Any = state.obj()
                 else:
                     target = state
                 if restore_load_context:
@@ -265,11 +332,11 @@ class InstanceEvents(event.Events):
                 event_key.with_dispatch_target(mgr).base_listen(propagate=True)
 
     @classmethod
-    def _clear(cls):
+    def _clear(cls) -> None:
         super()._clear()
         _InstanceEventsHold._clear()
 
-    def first_init(self, manager, cls):
+    def first_init(self, manager: ClassManager[_O], cls: Type[_O]) -> None:
         """Called when the first instance of a particular mapping is called.
 
         This event is called when the ``__init__`` method of a class
@@ -279,7 +346,7 @@ class InstanceEvents(event.Events):
 
         """
 
-    def init(self, target, args, kwargs):
+    def init(self, target: _O, args: Any, kwargs: Any) -> None:
         """Receive an instance when its constructor is called.
 
         This method is only called during a userland construction of
@@ -310,7 +377,7 @@ class InstanceEvents(event.Events):
 
         """
 
-    def init_failure(self, target, args, kwargs):
+    def init_failure(self, target: _O, args: Any, kwargs: Any) -> None:
         """Receive an instance when its constructor has been called,
         and raised an exception.
 
@@ -343,7 +410,9 @@ class InstanceEvents(event.Events):
 
         """
 
-    def _sa_event_merge_wo_load(self, target, context):
+    def _sa_event_merge_wo_load(
+        self, target: _O, context: QueryContext
+    ) -> None:
         """receive an object instance after it was the subject of a merge()
         call, when load=False was passed.
 
@@ -360,7 +429,7 @@ class InstanceEvents(event.Events):
 
         """
 
-    def load(self, target, context):
+    def load(self, target: _O, context: QueryContext) -> None:
         """Receive an object instance after it has been created via
         ``__new__``, and after initial attribute population has
         occurred.
@@ -435,7 +504,9 @@ class InstanceEvents(event.Events):
 
         """
 
-    def refresh(self, target, context, attrs):
+    def refresh(
+        self, target: _O, context: QueryContext, attrs: Optional[Iterable[str]]
+    ) -> None:
         """Receive an object instance after one or more attributes have
         been refreshed from a query.
 
@@ -467,7 +538,12 @@ class InstanceEvents(event.Events):
 
         """
 
-    def refresh_flush(self, target, flush_context, attrs):
+    def refresh_flush(
+        self,
+        target: _O,
+        flush_context: UOWTransaction,
+        attrs: Optional[Iterable[str]],
+    ) -> None:
         """Receive an object instance after one or more attributes that
         contain a column-level default or onupdate handler have been refreshed
         during persistence of the object's state.
@@ -509,7 +585,7 @@ class InstanceEvents(event.Events):
 
         """
 
-    def expire(self, target, attrs):
+    def expire(self, target: _O, attrs: Optional[Iterable[str]]) -> None:
         """Receive an object instance after its attributes or some subset
         have been expired.
 
@@ -526,7 +602,7 @@ class InstanceEvents(event.Events):
 
         """
 
-    def pickle(self, target, state_dict):
+    def pickle(self, target: _O, state_dict: _InstanceDict) -> None:
         """Receive an object instance when its associated state is
         being pickled.
 
@@ -540,7 +616,7 @@ class InstanceEvents(event.Events):
 
         """
 
-    def unpickle(self, target, state_dict):
+    def unpickle(self, target: _O, state_dict: _InstanceDict) -> None:
         """Receive an object instance after its associated state has
         been unpickled.
 
@@ -555,7 +631,7 @@ class InstanceEvents(event.Events):
         """
 
 
-class _EventsHold(event.RefCollection):
+class _EventsHold(event.RefCollection[_ET]):
     """Hold onto listeners against unmapped, uninstrumented classes.
 
     Establish _listen() for that class' mapper/instrumentation when
@@ -563,20 +639,30 @@ class _EventsHold(event.RefCollection):
 
     """
 
-    def __init__(self, class_):
+    all_holds: weakref.WeakKeyDictionary[Any, Any]
+
+    def __init__(
+        self,
+        class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type],
+    ) -> None:
         self.class_ = class_
 
     @classmethod
-    def _clear(cls):
+    def _clear(cls) -> None:
         cls.all_holds.clear()
 
-    class HoldEvents:
-        _dispatch_target = None
+    class HoldEvents(Generic[_ET2]):
+        _dispatch_target: Optional[Type[_ET2]] = None
 
         @classmethod
         def _listen(
-            cls, event_key, raw=False, propagate=False, retval=False, **kw
-        ):
+            cls,
+            event_key: _EventKey[_ET2],
+            raw: bool = False,
+            propagate: bool = False,
+            retval: bool = False,
+            **kw: Any,
+        ) -> None:
             target = event_key.dispatch_target
 
             if target.class_ in target.all_holds:
@@ -606,7 +692,7 @@ class _EventsHold(event.RefCollection):
                             raw=raw, propagate=False, retval=retval, **kw
                         )
 
-    def remove(self, event_key):
+    def remove(self, event_key: _EventKey[_ET]) -> None:
         target = event_key.dispatch_target
 
         if isinstance(target, _EventsHold):
@@ -614,7 +700,11 @@ class _EventsHold(event.RefCollection):
             del collection[event_key._key]
 
     @classmethod
-    def populate(cls, class_, subject):
+    def populate(
+        cls,
+        class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type],
+        subject: Union[ClassManager[_O], Mapper[_O]],
+    ) -> None:
         for subclass in class_.__mro__:
             if subclass in cls.all_holds:
                 collection = cls.all_holds[subclass]
@@ -636,19 +726,21 @@ class _EventsHold(event.RefCollection):
                         )
 
 
-class _InstanceEventsHold(_EventsHold):
-    all_holds = weakref.WeakKeyDictionary()
+class _InstanceEventsHold(_EventsHold[_ET]):
+    all_holds: weakref.WeakKeyDictionary[
+        Any, Any
+    ] = weakref.WeakKeyDictionary()
 
     def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]:
         return instrumentation.opt_manager_of_class(class_)
 
-    class HoldInstanceEvents(_EventsHold.HoldEvents, InstanceEvents):
+    class HoldInstanceEvents(_EventsHold.HoldEvents[_ET], InstanceEvents):  # type: ignore [misc] # noqa: E501
         pass
 
     dispatch = event.dispatcher(HoldInstanceEvents)
 
 
-class MapperEvents(event.Events):
+class MapperEvents(event.Events[mapperlib.Mapper[Any]]):
     """Define events specific to mappings.
 
     e.g.::
@@ -718,15 +810,23 @@ class MapperEvents(event.Events):
     _dispatch_target = mapperlib.Mapper
 
     @classmethod
-    def _new_mapper_instance(cls, class_, mapper):
+    def _new_mapper_instance(
+        cls,
+        class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type],
+        mapper: Mapper[_O],
+    ) -> None:
         _MapperEventsHold.populate(class_, mapper)
 
     @classmethod
     @util.preload_module("sqlalchemy.orm")
-    def _accept_with(cls, target, identifier):
+    def _accept_with(
+        cls,
+        target: Union[mapperlib.Mapper[Any], Type[mapperlib.Mapper[Any]]],
+        identifier: str,
+    ) -> Optional[Union[mapperlib.Mapper[Any], Type[mapperlib.Mapper[Any]]]]:
         orm = util.preloaded.orm
 
-        if target is orm.mapper:
+        if target is orm.mapper:  # type: ignore [attr-defined]
             util.warn_deprecated(
                 "The `sqlalchemy.orm.mapper()` symbol is deprecated and "
                 "will be removed in a future release. For the mapper-wide "
@@ -748,8 +848,13 @@ class MapperEvents(event.Events):
 
     @classmethod
     def _listen(
-        cls, event_key, raw=False, retval=False, propagate=False, **kw
-    ):
+        cls,
+        event_key: _EventKey[_ET],
+        raw: bool = False,
+        retval: bool = False,
+        propagate: bool = False,
+        **kw: Any,
+    ) -> None:
         target, identifier, fn = (
             event_key.dispatch_target,
             event_key.identifier,
@@ -776,10 +881,10 @@ class MapperEvents(event.Events):
                 except ValueError:
                     target_index = None
 
-            def wrap(*arg, **kw):
+            def wrap(*arg: Any, **kw: Any) -> Any:
                 if not raw and target_index is not None:
-                    arg = list(arg)
-                    arg[target_index] = arg[target_index].obj()
+                    arg = list(arg)  # type: ignore [assignment]
+                    arg[target_index] = arg[target_index].obj()  # type: ignore [index] # noqa: E501
                 if not retval:
                     fn(*arg, **kw)
                     return interfaces.EXT_CONTINUE
@@ -797,11 +902,11 @@ class MapperEvents(event.Events):
             event_key.base_listen(**kw)
 
     @classmethod
-    def _clear(cls):
+    def _clear(cls) -> None:
         super()._clear()
         _MapperEventsHold._clear()
 
-    def instrument_class(self, mapper, class_):
+    def instrument_class(self, mapper: Mapper[_O], class_: Type[_O]) -> None:
         r"""Receive a class when the mapper is first constructed,
         before instrumentation is applied to the mapped class.
 
@@ -824,7 +929,9 @@ class MapperEvents(event.Events):
 
         """
 
-    def before_mapper_configured(self, mapper, class_):
+    def before_mapper_configured(
+        self, mapper: Mapper[_O], class_: Type[_O]
+    ) -> None:
         """Called right before a specific mapper is to be configured.
 
         This event is intended to allow a specific mapper to be skipped during
@@ -872,7 +979,7 @@ class MapperEvents(event.Events):
 
         """
 
-    def mapper_configured(self, mapper, class_):
+    def mapper_configured(self, mapper: Mapper[_O], class_: Type[_O]) -> None:
         r"""Called when a specific mapper has completed its own configuration
         within the scope of the :func:`.configure_mappers` call.
 
@@ -926,7 +1033,7 @@ class MapperEvents(event.Events):
         """
         # TODO: need coverage for this event
 
-    def before_configured(self):
+    def before_configured(self) -> None:
         """Called before a series of mappers have been configured.
 
         The :meth:`.MapperEvents.before_configured` event is invoked
@@ -981,7 +1088,7 @@ class MapperEvents(event.Events):
 
         """
 
-    def after_configured(self):
+    def after_configured(self) -> None:
         """Called after a series of mappers have been configured.
 
         The :meth:`.MapperEvents.after_configured` event is invoked
@@ -1034,7 +1141,9 @@ class MapperEvents(event.Events):
 
         """
 
-    def before_insert(self, mapper, connection, target):
+    def before_insert(
+        self, mapper: Mapper[_O], connection: Connection, target: _O
+    ) -> None:
         """Receive an object instance before an INSERT statement
         is emitted corresponding to that instance.
 
@@ -1080,7 +1189,9 @@ class MapperEvents(event.Events):
 
         """
 
-    def after_insert(self, mapper, connection, target):
+    def after_insert(
+        self, mapper: Mapper[_O], connection: Connection, target: _O
+    ) -> None:
         """Receive an object instance after an INSERT statement
         is emitted corresponding to that instance.
 
@@ -1126,7 +1237,9 @@ class MapperEvents(event.Events):
 
         """
 
-    def before_update(self, mapper, connection, target):
+    def before_update(
+        self, mapper: Mapper[_O], connection: Connection, target: _O
+    ) -> None:
         """Receive an object instance before an UPDATE statement
         is emitted corresponding to that instance.
 
@@ -1191,7 +1304,9 @@ class MapperEvents(event.Events):
 
         """
 
-    def after_update(self, mapper, connection, target):
+    def after_update(
+        self, mapper: Mapper[_O], connection: Connection, target: _O
+    ) -> None:
         """Receive an object instance after an UPDATE statement
         is emitted corresponding to that instance.
 
@@ -1255,7 +1370,9 @@ class MapperEvents(event.Events):
 
         """
 
-    def before_delete(self, mapper, connection, target):
+    def before_delete(
+        self, mapper: Mapper[_O], connection: Connection, target: _O
+    ) -> None:
         """Receive an object instance before a DELETE statement
         is emitted corresponding to that instance.
 
@@ -1295,7 +1412,9 @@ class MapperEvents(event.Events):
 
         """
 
-    def after_delete(self, mapper, connection, target):
+    def after_delete(
+        self, mapper: Mapper[_O], connection: Connection, target: _O
+    ) -> None:
         """Receive an object instance after a DELETE statement
         has been emitted corresponding to that instance.
 
@@ -1336,19 +1455,21 @@ class MapperEvents(event.Events):
         """
 
 
-class _MapperEventsHold(_EventsHold):
+class _MapperEventsHold(_EventsHold[_ET]):
     all_holds = weakref.WeakKeyDictionary()
 
-    def resolve(self, class_):
+    def resolve(
+        self, class_: Union[Type[_T], _InternalEntityType[_T]]
+    ) -> Optional[Mapper[_T]]:
         return _mapper_or_none(class_)
 
-    class HoldMapperEvents(_EventsHold.HoldEvents, MapperEvents):
+    class HoldMapperEvents(_EventsHold.HoldEvents[_ET], MapperEvents):  # type: ignore [misc] # noqa: E501
         pass
 
     dispatch = event.dispatcher(HoldMapperEvents)
 
 
-_sessionevents_lifecycle_event_names = set()
+_sessionevents_lifecycle_event_names: Set[str] = set()
 
 
 class SessionEvents(event.Events[Session]):
@@ -1396,12 +1517,16 @@ class SessionEvents(event.Events[Session]):
 
     _dispatch_target = Session
 
-    def _lifecycle_event(fn):
+    def _lifecycle_event(  # type: ignore [misc]
+        fn: Callable[[SessionEvents, Session, Any], None]
+    ) -> Callable[[SessionEvents, Session, Any], None]:
         _sessionevents_lifecycle_event_names.add(fn.__name__)
         return fn
 
     @classmethod
-    def _accept_with(cls, target, identifier):
+    def _accept_with(  # type: ignore [return]
+        cls, target: Any, identifier: str
+    ) -> Union[Session, type]:
         if isinstance(target, scoped_session):
 
             target = target.session_factory
@@ -1427,7 +1552,7 @@ class SessionEvents(event.Events[Session]):
             target._no_async_engine_events()
         else:
             # allows alternate SessionEvents-like-classes to be consulted
-            return event.Events._accept_with(target, identifier)
+            return event.Events._accept_with(target, identifier)  # type: ignore [return-value] # noqa: E501
 
     @classmethod
     def _listen(
@@ -1447,15 +1572,20 @@ class SessionEvents(event.Events[Session]):
 
                 fn = event_key._listen_fn
 
-                def wrap(session, state, *arg, **kw):
+                def wrap(
+                    session: Session,
+                    state: InstanceState[_O],
+                    *arg: Any,
+                    **kw: Any,
+                ) -> Optional[Any]:
                     if not raw:
                         target = state.obj()
                         if target is None:
                             # existing behavior is that if the object is
                             # garbage collected, no event is emitted
-                            return
+                            return None
                     else:
-                        target = state
+                        target = state  # type: ignore [assignment]
                     if restore_load_context:
                         runid = state.runid
                     try:
@@ -1468,7 +1598,7 @@ class SessionEvents(event.Events[Session]):
 
         event_key.base_listen(**kw)
 
-    def do_orm_execute(self, orm_execute_state):
+    def do_orm_execute(self, orm_execute_state: ORMExecuteState) -> None:
         """Intercept statement executions that occur on behalf of an
         ORM :class:`.Session` object.
 
@@ -1541,7 +1671,9 @@ class SessionEvents(event.Events[Session]):
 
         """
 
-    def after_transaction_create(self, session, transaction):
+    def after_transaction_create(
+        self, session: Session, transaction: SessionTransaction
+    ) -> None:
         """Execute when a new :class:`.SessionTransaction` is created.
 
         This event differs from :meth:`~.SessionEvents.after_begin`
@@ -1583,7 +1715,9 @@ class SessionEvents(event.Events[Session]):
 
         """
 
-    def after_transaction_end(self, session, transaction):
+    def after_transaction_end(
+        self, session: Session, transaction: SessionTransaction
+    ) -> None:
         """Execute when the span of a :class:`.SessionTransaction` ends.
 
         This event differs from :meth:`~.SessionEvents.after_commit`
@@ -1622,7 +1756,7 @@ class SessionEvents(event.Events[Session]):
 
         """
 
-    def before_commit(self, session):
+    def before_commit(self, session: Session) -> None:
         """Execute before commit is called.
 
         .. note::
@@ -1650,7 +1784,7 @@ class SessionEvents(event.Events[Session]):
 
         """
 
-    def after_commit(self, session):
+    def after_commit(self, session: Session) -> None:
         """Execute after a commit has occurred.
 
         .. note::
@@ -1686,7 +1820,7 @@ class SessionEvents(event.Events[Session]):
 
         """
 
-    def after_rollback(self, session):
+    def after_rollback(self, session: Session) -> None:
         """Execute after a real DBAPI rollback has occurred.
 
         Note that this event only fires when the *actual* rollback against
@@ -1704,7 +1838,9 @@ class SessionEvents(event.Events[Session]):
 
         """
 
-    def after_soft_rollback(self, session, previous_transaction):
+    def after_soft_rollback(
+        self, session: Session, previous_transaction: SessionTransaction
+    ) -> None:
         """Execute after any rollback has occurred, including "soft"
         rollbacks that don't actually emit at the DBAPI level.
 
@@ -1730,7 +1866,12 @@ class SessionEvents(event.Events[Session]):
 
         """
 
-    def before_flush(self, session, flush_context, instances):
+    def before_flush(
+        self,
+        session: Session,
+        flush_context: UOWTransaction,
+        instances: Optional[Sequence[_O]],
+    ) -> None:
         """Execute before flush process has started.
 
         :param session: The target :class:`.Session`.
@@ -1750,7 +1891,9 @@ class SessionEvents(event.Events[Session]):
 
         """
 
-    def after_flush(self, session, flush_context):
+    def after_flush(
+        self, session: Session, flush_context: UOWTransaction
+    ) -> None:
         """Execute after flush has completed, but before commit has been
         called.
 
@@ -1781,7 +1924,9 @@ class SessionEvents(event.Events[Session]):
 
         """
 
-    def after_flush_postexec(self, session, flush_context):
+    def after_flush_postexec(
+        self, session: Session, flush_context: UOWTransaction
+    ) -> None:
         """Execute after flush has completed, and after the post-exec
         state occurs.
 
@@ -1805,7 +1950,12 @@ class SessionEvents(event.Events[Session]):
 
         """
 
-    def after_begin(self, session, transaction, connection):
+    def after_begin(
+        self,
+        session: Session,
+        transaction: SessionTransaction,
+        connection: Connection,
+    ) -> None:
         """Execute after a transaction is begun on a connection
 
         :param session: The target :class:`.Session`.
@@ -1826,7 +1976,7 @@ class SessionEvents(event.Events[Session]):
         """
 
     @_lifecycle_event
-    def before_attach(self, session, instance):
+    def before_attach(self, session: Session, instance: _O) -> None:
         """Execute before an instance is attached to a session.
 
         This is called before an add, delete or merge causes
@@ -1841,7 +1991,7 @@ class SessionEvents(event.Events[Session]):
         """
 
     @_lifecycle_event
-    def after_attach(self, session, instance):
+    def after_attach(self, session: Session, instance: _O) -> None:
         """Execute after an instance is attached to a session.
 
         This is called after an add, delete or merge.
@@ -1875,7 +2025,7 @@ class SessionEvents(event.Events[Session]):
             update_context.result,
         ),
     )
-    def after_bulk_update(self, update_context):
+    def after_bulk_update(self, update_context: _O) -> None:
         """Event for after the legacy :meth:`_orm.Query.update` method
         has been called.
 
@@ -1921,7 +2071,7 @@ class SessionEvents(event.Events[Session]):
             delete_context.result,
         ),
     )
-    def after_bulk_delete(self, delete_context):
+    def after_bulk_delete(self, delete_context: _O) -> None:
         """Event for after the legacy :meth:`_orm.Query.delete` method
         has been called.
 
@@ -1956,7 +2106,7 @@ class SessionEvents(event.Events[Session]):
         """
 
     @_lifecycle_event
-    def transient_to_pending(self, session, instance):
+    def transient_to_pending(self, session: Session, instance: _O) -> None:
         """Intercept the "transient to pending" transition for a specific
         object.
 
@@ -1978,7 +2128,7 @@ class SessionEvents(event.Events[Session]):
         """
 
     @_lifecycle_event
-    def pending_to_transient(self, session, instance):
+    def pending_to_transient(self, session: Session, instance: _O) -> None:
         """Intercept the "pending to transient" transition for a specific
         object.
 
@@ -2000,7 +2150,7 @@ class SessionEvents(event.Events[Session]):
         """
 
     @_lifecycle_event
-    def persistent_to_transient(self, session, instance):
+    def persistent_to_transient(self, session: Session, instance: _O) -> None:
         """Intercept the "persistent to transient" transition for a specific
         object.
 
@@ -2021,7 +2171,7 @@ class SessionEvents(event.Events[Session]):
         """
 
     @_lifecycle_event
-    def pending_to_persistent(self, session, instance):
+    def pending_to_persistent(self, session: Session, instance: _O) -> None:
         """Intercept the "pending to persistent"" transition for a specific
         object.
 
@@ -2044,7 +2194,7 @@ class SessionEvents(event.Events[Session]):
         """
 
     @_lifecycle_event
-    def detached_to_persistent(self, session, instance):
+    def detached_to_persistent(self, session: Session, instance: _O) -> None:
         """Intercept the "detached to persistent" transition for a specific
         object.
 
@@ -2081,7 +2231,7 @@ class SessionEvents(event.Events[Session]):
         """
 
     @_lifecycle_event
-    def loaded_as_persistent(self, session, instance):
+    def loaded_as_persistent(self, session: Session, instance: _O) -> None:
         """Intercept the "loaded as persistent" transition for a specific
         object.
 
@@ -2117,7 +2267,7 @@ class SessionEvents(event.Events[Session]):
         """
 
     @_lifecycle_event
-    def persistent_to_deleted(self, session, instance):
+    def persistent_to_deleted(self, session: Session, instance: _O) -> None:
         """Intercept the "persistent to deleted" transition for a specific
         object.
 
@@ -2150,7 +2300,7 @@ class SessionEvents(event.Events[Session]):
         """
 
     @_lifecycle_event
-    def deleted_to_persistent(self, session, instance):
+    def deleted_to_persistent(self, session: Session, instance: _O) -> None:
         """Intercept the "deleted to persistent" transition for a specific
         object.
 
@@ -2168,7 +2318,7 @@ class SessionEvents(event.Events[Session]):
         """
 
     @_lifecycle_event
-    def deleted_to_detached(self, session, instance):
+    def deleted_to_detached(self, session: Session, instance: _O) -> None:
         """Intercept the "deleted to detached" transition for a specific
         object.
 
@@ -2192,7 +2342,7 @@ class SessionEvents(event.Events[Session]):
         """
 
     @_lifecycle_event
-    def persistent_to_detached(self, session, instance):
+    def persistent_to_detached(self, session: Session, instance: _O) -> None:
         """Intercept the "persistent to detached" transition for a specific
         object.
 
@@ -2224,7 +2374,7 @@ class SessionEvents(event.Events[Session]):
         """
 
 
-class AttributeEvents(event.Events):
+class AttributeEvents(event.Events[QueryableAttribute[Any]]):
     r"""Define events for object attributes.
 
     These are typically defined on the class-bound descriptor for the
@@ -2297,13 +2447,19 @@ class AttributeEvents(event.Events):
     _dispatch_target = QueryableAttribute
 
     @staticmethod
-    def _set_dispatch(cls, dispatch_cls):
+    def _set_dispatch(
+        cls: Type[_HasEventsDispatch[Any]], dispatch_cls: Type[_Dispatch[Any]]
+    ) -> _Dispatch[Any]:
         dispatch = event.Events._set_dispatch(cls, dispatch_cls)
         dispatch_cls._active_history = False
         return dispatch
 
     @classmethod
-    def _accept_with(cls, target, identifier):
+    def _accept_with(
+        cls,
+        target: Union[QueryableAttribute[Any], Type[QueryableAttribute[Any]]],
+        identifier: str,
+    ) -> Union[QueryableAttribute[Any], Type[QueryableAttribute[Any]]]:
         # TODO: coverage
         if isinstance(target, interfaces.MapperProperty):
             return getattr(target.parent.class_, target.key)
@@ -2311,15 +2467,15 @@ class AttributeEvents(event.Events):
             return target
 
     @classmethod
-    def _listen(
+    def _listen(  # type: ignore [override]
         cls,
-        event_key,
-        active_history=False,
-        raw=False,
-        retval=False,
-        propagate=False,
-        include_key=False,
-    ):
+        event_key: _EventKey[QueryableAttribute[Any]],
+        active_history: bool = False,
+        raw: bool = False,
+        retval: bool = False,
+        propagate: bool = False,
+        include_key: bool = False,
+    ) -> None:
 
         target, fn = event_key.dispatch_target, event_key._listen_fn
 
@@ -2328,9 +2484,9 @@ class AttributeEvents(event.Events):
 
         if not raw or not retval or not include_key:
 
-            def wrap(target, *arg, **kw):
+            def wrap(target: InstanceState[_O], *arg: Any, **kw: Any) -> Any:
                 if not raw:
-                    target = target.obj()
+                    target = target.obj()  # type: ignore [assignment]
                 if not retval:
                     if arg:
                         value = arg[0]
@@ -2354,14 +2510,21 @@ class AttributeEvents(event.Events):
         if propagate:
             manager = instrumentation.manager_of_class(target.class_)
 
-            for mgr in manager.subclass_managers(True):
+            for mgr in manager.subclass_managers(True):  # type: ignore [no-untyped-call] # noqa: E501
                 event_key.with_dispatch_target(mgr[target.key]).base_listen(
                     propagate=True
                 )
                 if active_history:
                     mgr[target.key].dispatch._active_history = True
 
-    def append(self, target, value, initiator, *, key=NO_KEY):
+    def append(
+        self,
+        target: _O,
+        value: _T,
+        initiator: Event,
+        *,
+        key: EventConstants = NO_KEY,
+    ) -> Optional[_T]:
         """Receive a collection append event.
 
         The append event is invoked for each element as it is appended
@@ -2405,7 +2568,14 @@ class AttributeEvents(event.Events):
 
         """
 
-    def append_wo_mutation(self, target, value, initiator, *, key=NO_KEY):
+    def append_wo_mutation(
+        self,
+        target: _O,
+        value: _T,
+        initiator: Event,
+        *,
+        key: EventConstants = NO_KEY,
+    ) -> None:
         """Receive a collection append event where the collection was not
         actually mutated.
 
@@ -2447,7 +2617,14 @@ class AttributeEvents(event.Events):
 
         """
 
-    def bulk_replace(self, target, values, initiator, *, keys=None):
+    def bulk_replace(
+        self,
+        target: _O,
+        values: Iterable[_T],
+        initiator: Event,
+        *,
+        keys: Optional[Iterable[EventConstants]] = None,
+    ) -> None:
         """Receive a collection 'bulk replace' event.
 
         This event is invoked for a sequence of values as they are incoming
@@ -2510,7 +2687,14 @@ class AttributeEvents(event.Events):
 
         """
 
-    def remove(self, target, value, initiator, *, key=NO_KEY):
+    def remove(
+        self,
+        target: _O,
+        value: _T,
+        initiator: Event,
+        *,
+        key: EventConstants = NO_KEY,
+    ) -> None:
         """Receive a collection remove event.
 
         :param target: the object instance receiving the event.
@@ -2548,7 +2732,9 @@ class AttributeEvents(event.Events):
 
         """
 
-    def set(self, target, value, oldvalue, initiator):
+    def set(
+        self, target: _O, value: _T, oldvalue: _T, initiator: Event
+    ) -> None:
         """Receive a scalar set event.
 
         :param target: the object instance receiving the event.
@@ -2584,7 +2770,9 @@ class AttributeEvents(event.Events):
 
         """
 
-    def init_scalar(self, target, value, dict_):
+    def init_scalar(
+        self, target: _O, value: _T, dict_: Dict[Any, Any]
+    ) -> None:
         r"""Receive a scalar "init" event.
 
         This event is invoked when an uninitialized, unpersisted scalar
@@ -2706,7 +2894,12 @@ class AttributeEvents(event.Events):
 
         """
 
-    def init_collection(self, target, collection, collection_adapter):
+    def init_collection(
+        self,
+        target: _O,
+        collection: Type[Collection[Any]],
+        collection_adapter: CollectionAdapter,
+    ) -> None:
         """Receive a 'collection init' event.
 
         This event is triggered for a collection-based attribute, when
@@ -2747,7 +2940,12 @@ class AttributeEvents(event.Events):
 
         """
 
-    def dispose_collection(self, target, collection, collection_adapter):
+    def dispose_collection(
+        self,
+        target: _O,
+        collection: Collection[Any],
+        collection_adapter: CollectionAdapter,
+    ) -> None:
         """Receive a 'collection dispose' event.
 
         This event is triggered for a collection-based attribute when
@@ -2774,7 +2972,7 @@ class AttributeEvents(event.Events):
 
         """
 
-    def modified(self, target, initiator):
+    def modified(self, target: _O, initiator: Event) -> None:
         """Receive a 'modified' event.
 
         This event is triggered when the :func:`.attributes.flag_modified`
@@ -2798,7 +2996,7 @@ class AttributeEvents(event.Events):
         """
 
 
-class QueryEvents(event.Events):
+class QueryEvents(event.Events[Query[Any]]):
     """Represent events within the construction of a :class:`_query.Query`
     object.
 
@@ -2817,7 +3015,7 @@ class QueryEvents(event.Events):
     _target_class_doc = "SomeQuery"
     _dispatch_target = Query
 
-    def before_compile(self, query):
+    def before_compile(self, query: Query[Any]) -> None:
         """Receive the :class:`_query.Query`
         object before it is composed into a
         core :class:`_expression.Select` object.
@@ -2883,7 +3081,9 @@ class QueryEvents(event.Events):
 
         """
 
-    def before_compile_update(self, query, update_context):
+    def before_compile_update(
+        self, query: Query[Any], update_context: BulkUpdate
+    ) -> None:
         """Allow modifications to the :class:`_query.Query` object within
         :meth:`_query.Query.update`.
 
@@ -2933,7 +3133,9 @@ class QueryEvents(event.Events):
 
         """
 
-    def before_compile_delete(self, query, delete_context):
+    def before_compile_delete(
+        self, query: Query[Any], delete_context: BulkDelete
+    ) -> None:
         """Allow modifications to the :class:`_query.Query` object within
         :meth:`_query.Query.delete`.
 
@@ -2973,12 +3175,18 @@ class QueryEvents(event.Events):
         """
 
     @classmethod
-    def _listen(cls, event_key, retval=False, bake_ok=False, **kw):
+    def _listen(
+        cls,
+        event_key: _EventKey[_ET],
+        retval: bool = False,
+        bake_ok: bool = False,
+        **kw: Any,
+    ) -> None:
         fn = event_key._listen_fn
 
         if not retval:
 
-            def wrap(*arg, **kw):
+            def wrap(*arg: Any, **kw: Any) -> Any:
                 if not retval:
                     query = arg[0]
                     fn(*arg, **kw)
@@ -2989,11 +3197,11 @@ class QueryEvents(event.Events):
             event_key = event_key.with_wrapper(wrap)
         else:
             # don't assume we can apply an attribute to the callable
-            def wrap(*arg, **kw):
+            def wrap(*arg: Any, **kw: Any) -> Any:
                 return fn(*arg, **kw)
 
             event_key = event_key.with_wrapper(wrap)
 
-        wrap._bake_ok = bake_ok
+        wrap._bake_ok = bake_ok  # type: ignore [attr-defined]
 
         event_key.base_listen(**kw)