]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Covered low-hanging fruits of type refinement
authorGleb Kisenkov <g.kisenkov@gmail.com>
Fri, 16 Dec 2022 17:46:08 +0000 (18:46 +0100)
committerGleb Kisenkov <g.kisenkov@gmail.com>
Fri, 16 Dec 2022 17:46:08 +0000 (18:46 +0100)
lib/sqlalchemy/orm/events.py

index e9f325ad337aa8138ab34b5f31c73cc9a65200fd..b96f5b8d146f60449353ac09748ebf316c01689e 100644 (file)
@@ -4,6 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: allow-untyped-defs
 
 """ORM event interfaces.
 
 from __future__ import annotations
 
 from typing import Any
+from typing import Callable
+from typing import Generic
 from typing import Optional
+from typing import Set
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypeVar
 import weakref
 
 from . import instrumentation
@@ -29,26 +34,30 @@ from .session import sessionmaker
 from .. import event
 from .. import exc
 from .. import util
+from ..event import EventTarget
 from ..util.compat import inspect_getfullargspec
-from sqlalchemy.orm.session import Session
-from sqlalchemy.orm.state import InstanceState
-from sqlalchemy.orm.base import EventConstants
-from sqlalchemy.orm.query import Query
-from typing import Union
-from weakref import ReferenceType
-from sqlalchemy.event.registry import _EventKey
-from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept
-from sqlalchemy.orm.decl_api import DeclarativeMeta
-from sqlalchemy.orm.instrumentation import ClassManager
-from sqlalchemy.orm.mapper import Mapper
-from sqlalchemy.orm.attributes import InstrumentedAttribute
 
 if TYPE_CHECKING:
+    from typing import Union
+    from weakref import ReferenceType
+
+    from ._typing import _InternalEntityType
     from ._typing import _O
+    from ._typing import _T
     from .instrumentation import ClassManager
+    from ..event.registry import _ET
+    from ..event.registry import _EventKey
+    from ..orm.attributes import InstrumentedAttribute
+    from ..orm.base import EventConstants
+    from ..orm.decl_api import DeclarativeAttributeIntercept
+    from ..orm.decl_api import DeclarativeMeta
+    from ..orm.mapper import Mapper
+    from ..orm.state import InstanceState
+
+_ET2 = TypeVar("_ET2", bound=EventTarget)
 
 
-class InstrumentationEvents(event.Events):
+class InstrumentationEvents(event.Events[_T]):
     """Events related to class instrumentation events.
 
     The listeners here support being established against
@@ -71,24 +80,28 @@ class InstrumentationEvents(event.Events):
     """
 
     _target_class_doc = "SomeBaseClass"
-    _dispatch_target = instrumentation.InstrumentationFactory
+    _dispatch_target = instrumentation.InstrumentationFactory  # type: ignore [assignment] # noqa: E501
 
     @classmethod
-    def _accept_with(cls, target: type, identifier: str) -> _InstrumentationEventsHold:
+    def _accept_with(
+        cls, target: Union[_ET, Type[_ET]], identifier: str
+    ) -> Optional[Union[_ET, Type[_ET]]]:
         if isinstance(target, type):
-            return _InstrumentationEventsHold(target)
+            return _InstrumentationEventsHold(target)  # type: ignore [return-value] # noqa: E501
         else:
             return None
 
     @classmethod
-    def _listen(cls, event_key: _EventKey, propagate: bool = True, **kw: Any) -> None:
+    def _listen(
+        cls, event_key: _EventKey[_T], propagate: bool = True, **kw: Any
+    ) -> None:
         target, identifier, fn = (
             event_key.dispatch_target,
             event_key.identifier,
             event_key._listen_fn,
         )
 
-        def listen(target_cls: Union[int, type], *arg: Any) -> Optional[Any]:
+        def listen(target_cls: type, *arg: Any) -> Optional[Any]:
             listen_cls = target()
 
             # if weakref were collected, however this is not something
@@ -102,9 +115,11 @@ class InstrumentationEvents(event.Events):
                 return fn(target_cls, *arg)
             elif not propagate and target_cls is listen_cls:
                 return fn(target_cls, *arg)
+            else:
+                return None
 
-        def remove(ref: ReferenceType) -> None:
-            key = event.registry._EventKey(
+        def remove(ref: ReferenceType[_T]) -> None:
+            key = event.registry._EventKey(  # type: ignore [type-var]
                 None,
                 identifier,
                 listen,
@@ -157,7 +172,7 @@ class _InstrumentationEventsHold:
     dispatch = event.dispatcher(InstrumentationEvents)
 
 
-class InstanceEvents(event.Events):
+class InstanceEvents(event.Events[_ET]):
     """Define events specific to object lifecycle.
 
     e.g.::
@@ -207,15 +222,21 @@ class InstanceEvents(event.Events):
 
     _target_class_doc = "SomeClass"
 
-    _dispatch_target = instrumentation.ClassManager
+    _dispatch_target = instrumentation.ClassManager  # type: ignore [assignment] # noqa: E501
 
     @classmethod
-    def _new_classmanager_instance(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], classmanager: ClassManager) -> None:
+    def _new_classmanager_instance(
+        cls,
+        class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type],
+        classmanager: ClassManager[_O],
+    ) -> None:
         _InstanceEventsHold.populate(class_, classmanager)
 
     @classmethod
     @util.preload_module("sqlalchemy.orm")
-    def _accept_with(cls, target: Any, identifier: str) -> Union[_InstanceEventsHold, ClassManager, type]:
+    def _accept_with(
+        cls, target: Union[_ET, Type[_ET]], identifier: str
+    ) -> Optional[Union[_ET, Type[_ET]]]:
         orm = util.preloaded.orm
 
         if isinstance(target, instrumentation.ClassManager):
@@ -244,7 +265,7 @@ class InstanceEvents(event.Events):
     @classmethod
     def _listen(
         cls,
-        event_key: _EventKey,
+        event_key: _EventKey[_ET],
         raw: bool = False,
         propagate: bool = False,
         restore_load_context: bool = False,
@@ -254,7 +275,9 @@ class InstanceEvents(event.Events):
 
         if not raw or restore_load_context:
 
-            def wrap(state: InstanceState, *arg: Any, **kw: Any) -> Optional[Any]:
+            def wrap(
+                state: InstanceState[_O], *arg: Any, **kw: Any
+            ) -> Optional[Any]:
                 if not raw:
                     target = state.obj()
                 else:
@@ -566,7 +589,7 @@ class InstanceEvents(event.Events):
         """
 
 
-class _EventsHold(event.RefCollection):
+class _EventsHold(event.RefCollection[_ET]):
     """Hold onto listeners against unmapped, uninstrumented classes.
 
     Establish _listen() for that class' mapper/instrumentation when
@@ -574,19 +597,27 @@ class _EventsHold(event.RefCollection):
 
     """
 
-    def __init__(self, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type]) -> None:
+    def __init__(
+        self,
+        class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type],
+    ) -> None:
         self.class_ = class_
 
     @classmethod
     def _clear(cls) -> None:
         cls.all_holds.clear()
 
-    class HoldEvents:
-        _dispatch_target = None
+    class HoldEvents(Generic[_ET2]):
+        _dispatch_target: Optional[Type[_ET2]] = None
 
         @classmethod
         def _listen(
-            cls, event_key: _EventKey, raw: bool = False, propagate: bool = False, retval: bool = False, **kw: Any
+            cls,
+            event_key: _EventKey[_ET2],
+            raw: bool = False,
+            propagate: bool = False,
+            retval: bool = False,
+            **kw: Any,
         ) -> None:
             target = event_key.dispatch_target
 
@@ -617,7 +648,7 @@ class _EventsHold(event.RefCollection):
                             raw=raw, propagate=False, retval=retval, **kw
                         )
 
-    def remove(self, event_key: _EventKey) -> None:
+    def remove(self, event_key: _EventKey[_ET]) -> None:
         target = event_key.dispatch_target
 
         if isinstance(target, _EventsHold):
@@ -625,7 +656,11 @@ class _EventsHold(event.RefCollection):
             del collection[event_key._key]
 
     @classmethod
-    def populate(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], subject: Union[ClassManager, Mapper]) -> None:
+    def populate(
+        cls,
+        class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type],
+        subject: Union[ClassManager[_O], Mapper[_O]],
+    ) -> None:
         for subclass in class_.__mro__:
             if subclass in cls.all_holds:
                 collection = cls.all_holds[subclass]
@@ -647,19 +682,19 @@ class _EventsHold(event.RefCollection):
                         )
 
 
-class _InstanceEventsHold(_EventsHold):
+class _InstanceEventsHold(_EventsHold[_ET]):
     all_holds = weakref.WeakKeyDictionary()
 
     def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]:
         return instrumentation.opt_manager_of_class(class_)
 
-    class HoldInstanceEvents(_EventsHold.HoldEvents, InstanceEvents):
+    class HoldInstanceEvents(_EventsHold.HoldEvents[_ET], InstanceEvents[_ET]):
         pass
 
     dispatch = event.dispatcher(HoldInstanceEvents)
 
 
-class MapperEvents(event.Events):
+class MapperEvents(event.Events[_ET]):
     """Define events specific to mappings.
 
     e.g.::
@@ -729,12 +764,18 @@ class MapperEvents(event.Events):
     _dispatch_target = mapperlib.Mapper
 
     @classmethod
-    def _new_mapper_instance(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], mapper: Mapper) -> None:
+    def _new_mapper_instance(
+        cls,
+        class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type],
+        mapper: Mapper[_O],
+    ) -> None:
         _MapperEventsHold.populate(class_, mapper)
 
     @classmethod
     @util.preload_module("sqlalchemy.orm")
-    def _accept_with(cls, target: Union[DeclarativeMeta, Mapper, type], identifier: str) -> Union[_MapperEventsHold, Mapper, type]:
+    def _accept_with(
+        cls, target: Union[DeclarativeMeta, Mapper[_O], type], identifier: str
+    ) -> Union[_MapperEventsHold[_ET], Mapper[_O], type]:
         orm = util.preloaded.orm
 
         if target is orm.mapper:
@@ -759,7 +800,12 @@ class MapperEvents(event.Events):
 
     @classmethod
     def _listen(
-        cls, event_key: _EventKey, raw: bool = False, retval: bool = False, propagate: bool = False, **kw: Any
+        cls,
+        event_key: _EventKey[_ET],
+        raw: bool = False,
+        retval: bool = False,
+        propagate: bool = False,
+        **kw: Any,
     ) -> None:
         target, identifier, fn = (
             event_key.dispatch_target,
@@ -1347,19 +1393,21 @@ class MapperEvents(event.Events):
         """
 
 
-class _MapperEventsHold(_EventsHold):
+class _MapperEventsHold(_EventsHold[_ET]):
     all_holds = weakref.WeakKeyDictionary()
 
-    def resolve(self, class_: Union[DeclarativeMeta, type]) -> Optional[Mapper]:
+    def resolve(
+        self, class_: Union[Type[_T], _InternalEntityType[_T]]
+    ) -> Optional[Mapper[_T]]:
         return _mapper_or_none(class_)
 
-    class HoldMapperEvents(_EventsHold.HoldEvents, MapperEvents):
+    class HoldMapperEvents(_EventsHold.HoldEvents[_ET], MapperEvents[_ET]):
         pass
 
     dispatch = event.dispatcher(HoldMapperEvents)
 
 
-_sessionevents_lifecycle_event_names = set()
+_sessionevents_lifecycle_event_names: Set[str] = set()
 
 
 class SessionEvents(event.Events[Session]):
@@ -1407,12 +1455,16 @@ class SessionEvents(event.Events[Session]):
 
     _dispatch_target = Session
 
-    def _lifecycle_event(fn):
+    def _lifecycle_event(
+        fn: Callable[[SessionEvents, Session, Any], None]
+    ) -> Callable[[Session, Any], None]:
         _sessionevents_lifecycle_event_names.add(fn.__name__)
         return fn
 
     @classmethod
-    def _accept_with(cls, target: Any, identifier: str) -> Union[Session, type]:
+    def _accept_with(
+        cls, target: Any, identifier: str
+    ) -> Union[Session, type]:
         if isinstance(target, scoped_session):
 
             target = target.session_factory
@@ -1458,7 +1510,12 @@ class SessionEvents(event.Events[Session]):
 
                 fn = event_key._listen_fn
 
-                def wrap(session: Session, state: InstanceState, *arg: Any, **kw: Any) -> Optional[Any]:
+                def wrap(
+                    session: Session,
+                    state: InstanceState[_O],
+                    *arg: Any,
+                    **kw: Any,
+                ) -> Optional[Any]:
                     if not raw:
                         target = state.obj()
                         if target is None:
@@ -2235,7 +2292,7 @@ class SessionEvents(event.Events[Session]):
         """
 
 
-class AttributeEvents(event.Events):
+class AttributeEvents(event.Events[_ET]):
     r"""Define events for object attributes.
 
     These are typically defined on the class-bound descriptor for the
@@ -2305,7 +2362,7 @@ class AttributeEvents(event.Events):
     """
 
     _target_class_doc = "SomeClass.some_attribute"
-    _dispatch_target = QueryableAttribute
+    _dispatch_target = QueryableAttribute  # type: ignore [assignment]
 
     @staticmethod
     def _set_dispatch(cls, dispatch_cls):
@@ -2314,7 +2371,9 @@ class AttributeEvents(event.Events):
         return dispatch
 
     @classmethod
-    def _accept_with(cls, target: InstrumentedAttribute, identifier: str) -> InstrumentedAttribute:
+    def _accept_with(
+        cls, target: InstrumentedAttribute[_T], identifier: str
+    ) -> InstrumentedAttribute[_T]:
         # TODO: coverage
         if isinstance(target, interfaces.MapperProperty):
             return getattr(target.parent.class_, target.key)
@@ -2324,7 +2383,7 @@ class AttributeEvents(event.Events):
     @classmethod
     def _listen(
         cls,
-        event_key: _EventKey,
+        event_key: _EventKey[_ET],
         active_history: bool = False,
         raw: bool = False,
         retval: bool = False,
@@ -2339,7 +2398,7 @@ class AttributeEvents(event.Events):
 
         if not raw or not retval or not include_key:
 
-            def wrap(target: InstanceState, *arg: Any, **kw: Any) -> str:
+            def wrap(target: InstanceState[_O], *arg: Any, **kw: Any) -> Any:
                 if not raw:
                     target = target.obj()
                 if not retval:
@@ -2809,7 +2868,7 @@ class AttributeEvents(event.Events):
         """
 
 
-class QueryEvents(event.Events):
+class QueryEvents(event.Events[_ET]):
     """Represent events within the construction of a :class:`_query.Query`
     object.
 
@@ -2826,7 +2885,7 @@ class QueryEvents(event.Events):
     """
 
     _target_class_doc = "SomeQuery"
-    _dispatch_target = Query
+    _dispatch_target = Query  # type: ignore [assignment]
 
     def before_compile(self, query):
         """Receive the :class:`_query.Query`
@@ -2984,12 +3043,18 @@ class QueryEvents(event.Events):
         """
 
     @classmethod
-    def _listen(cls, event_key: _EventKey, retval: bool = False, bake_ok: bool = False, **kw: Any) -> None:
+    def _listen(
+        cls,
+        event_key: _EventKey[_ET],
+        retval: bool = False,
+        bake_ok: bool = False,
+        **kw: Any,
+    ) -> None:
         fn = event_key._listen_fn
 
         if not retval:
 
-            def wrap(*arg: Query, **kw: Any) -> Query:
+            def wrap(*arg: Any, **kw: Any) -> Any:
                 if not retval:
                     query = arg[0]
                     fn(*arg, **kw)
@@ -3000,11 +3065,11 @@ class QueryEvents(event.Events):
             event_key = event_key.with_wrapper(wrap)
         else:
             # don't assume we can apply an attribute to the callable
-            def wrap(*arg: Any, **kw: Any) -> Query:
+            def wrap(*arg: Any, **kw: Any) -> Any:
                 return fn(*arg, **kw)
 
             event_key = event_key.with_wrapper(wrap)
 
-        wrap._bake_ok = bake_ok
+        wrap._bake_ok = bake_ok  # type: ignore [attr-defined]
 
         event_key.base_listen(**kw)