From 2c9c85a5c995b233b3860111b15f19157bbc43f9 Mon Sep 17 00:00:00 2001 From: Gleb Kisenkov Date: Wed, 28 Dec 2022 14:23:23 -0500 Subject: [PATCH] Type annotations for sqlalchemy.orm.events ### Description An attempt to annotate `lib/sqlalchemy/orm/events.py` with type hints (issue #6810). ### Checklist This pull request is: - [ ] A documentation / typographical error fix - Good to go, no issue or tests are needed - [ ] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [x] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. **Have a nice day!** Closes: #9025 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9025 Pull-request-sha: a3fd2c0c3790164c433305ccc7ac6b73e813e037 Change-Id: I0808b6485504615fa20691dc8f4631d38bc89ab3 --- lib/sqlalchemy/orm/events.py | 496 +++++++++++++++++++++++++---------- 1 file changed, 352 insertions(+), 144 deletions(-) diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 32de155a15..b182b91ca2 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors """ORM event interfaces. @@ -12,9 +11,18 @@ from __future__ import annotations from typing import Any +from typing import Callable +from typing import Collection +from typing import Dict +from typing import Generic +from typing import Iterable from typing import Optional +from typing import Sequence +from typing import Set from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union import weakref from . import instrumentation @@ -23,6 +31,10 @@ from . import mapperlib from .attributes import QueryableAttribute from .base import _mapper_or_none from .base import NO_KEY +from .instrumentation import ClassManager +from .instrumentation import InstrumentationFactory +from .query import BulkDelete +from .query import BulkUpdate from .query import Query from .scoping import scoped_session from .session import Session @@ -30,14 +42,38 @@ from .session import sessionmaker from .. import event from .. import exc from .. import util +from ..event import EventTarget +from ..event.registry import _ET from ..util.compat import inspect_getfullargspec if TYPE_CHECKING: - from ._typing import _O - from .instrumentation import ClassManager - + from weakref import ReferenceType -class InstrumentationEvents(event.Events): + from ._typing import _InstanceDict + from ._typing import _InternalEntityType + from ._typing import _O + from ._typing import _T + from .attributes import Event + from .base import EventConstants + from .session import ORMExecuteState + from .session import SessionTransaction + from .unitofwork import UOWTransaction + from ..engine import Connection + from ..event.base import _Dispatch + from ..event.base import _HasEventsDispatch + from ..event.registry import _EventKey + from ..orm.collections import CollectionAdapter + from ..orm.context import QueryContext + from ..orm.decl_api import DeclarativeAttributeIntercept + from ..orm.decl_api import DeclarativeMeta + from ..orm.mapper import Mapper + from ..orm.state import InstanceState + +_KT = TypeVar("_KT", bound=Any) +_ET2 = TypeVar("_ET2", bound=EventTarget) + + +class InstrumentationEvents(event.Events[InstrumentationFactory]): """Events related to class instrumentation events. The listeners here support being established against @@ -60,24 +96,38 @@ class InstrumentationEvents(event.Events): """ _target_class_doc = "SomeBaseClass" - _dispatch_target = instrumentation.InstrumentationFactory + _dispatch_target = InstrumentationFactory @classmethod - def _accept_with(cls, target, identifier): + def _accept_with( + cls, + target: Union[ + InstrumentationFactory, + Type[InstrumentationFactory], + ], + identifier: str, + ) -> Optional[ + Union[ + InstrumentationFactory, + Type[InstrumentationFactory], + ] + ]: if isinstance(target, type): - return _InstrumentationEventsHold(target) + return _InstrumentationEventsHold(target) # type: ignore [return-value] # noqa: E501 else: return None @classmethod - def _listen(cls, event_key, propagate=True, **kw): + def _listen( + cls, event_key: _EventKey[_T], propagate: bool = True, **kw: Any + ) -> None: target, identifier, fn = ( event_key.dispatch_target, event_key.identifier, event_key._listen_fn, ) - def listen(target_cls, *arg): + def listen(target_cls: type, *arg: Any) -> Optional[Any]: listen_cls = target() # if weakref were collected, however this is not something @@ -91,9 +141,11 @@ class InstrumentationEvents(event.Events): return fn(target_cls, *arg) elif not propagate and target_cls is listen_cls: return fn(target_cls, *arg) + else: + return None - def remove(ref): - key = event.registry._EventKey( + def remove(ref: ReferenceType[_T]) -> None: + key = event.registry._EventKey( # type: ignore [type-var] None, identifier, listen, @@ -110,11 +162,11 @@ class InstrumentationEvents(event.Events): ).with_wrapper(listen).base_listen(**kw) @classmethod - def _clear(cls): + def _clear(cls) -> None: super()._clear() instrumentation._instrumentation_factory.dispatch._clear() - def class_instrument(self, cls): + def class_instrument(self, cls: ClassManager[_O]) -> None: """Called after the given class is instrumented. To get at the :class:`.ClassManager`, use @@ -122,7 +174,7 @@ class InstrumentationEvents(event.Events): """ - def class_uninstrument(self, cls): + def class_uninstrument(self, cls: ClassManager[_O]) -> None: """Called before the given class is uninstrumented. To get at the :class:`.ClassManager`, use @@ -130,7 +182,9 @@ class InstrumentationEvents(event.Events): """ - def attribute_instrument(self, cls, key, inst): + def attribute_instrument( + self, cls: ClassManager[_O], key: _KT, inst: _O + ) -> None: """Called when an attribute is instrumented.""" @@ -140,13 +194,13 @@ class _InstrumentationEventsHold: """ - def __init__(self, class_): + def __init__(self, class_: type) -> None: self.class_ = class_ dispatch = event.dispatcher(InstrumentationEvents) -class InstanceEvents(event.Events): +class InstanceEvents(event.Events[ClassManager[Any]]): """Define events specific to object lifecycle. e.g.:: @@ -196,56 +250,69 @@ class InstanceEvents(event.Events): _target_class_doc = "SomeClass" - _dispatch_target = instrumentation.ClassManager + _dispatch_target = ClassManager @classmethod - def _new_classmanager_instance(cls, class_, classmanager): + def _new_classmanager_instance( + cls, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + classmanager: ClassManager[_O], + ) -> None: _InstanceEventsHold.populate(class_, classmanager) @classmethod @util.preload_module("sqlalchemy.orm") - def _accept_with(cls, target, identifier): + def _accept_with( + cls, + target: Union[ + ClassManager[Any], + Type[ClassManager[Any]], + ], + identifier: str, + ) -> Optional[Union[ClassManager[Any], Type[ClassManager[Any]]]]: orm = util.preloaded.orm - if isinstance(target, instrumentation.ClassManager): + if isinstance(target, ClassManager): return target elif isinstance(target, mapperlib.Mapper): return target.class_manager - elif target is orm.mapper: + elif target is orm.mapper: # type: ignore [attr-defined] util.warn_deprecated( "The `sqlalchemy.orm.mapper()` symbol is deprecated and " "will be removed in a future release. For the mapper-wide " "event target, use the 'sqlalchemy.orm.Mapper' class.", "2.0", ) - return instrumentation.ClassManager + return ClassManager elif isinstance(target, type): if issubclass(target, mapperlib.Mapper): - return instrumentation.ClassManager + return ClassManager else: manager = instrumentation.opt_manager_of_class(target) if manager: return manager else: - return _InstanceEventsHold(target) + return _InstanceEventsHold(target) # type: ignore [return-value] # noqa: E501 return None @classmethod def _listen( cls, - event_key, - raw=False, - propagate=False, - restore_load_context=False, - **kw, - ): + event_key: _EventKey[ClassManager[Any]], + raw: bool = False, + propagate: bool = False, + restore_load_context: bool = False, + **kw: Any, + ) -> None: target, fn = (event_key.dispatch_target, event_key._listen_fn) if not raw or restore_load_context: - def wrap(state, *arg, **kw): + def wrap( + state: InstanceState[_O], *arg: Any, **kw: Any + ) -> Optional[Any]: if not raw: - target = state.obj() + target: Any = state.obj() else: target = state if restore_load_context: @@ -265,11 +332,11 @@ class InstanceEvents(event.Events): event_key.with_dispatch_target(mgr).base_listen(propagate=True) @classmethod - def _clear(cls): + def _clear(cls) -> None: super()._clear() _InstanceEventsHold._clear() - def first_init(self, manager, cls): + def first_init(self, manager: ClassManager[_O], cls: Type[_O]) -> None: """Called when the first instance of a particular mapping is called. This event is called when the ``__init__`` method of a class @@ -279,7 +346,7 @@ class InstanceEvents(event.Events): """ - def init(self, target, args, kwargs): + def init(self, target: _O, args: Any, kwargs: Any) -> None: """Receive an instance when its constructor is called. This method is only called during a userland construction of @@ -310,7 +377,7 @@ class InstanceEvents(event.Events): """ - def init_failure(self, target, args, kwargs): + def init_failure(self, target: _O, args: Any, kwargs: Any) -> None: """Receive an instance when its constructor has been called, and raised an exception. @@ -343,7 +410,9 @@ class InstanceEvents(event.Events): """ - def _sa_event_merge_wo_load(self, target, context): + def _sa_event_merge_wo_load( + self, target: _O, context: QueryContext + ) -> None: """receive an object instance after it was the subject of a merge() call, when load=False was passed. @@ -360,7 +429,7 @@ class InstanceEvents(event.Events): """ - def load(self, target, context): + def load(self, target: _O, context: QueryContext) -> None: """Receive an object instance after it has been created via ``__new__``, and after initial attribute population has occurred. @@ -435,7 +504,9 @@ class InstanceEvents(event.Events): """ - def refresh(self, target, context, attrs): + def refresh( + self, target: _O, context: QueryContext, attrs: Optional[Iterable[str]] + ) -> None: """Receive an object instance after one or more attributes have been refreshed from a query. @@ -467,7 +538,12 @@ class InstanceEvents(event.Events): """ - def refresh_flush(self, target, flush_context, attrs): + def refresh_flush( + self, + target: _O, + flush_context: UOWTransaction, + attrs: Optional[Iterable[str]], + ) -> None: """Receive an object instance after one or more attributes that contain a column-level default or onupdate handler have been refreshed during persistence of the object's state. @@ -509,7 +585,7 @@ class InstanceEvents(event.Events): """ - def expire(self, target, attrs): + def expire(self, target: _O, attrs: Optional[Iterable[str]]) -> None: """Receive an object instance after its attributes or some subset have been expired. @@ -526,7 +602,7 @@ class InstanceEvents(event.Events): """ - def pickle(self, target, state_dict): + def pickle(self, target: _O, state_dict: _InstanceDict) -> None: """Receive an object instance when its associated state is being pickled. @@ -540,7 +616,7 @@ class InstanceEvents(event.Events): """ - def unpickle(self, target, state_dict): + def unpickle(self, target: _O, state_dict: _InstanceDict) -> None: """Receive an object instance after its associated state has been unpickled. @@ -555,7 +631,7 @@ class InstanceEvents(event.Events): """ -class _EventsHold(event.RefCollection): +class _EventsHold(event.RefCollection[_ET]): """Hold onto listeners against unmapped, uninstrumented classes. Establish _listen() for that class' mapper/instrumentation when @@ -563,20 +639,30 @@ class _EventsHold(event.RefCollection): """ - def __init__(self, class_): + all_holds: weakref.WeakKeyDictionary[Any, Any] + + def __init__( + self, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + ) -> None: self.class_ = class_ @classmethod - def _clear(cls): + def _clear(cls) -> None: cls.all_holds.clear() - class HoldEvents: - _dispatch_target = None + class HoldEvents(Generic[_ET2]): + _dispatch_target: Optional[Type[_ET2]] = None @classmethod def _listen( - cls, event_key, raw=False, propagate=False, retval=False, **kw - ): + cls, + event_key: _EventKey[_ET2], + raw: bool = False, + propagate: bool = False, + retval: bool = False, + **kw: Any, + ) -> None: target = event_key.dispatch_target if target.class_ in target.all_holds: @@ -606,7 +692,7 @@ class _EventsHold(event.RefCollection): raw=raw, propagate=False, retval=retval, **kw ) - def remove(self, event_key): + def remove(self, event_key: _EventKey[_ET]) -> None: target = event_key.dispatch_target if isinstance(target, _EventsHold): @@ -614,7 +700,11 @@ class _EventsHold(event.RefCollection): del collection[event_key._key] @classmethod - def populate(cls, class_, subject): + def populate( + cls, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + subject: Union[ClassManager[_O], Mapper[_O]], + ) -> None: for subclass in class_.__mro__: if subclass in cls.all_holds: collection = cls.all_holds[subclass] @@ -636,19 +726,21 @@ class _EventsHold(event.RefCollection): ) -class _InstanceEventsHold(_EventsHold): - all_holds = weakref.WeakKeyDictionary() +class _InstanceEventsHold(_EventsHold[_ET]): + all_holds: weakref.WeakKeyDictionary[ + Any, Any + ] = weakref.WeakKeyDictionary() def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]: return instrumentation.opt_manager_of_class(class_) - class HoldInstanceEvents(_EventsHold.HoldEvents, InstanceEvents): + class HoldInstanceEvents(_EventsHold.HoldEvents[_ET], InstanceEvents): # type: ignore [misc] # noqa: E501 pass dispatch = event.dispatcher(HoldInstanceEvents) -class MapperEvents(event.Events): +class MapperEvents(event.Events[mapperlib.Mapper[Any]]): """Define events specific to mappings. e.g.:: @@ -718,15 +810,23 @@ class MapperEvents(event.Events): _dispatch_target = mapperlib.Mapper @classmethod - def _new_mapper_instance(cls, class_, mapper): + def _new_mapper_instance( + cls, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + mapper: Mapper[_O], + ) -> None: _MapperEventsHold.populate(class_, mapper) @classmethod @util.preload_module("sqlalchemy.orm") - def _accept_with(cls, target, identifier): + def _accept_with( + cls, + target: Union[mapperlib.Mapper[Any], Type[mapperlib.Mapper[Any]]], + identifier: str, + ) -> Optional[Union[mapperlib.Mapper[Any], Type[mapperlib.Mapper[Any]]]]: orm = util.preloaded.orm - if target is orm.mapper: + if target is orm.mapper: # type: ignore [attr-defined] util.warn_deprecated( "The `sqlalchemy.orm.mapper()` symbol is deprecated and " "will be removed in a future release. For the mapper-wide " @@ -748,8 +848,13 @@ class MapperEvents(event.Events): @classmethod def _listen( - cls, event_key, raw=False, retval=False, propagate=False, **kw - ): + cls, + event_key: _EventKey[_ET], + raw: bool = False, + retval: bool = False, + propagate: bool = False, + **kw: Any, + ) -> None: target, identifier, fn = ( event_key.dispatch_target, event_key.identifier, @@ -776,10 +881,10 @@ class MapperEvents(event.Events): except ValueError: target_index = None - def wrap(*arg, **kw): + def wrap(*arg: Any, **kw: Any) -> Any: if not raw and target_index is not None: - arg = list(arg) - arg[target_index] = arg[target_index].obj() + arg = list(arg) # type: ignore [assignment] + arg[target_index] = arg[target_index].obj() # type: ignore [index] # noqa: E501 if not retval: fn(*arg, **kw) return interfaces.EXT_CONTINUE @@ -797,11 +902,11 @@ class MapperEvents(event.Events): event_key.base_listen(**kw) @classmethod - def _clear(cls): + def _clear(cls) -> None: super()._clear() _MapperEventsHold._clear() - def instrument_class(self, mapper, class_): + def instrument_class(self, mapper: Mapper[_O], class_: Type[_O]) -> None: r"""Receive a class when the mapper is first constructed, before instrumentation is applied to the mapped class. @@ -824,7 +929,9 @@ class MapperEvents(event.Events): """ - def before_mapper_configured(self, mapper, class_): + def before_mapper_configured( + self, mapper: Mapper[_O], class_: Type[_O] + ) -> None: """Called right before a specific mapper is to be configured. This event is intended to allow a specific mapper to be skipped during @@ -872,7 +979,7 @@ class MapperEvents(event.Events): """ - def mapper_configured(self, mapper, class_): + def mapper_configured(self, mapper: Mapper[_O], class_: Type[_O]) -> None: r"""Called when a specific mapper has completed its own configuration within the scope of the :func:`.configure_mappers` call. @@ -926,7 +1033,7 @@ class MapperEvents(event.Events): """ # TODO: need coverage for this event - def before_configured(self): + def before_configured(self) -> None: """Called before a series of mappers have been configured. The :meth:`.MapperEvents.before_configured` event is invoked @@ -981,7 +1088,7 @@ class MapperEvents(event.Events): """ - def after_configured(self): + def after_configured(self) -> None: """Called after a series of mappers have been configured. The :meth:`.MapperEvents.after_configured` event is invoked @@ -1034,7 +1141,9 @@ class MapperEvents(event.Events): """ - def before_insert(self, mapper, connection, target): + def before_insert( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: """Receive an object instance before an INSERT statement is emitted corresponding to that instance. @@ -1080,7 +1189,9 @@ class MapperEvents(event.Events): """ - def after_insert(self, mapper, connection, target): + def after_insert( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: """Receive an object instance after an INSERT statement is emitted corresponding to that instance. @@ -1126,7 +1237,9 @@ class MapperEvents(event.Events): """ - def before_update(self, mapper, connection, target): + def before_update( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: """Receive an object instance before an UPDATE statement is emitted corresponding to that instance. @@ -1191,7 +1304,9 @@ class MapperEvents(event.Events): """ - def after_update(self, mapper, connection, target): + def after_update( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: """Receive an object instance after an UPDATE statement is emitted corresponding to that instance. @@ -1255,7 +1370,9 @@ class MapperEvents(event.Events): """ - def before_delete(self, mapper, connection, target): + def before_delete( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: """Receive an object instance before a DELETE statement is emitted corresponding to that instance. @@ -1295,7 +1412,9 @@ class MapperEvents(event.Events): """ - def after_delete(self, mapper, connection, target): + def after_delete( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: """Receive an object instance after a DELETE statement has been emitted corresponding to that instance. @@ -1336,19 +1455,21 @@ class MapperEvents(event.Events): """ -class _MapperEventsHold(_EventsHold): +class _MapperEventsHold(_EventsHold[_ET]): all_holds = weakref.WeakKeyDictionary() - def resolve(self, class_): + def resolve( + self, class_: Union[Type[_T], _InternalEntityType[_T]] + ) -> Optional[Mapper[_T]]: return _mapper_or_none(class_) - class HoldMapperEvents(_EventsHold.HoldEvents, MapperEvents): + class HoldMapperEvents(_EventsHold.HoldEvents[_ET], MapperEvents): # type: ignore [misc] # noqa: E501 pass dispatch = event.dispatcher(HoldMapperEvents) -_sessionevents_lifecycle_event_names = set() +_sessionevents_lifecycle_event_names: Set[str] = set() class SessionEvents(event.Events[Session]): @@ -1396,12 +1517,16 @@ class SessionEvents(event.Events[Session]): _dispatch_target = Session - def _lifecycle_event(fn): + def _lifecycle_event( # type: ignore [misc] + fn: Callable[[SessionEvents, Session, Any], None] + ) -> Callable[[SessionEvents, Session, Any], None]: _sessionevents_lifecycle_event_names.add(fn.__name__) return fn @classmethod - def _accept_with(cls, target, identifier): + def _accept_with( # type: ignore [return] + cls, target: Any, identifier: str + ) -> Union[Session, type]: if isinstance(target, scoped_session): target = target.session_factory @@ -1427,7 +1552,7 @@ class SessionEvents(event.Events[Session]): target._no_async_engine_events() else: # allows alternate SessionEvents-like-classes to be consulted - return event.Events._accept_with(target, identifier) + return event.Events._accept_with(target, identifier) # type: ignore [return-value] # noqa: E501 @classmethod def _listen( @@ -1447,15 +1572,20 @@ class SessionEvents(event.Events[Session]): fn = event_key._listen_fn - def wrap(session, state, *arg, **kw): + def wrap( + session: Session, + state: InstanceState[_O], + *arg: Any, + **kw: Any, + ) -> Optional[Any]: if not raw: target = state.obj() if target is None: # existing behavior is that if the object is # garbage collected, no event is emitted - return + return None else: - target = state + target = state # type: ignore [assignment] if restore_load_context: runid = state.runid try: @@ -1468,7 +1598,7 @@ class SessionEvents(event.Events[Session]): event_key.base_listen(**kw) - def do_orm_execute(self, orm_execute_state): + def do_orm_execute(self, orm_execute_state: ORMExecuteState) -> None: """Intercept statement executions that occur on behalf of an ORM :class:`.Session` object. @@ -1541,7 +1671,9 @@ class SessionEvents(event.Events[Session]): """ - def after_transaction_create(self, session, transaction): + def after_transaction_create( + self, session: Session, transaction: SessionTransaction + ) -> None: """Execute when a new :class:`.SessionTransaction` is created. This event differs from :meth:`~.SessionEvents.after_begin` @@ -1583,7 +1715,9 @@ class SessionEvents(event.Events[Session]): """ - def after_transaction_end(self, session, transaction): + def after_transaction_end( + self, session: Session, transaction: SessionTransaction + ) -> None: """Execute when the span of a :class:`.SessionTransaction` ends. This event differs from :meth:`~.SessionEvents.after_commit` @@ -1622,7 +1756,7 @@ class SessionEvents(event.Events[Session]): """ - def before_commit(self, session): + def before_commit(self, session: Session) -> None: """Execute before commit is called. .. note:: @@ -1650,7 +1784,7 @@ class SessionEvents(event.Events[Session]): """ - def after_commit(self, session): + def after_commit(self, session: Session) -> None: """Execute after a commit has occurred. .. note:: @@ -1686,7 +1820,7 @@ class SessionEvents(event.Events[Session]): """ - def after_rollback(self, session): + def after_rollback(self, session: Session) -> None: """Execute after a real DBAPI rollback has occurred. Note that this event only fires when the *actual* rollback against @@ -1704,7 +1838,9 @@ class SessionEvents(event.Events[Session]): """ - def after_soft_rollback(self, session, previous_transaction): + def after_soft_rollback( + self, session: Session, previous_transaction: SessionTransaction + ) -> None: """Execute after any rollback has occurred, including "soft" rollbacks that don't actually emit at the DBAPI level. @@ -1730,7 +1866,12 @@ class SessionEvents(event.Events[Session]): """ - def before_flush(self, session, flush_context, instances): + def before_flush( + self, + session: Session, + flush_context: UOWTransaction, + instances: Optional[Sequence[_O]], + ) -> None: """Execute before flush process has started. :param session: The target :class:`.Session`. @@ -1750,7 +1891,9 @@ class SessionEvents(event.Events[Session]): """ - def after_flush(self, session, flush_context): + def after_flush( + self, session: Session, flush_context: UOWTransaction + ) -> None: """Execute after flush has completed, but before commit has been called. @@ -1781,7 +1924,9 @@ class SessionEvents(event.Events[Session]): """ - def after_flush_postexec(self, session, flush_context): + def after_flush_postexec( + self, session: Session, flush_context: UOWTransaction + ) -> None: """Execute after flush has completed, and after the post-exec state occurs. @@ -1805,7 +1950,12 @@ class SessionEvents(event.Events[Session]): """ - def after_begin(self, session, transaction, connection): + def after_begin( + self, + session: Session, + transaction: SessionTransaction, + connection: Connection, + ) -> None: """Execute after a transaction is begun on a connection :param session: The target :class:`.Session`. @@ -1826,7 +1976,7 @@ class SessionEvents(event.Events[Session]): """ @_lifecycle_event - def before_attach(self, session, instance): + def before_attach(self, session: Session, instance: _O) -> None: """Execute before an instance is attached to a session. This is called before an add, delete or merge causes @@ -1841,7 +1991,7 @@ class SessionEvents(event.Events[Session]): """ @_lifecycle_event - def after_attach(self, session, instance): + def after_attach(self, session: Session, instance: _O) -> None: """Execute after an instance is attached to a session. This is called after an add, delete or merge. @@ -1875,7 +2025,7 @@ class SessionEvents(event.Events[Session]): update_context.result, ), ) - def after_bulk_update(self, update_context): + def after_bulk_update(self, update_context: _O) -> None: """Event for after the legacy :meth:`_orm.Query.update` method has been called. @@ -1921,7 +2071,7 @@ class SessionEvents(event.Events[Session]): delete_context.result, ), ) - def after_bulk_delete(self, delete_context): + def after_bulk_delete(self, delete_context: _O) -> None: """Event for after the legacy :meth:`_orm.Query.delete` method has been called. @@ -1956,7 +2106,7 @@ class SessionEvents(event.Events[Session]): """ @_lifecycle_event - def transient_to_pending(self, session, instance): + def transient_to_pending(self, session: Session, instance: _O) -> None: """Intercept the "transient to pending" transition for a specific object. @@ -1978,7 +2128,7 @@ class SessionEvents(event.Events[Session]): """ @_lifecycle_event - def pending_to_transient(self, session, instance): + def pending_to_transient(self, session: Session, instance: _O) -> None: """Intercept the "pending to transient" transition for a specific object. @@ -2000,7 +2150,7 @@ class SessionEvents(event.Events[Session]): """ @_lifecycle_event - def persistent_to_transient(self, session, instance): + def persistent_to_transient(self, session: Session, instance: _O) -> None: """Intercept the "persistent to transient" transition for a specific object. @@ -2021,7 +2171,7 @@ class SessionEvents(event.Events[Session]): """ @_lifecycle_event - def pending_to_persistent(self, session, instance): + def pending_to_persistent(self, session: Session, instance: _O) -> None: """Intercept the "pending to persistent"" transition for a specific object. @@ -2044,7 +2194,7 @@ class SessionEvents(event.Events[Session]): """ @_lifecycle_event - def detached_to_persistent(self, session, instance): + def detached_to_persistent(self, session: Session, instance: _O) -> None: """Intercept the "detached to persistent" transition for a specific object. @@ -2081,7 +2231,7 @@ class SessionEvents(event.Events[Session]): """ @_lifecycle_event - def loaded_as_persistent(self, session, instance): + def loaded_as_persistent(self, session: Session, instance: _O) -> None: """Intercept the "loaded as persistent" transition for a specific object. @@ -2117,7 +2267,7 @@ class SessionEvents(event.Events[Session]): """ @_lifecycle_event - def persistent_to_deleted(self, session, instance): + def persistent_to_deleted(self, session: Session, instance: _O) -> None: """Intercept the "persistent to deleted" transition for a specific object. @@ -2150,7 +2300,7 @@ class SessionEvents(event.Events[Session]): """ @_lifecycle_event - def deleted_to_persistent(self, session, instance): + def deleted_to_persistent(self, session: Session, instance: _O) -> None: """Intercept the "deleted to persistent" transition for a specific object. @@ -2168,7 +2318,7 @@ class SessionEvents(event.Events[Session]): """ @_lifecycle_event - def deleted_to_detached(self, session, instance): + def deleted_to_detached(self, session: Session, instance: _O) -> None: """Intercept the "deleted to detached" transition for a specific object. @@ -2192,7 +2342,7 @@ class SessionEvents(event.Events[Session]): """ @_lifecycle_event - def persistent_to_detached(self, session, instance): + def persistent_to_detached(self, session: Session, instance: _O) -> None: """Intercept the "persistent to detached" transition for a specific object. @@ -2224,7 +2374,7 @@ class SessionEvents(event.Events[Session]): """ -class AttributeEvents(event.Events): +class AttributeEvents(event.Events[QueryableAttribute[Any]]): r"""Define events for object attributes. These are typically defined on the class-bound descriptor for the @@ -2297,13 +2447,19 @@ class AttributeEvents(event.Events): _dispatch_target = QueryableAttribute @staticmethod - def _set_dispatch(cls, dispatch_cls): + def _set_dispatch( + cls: Type[_HasEventsDispatch[Any]], dispatch_cls: Type[_Dispatch[Any]] + ) -> _Dispatch[Any]: dispatch = event.Events._set_dispatch(cls, dispatch_cls) dispatch_cls._active_history = False return dispatch @classmethod - def _accept_with(cls, target, identifier): + def _accept_with( + cls, + target: Union[QueryableAttribute[Any], Type[QueryableAttribute[Any]]], + identifier: str, + ) -> Union[QueryableAttribute[Any], Type[QueryableAttribute[Any]]]: # TODO: coverage if isinstance(target, interfaces.MapperProperty): return getattr(target.parent.class_, target.key) @@ -2311,15 +2467,15 @@ class AttributeEvents(event.Events): return target @classmethod - def _listen( + def _listen( # type: ignore [override] cls, - event_key, - active_history=False, - raw=False, - retval=False, - propagate=False, - include_key=False, - ): + event_key: _EventKey[QueryableAttribute[Any]], + active_history: bool = False, + raw: bool = False, + retval: bool = False, + propagate: bool = False, + include_key: bool = False, + ) -> None: target, fn = event_key.dispatch_target, event_key._listen_fn @@ -2328,9 +2484,9 @@ class AttributeEvents(event.Events): if not raw or not retval or not include_key: - def wrap(target, *arg, **kw): + def wrap(target: InstanceState[_O], *arg: Any, **kw: Any) -> Any: if not raw: - target = target.obj() + target = target.obj() # type: ignore [assignment] if not retval: if arg: value = arg[0] @@ -2354,14 +2510,21 @@ class AttributeEvents(event.Events): if propagate: manager = instrumentation.manager_of_class(target.class_) - for mgr in manager.subclass_managers(True): + for mgr in manager.subclass_managers(True): # type: ignore [no-untyped-call] # noqa: E501 event_key.with_dispatch_target(mgr[target.key]).base_listen( propagate=True ) if active_history: mgr[target.key].dispatch._active_history = True - def append(self, target, value, initiator, *, key=NO_KEY): + def append( + self, + target: _O, + value: _T, + initiator: Event, + *, + key: EventConstants = NO_KEY, + ) -> Optional[_T]: """Receive a collection append event. The append event is invoked for each element as it is appended @@ -2405,7 +2568,14 @@ class AttributeEvents(event.Events): """ - def append_wo_mutation(self, target, value, initiator, *, key=NO_KEY): + def append_wo_mutation( + self, + target: _O, + value: _T, + initiator: Event, + *, + key: EventConstants = NO_KEY, + ) -> None: """Receive a collection append event where the collection was not actually mutated. @@ -2447,7 +2617,14 @@ class AttributeEvents(event.Events): """ - def bulk_replace(self, target, values, initiator, *, keys=None): + def bulk_replace( + self, + target: _O, + values: Iterable[_T], + initiator: Event, + *, + keys: Optional[Iterable[EventConstants]] = None, + ) -> None: """Receive a collection 'bulk replace' event. This event is invoked for a sequence of values as they are incoming @@ -2510,7 +2687,14 @@ class AttributeEvents(event.Events): """ - def remove(self, target, value, initiator, *, key=NO_KEY): + def remove( + self, + target: _O, + value: _T, + initiator: Event, + *, + key: EventConstants = NO_KEY, + ) -> None: """Receive a collection remove event. :param target: the object instance receiving the event. @@ -2548,7 +2732,9 @@ class AttributeEvents(event.Events): """ - def set(self, target, value, oldvalue, initiator): + def set( + self, target: _O, value: _T, oldvalue: _T, initiator: Event + ) -> None: """Receive a scalar set event. :param target: the object instance receiving the event. @@ -2584,7 +2770,9 @@ class AttributeEvents(event.Events): """ - def init_scalar(self, target, value, dict_): + def init_scalar( + self, target: _O, value: _T, dict_: Dict[Any, Any] + ) -> None: r"""Receive a scalar "init" event. This event is invoked when an uninitialized, unpersisted scalar @@ -2706,7 +2894,12 @@ class AttributeEvents(event.Events): """ - def init_collection(self, target, collection, collection_adapter): + def init_collection( + self, + target: _O, + collection: Type[Collection[Any]], + collection_adapter: CollectionAdapter, + ) -> None: """Receive a 'collection init' event. This event is triggered for a collection-based attribute, when @@ -2747,7 +2940,12 @@ class AttributeEvents(event.Events): """ - def dispose_collection(self, target, collection, collection_adapter): + def dispose_collection( + self, + target: _O, + collection: Collection[Any], + collection_adapter: CollectionAdapter, + ) -> None: """Receive a 'collection dispose' event. This event is triggered for a collection-based attribute when @@ -2774,7 +2972,7 @@ class AttributeEvents(event.Events): """ - def modified(self, target, initiator): + def modified(self, target: _O, initiator: Event) -> None: """Receive a 'modified' event. This event is triggered when the :func:`.attributes.flag_modified` @@ -2798,7 +2996,7 @@ class AttributeEvents(event.Events): """ -class QueryEvents(event.Events): +class QueryEvents(event.Events[Query[Any]]): """Represent events within the construction of a :class:`_query.Query` object. @@ -2817,7 +3015,7 @@ class QueryEvents(event.Events): _target_class_doc = "SomeQuery" _dispatch_target = Query - def before_compile(self, query): + def before_compile(self, query: Query[Any]) -> None: """Receive the :class:`_query.Query` object before it is composed into a core :class:`_expression.Select` object. @@ -2883,7 +3081,9 @@ class QueryEvents(event.Events): """ - def before_compile_update(self, query, update_context): + def before_compile_update( + self, query: Query[Any], update_context: BulkUpdate + ) -> None: """Allow modifications to the :class:`_query.Query` object within :meth:`_query.Query.update`. @@ -2933,7 +3133,9 @@ class QueryEvents(event.Events): """ - def before_compile_delete(self, query, delete_context): + def before_compile_delete( + self, query: Query[Any], delete_context: BulkDelete + ) -> None: """Allow modifications to the :class:`_query.Query` object within :meth:`_query.Query.delete`. @@ -2973,12 +3175,18 @@ class QueryEvents(event.Events): """ @classmethod - def _listen(cls, event_key, retval=False, bake_ok=False, **kw): + def _listen( + cls, + event_key: _EventKey[_ET], + retval: bool = False, + bake_ok: bool = False, + **kw: Any, + ) -> None: fn = event_key._listen_fn if not retval: - def wrap(*arg, **kw): + def wrap(*arg: Any, **kw: Any) -> Any: if not retval: query = arg[0] fn(*arg, **kw) @@ -2989,11 +3197,11 @@ class QueryEvents(event.Events): event_key = event_key.with_wrapper(wrap) else: # don't assume we can apply an attribute to the callable - def wrap(*arg, **kw): + def wrap(*arg: Any, **kw: Any) -> Any: return fn(*arg, **kw) event_key = event_key.with_wrapper(wrap) - wrap._bake_ok = bake_ok + wrap._bake_ok = bake_ok # type: ignore [attr-defined] event_key.base_listen(**kw) -- 2.47.2