#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: allow-untyped-defs
"""ORM event interfaces.
from __future__ import annotations
from typing import Any
+from typing import Callable
+from typing import Generic
from typing import Optional
+from typing import Set
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
import weakref
from . import instrumentation
from .. import event
from .. import exc
from .. import util
+from ..event import EventTarget
from ..util.compat import inspect_getfullargspec
-from sqlalchemy.orm.session import Session
-from sqlalchemy.orm.state import InstanceState
-from sqlalchemy.orm.base import EventConstants
-from sqlalchemy.orm.query import Query
-from typing import Union
-from weakref import ReferenceType
-from sqlalchemy.event.registry import _EventKey
-from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept
-from sqlalchemy.orm.decl_api import DeclarativeMeta
-from sqlalchemy.orm.instrumentation import ClassManager
-from sqlalchemy.orm.mapper import Mapper
-from sqlalchemy.orm.attributes import InstrumentedAttribute
if TYPE_CHECKING:
+ from typing import Union
+ from weakref import ReferenceType
+
+ from ._typing import _InternalEntityType
from ._typing import _O
+ from ._typing import _T
from .instrumentation import ClassManager
+ from ..event.registry import _ET
+ from ..event.registry import _EventKey
+ from ..orm.attributes import InstrumentedAttribute
+ from ..orm.base import EventConstants
+ from ..orm.decl_api import DeclarativeAttributeIntercept
+ from ..orm.decl_api import DeclarativeMeta
+ from ..orm.mapper import Mapper
+ from ..orm.state import InstanceState
+
+_ET2 = TypeVar("_ET2", bound=EventTarget)
-class InstrumentationEvents(event.Events):
+class InstrumentationEvents(event.Events[_T]):
"""Events related to class instrumentation events.
The listeners here support being established against
"""
_target_class_doc = "SomeBaseClass"
- _dispatch_target = instrumentation.InstrumentationFactory
+ _dispatch_target = instrumentation.InstrumentationFactory # type: ignore [assignment] # noqa: E501
@classmethod
- def _accept_with(cls, target: type, identifier: str) -> _InstrumentationEventsHold:
+ def _accept_with(
+ cls, target: Union[_ET, Type[_ET]], identifier: str
+ ) -> Optional[Union[_ET, Type[_ET]]]:
if isinstance(target, type):
- return _InstrumentationEventsHold(target)
+ return _InstrumentationEventsHold(target) # type: ignore [return-value] # noqa: E501
else:
return None
@classmethod
- def _listen(cls, event_key: _EventKey, propagate: bool = True, **kw: Any) -> None:
+ def _listen(
+ cls, event_key: _EventKey[_T], propagate: bool = True, **kw: Any
+ ) -> None:
target, identifier, fn = (
event_key.dispatch_target,
event_key.identifier,
event_key._listen_fn,
)
- def listen(target_cls: Union[int, type], *arg: Any) -> Optional[Any]:
+ def listen(target_cls: type, *arg: Any) -> Optional[Any]:
listen_cls = target()
# if weakref were collected, however this is not something
return fn(target_cls, *arg)
elif not propagate and target_cls is listen_cls:
return fn(target_cls, *arg)
+ else:
+ return None
- def remove(ref: ReferenceType) -> None:
- key = event.registry._EventKey(
+ def remove(ref: ReferenceType[_T]) -> None:
+ key = event.registry._EventKey( # type: ignore [type-var]
None,
identifier,
listen,
dispatch = event.dispatcher(InstrumentationEvents)
-class InstanceEvents(event.Events):
+class InstanceEvents(event.Events[_ET]):
"""Define events specific to object lifecycle.
e.g.::
_target_class_doc = "SomeClass"
- _dispatch_target = instrumentation.ClassManager
+ _dispatch_target = instrumentation.ClassManager # type: ignore [assignment] # noqa: E501
@classmethod
- def _new_classmanager_instance(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], classmanager: ClassManager) -> None:
+ def _new_classmanager_instance(
+ cls,
+ class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type],
+ classmanager: ClassManager[_O],
+ ) -> None:
_InstanceEventsHold.populate(class_, classmanager)
@classmethod
@util.preload_module("sqlalchemy.orm")
- def _accept_with(cls, target: Any, identifier: str) -> Union[_InstanceEventsHold, ClassManager, type]:
+ def _accept_with(
+ cls, target: Union[_ET, Type[_ET]], identifier: str
+ ) -> Optional[Union[_ET, Type[_ET]]]:
orm = util.preloaded.orm
if isinstance(target, instrumentation.ClassManager):
@classmethod
def _listen(
cls,
- event_key: _EventKey,
+ event_key: _EventKey[_ET],
raw: bool = False,
propagate: bool = False,
restore_load_context: bool = False,
if not raw or restore_load_context:
- def wrap(state: InstanceState, *arg: Any, **kw: Any) -> Optional[Any]:
+ def wrap(
+ state: InstanceState[_O], *arg: Any, **kw: Any
+ ) -> Optional[Any]:
if not raw:
target = state.obj()
else:
"""
-class _EventsHold(event.RefCollection):
+class _EventsHold(event.RefCollection[_ET]):
"""Hold onto listeners against unmapped, uninstrumented classes.
Establish _listen() for that class' mapper/instrumentation when
"""
- def __init__(self, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type]) -> None:
+ def __init__(
+ self,
+ class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type],
+ ) -> None:
self.class_ = class_
@classmethod
def _clear(cls) -> None:
cls.all_holds.clear()
- class HoldEvents:
- _dispatch_target = None
+ class HoldEvents(Generic[_ET2]):
+ _dispatch_target: Optional[Type[_ET2]] = None
@classmethod
def _listen(
- cls, event_key: _EventKey, raw: bool = False, propagate: bool = False, retval: bool = False, **kw: Any
+ cls,
+ event_key: _EventKey[_ET2],
+ raw: bool = False,
+ propagate: bool = False,
+ retval: bool = False,
+ **kw: Any,
) -> None:
target = event_key.dispatch_target
raw=raw, propagate=False, retval=retval, **kw
)
- def remove(self, event_key: _EventKey) -> None:
+ def remove(self, event_key: _EventKey[_ET]) -> None:
target = event_key.dispatch_target
if isinstance(target, _EventsHold):
del collection[event_key._key]
@classmethod
- def populate(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], subject: Union[ClassManager, Mapper]) -> None:
+ def populate(
+ cls,
+ class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type],
+ subject: Union[ClassManager[_O], Mapper[_O]],
+ ) -> None:
for subclass in class_.__mro__:
if subclass in cls.all_holds:
collection = cls.all_holds[subclass]
)
-class _InstanceEventsHold(_EventsHold):
+class _InstanceEventsHold(_EventsHold[_ET]):
all_holds = weakref.WeakKeyDictionary()
def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]:
return instrumentation.opt_manager_of_class(class_)
- class HoldInstanceEvents(_EventsHold.HoldEvents, InstanceEvents):
+ class HoldInstanceEvents(_EventsHold.HoldEvents[_ET], InstanceEvents[_ET]):
pass
dispatch = event.dispatcher(HoldInstanceEvents)
-class MapperEvents(event.Events):
+class MapperEvents(event.Events[_ET]):
"""Define events specific to mappings.
e.g.::
_dispatch_target = mapperlib.Mapper
@classmethod
- def _new_mapper_instance(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], mapper: Mapper) -> None:
+ def _new_mapper_instance(
+ cls,
+ class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type],
+ mapper: Mapper[_O],
+ ) -> None:
_MapperEventsHold.populate(class_, mapper)
@classmethod
@util.preload_module("sqlalchemy.orm")
- def _accept_with(cls, target: Union[DeclarativeMeta, Mapper, type], identifier: str) -> Union[_MapperEventsHold, Mapper, type]:
+ def _accept_with(
+ cls, target: Union[DeclarativeMeta, Mapper[_O], type], identifier: str
+ ) -> Union[_MapperEventsHold[_ET], Mapper[_O], type]:
orm = util.preloaded.orm
if target is orm.mapper:
@classmethod
def _listen(
- cls, event_key: _EventKey, raw: bool = False, retval: bool = False, propagate: bool = False, **kw: Any
+ cls,
+ event_key: _EventKey[_ET],
+ raw: bool = False,
+ retval: bool = False,
+ propagate: bool = False,
+ **kw: Any,
) -> None:
target, identifier, fn = (
event_key.dispatch_target,
"""
-class _MapperEventsHold(_EventsHold):
+class _MapperEventsHold(_EventsHold[_ET]):
all_holds = weakref.WeakKeyDictionary()
- def resolve(self, class_: Union[DeclarativeMeta, type]) -> Optional[Mapper]:
+ def resolve(
+ self, class_: Union[Type[_T], _InternalEntityType[_T]]
+ ) -> Optional[Mapper[_T]]:
return _mapper_or_none(class_)
- class HoldMapperEvents(_EventsHold.HoldEvents, MapperEvents):
+ class HoldMapperEvents(_EventsHold.HoldEvents[_ET], MapperEvents[_ET]):
pass
dispatch = event.dispatcher(HoldMapperEvents)
-_sessionevents_lifecycle_event_names = set()
+_sessionevents_lifecycle_event_names: Set[str] = set()
class SessionEvents(event.Events[Session]):
_dispatch_target = Session
- def _lifecycle_event(fn):
+ def _lifecycle_event(
+ fn: Callable[[SessionEvents, Session, Any], None]
+ ) -> Callable[[Session, Any], None]:
_sessionevents_lifecycle_event_names.add(fn.__name__)
return fn
@classmethod
- def _accept_with(cls, target: Any, identifier: str) -> Union[Session, type]:
+ def _accept_with(
+ cls, target: Any, identifier: str
+ ) -> Union[Session, type]:
if isinstance(target, scoped_session):
target = target.session_factory
fn = event_key._listen_fn
- def wrap(session: Session, state: InstanceState, *arg: Any, **kw: Any) -> Optional[Any]:
+ def wrap(
+ session: Session,
+ state: InstanceState[_O],
+ *arg: Any,
+ **kw: Any,
+ ) -> Optional[Any]:
if not raw:
target = state.obj()
if target is None:
"""
-class AttributeEvents(event.Events):
+class AttributeEvents(event.Events[_ET]):
r"""Define events for object attributes.
These are typically defined on the class-bound descriptor for the
"""
_target_class_doc = "SomeClass.some_attribute"
- _dispatch_target = QueryableAttribute
+ _dispatch_target = QueryableAttribute # type: ignore [assignment]
@staticmethod
def _set_dispatch(cls, dispatch_cls):
return dispatch
@classmethod
- def _accept_with(cls, target: InstrumentedAttribute, identifier: str) -> InstrumentedAttribute:
+ def _accept_with(
+ cls, target: InstrumentedAttribute[_T], identifier: str
+ ) -> InstrumentedAttribute[_T]:
# TODO: coverage
if isinstance(target, interfaces.MapperProperty):
return getattr(target.parent.class_, target.key)
@classmethod
def _listen(
cls,
- event_key: _EventKey,
+ event_key: _EventKey[_ET],
active_history: bool = False,
raw: bool = False,
retval: bool = False,
if not raw or not retval or not include_key:
- def wrap(target: InstanceState, *arg: Any, **kw: Any) -> str:
+ def wrap(target: InstanceState[_O], *arg: Any, **kw: Any) -> Any:
if not raw:
target = target.obj()
if not retval:
"""
-class QueryEvents(event.Events):
+class QueryEvents(event.Events[_ET]):
"""Represent events within the construction of a :class:`_query.Query`
object.
"""
_target_class_doc = "SomeQuery"
- _dispatch_target = Query
+ _dispatch_target = Query # type: ignore [assignment]
def before_compile(self, query):
"""Receive the :class:`_query.Query`
"""
@classmethod
- def _listen(cls, event_key: _EventKey, retval: bool = False, bake_ok: bool = False, **kw: Any) -> None:
+ def _listen(
+ cls,
+ event_key: _EventKey[_ET],
+ retval: bool = False,
+ bake_ok: bool = False,
+ **kw: Any,
+ ) -> None:
fn = event_key._listen_fn
if not retval:
- def wrap(*arg: Query, **kw: Any) -> Query:
+ def wrap(*arg: Any, **kw: Any) -> Any:
if not retval:
query = arg[0]
fn(*arg, **kw)
event_key = event_key.with_wrapper(wrap)
else:
# don't assume we can apply an attribute to the callable
- def wrap(*arg: Any, **kw: Any) -> Query:
+ def wrap(*arg: Any, **kw: Any) -> Any:
return fn(*arg, **kw)
event_key = event_key.with_wrapper(wrap)
- wrap._bake_ok = bake_ok
+ wrap._bake_ok = bake_ok # type: ignore [attr-defined]
event_key.base_listen(**kw)