]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Runtime types collected
authorGleb Kisenkov <g.kisenkov@gmail.com>
Mon, 12 Dec 2022 23:16:34 +0000 (00:16 +0100)
committerGleb Kisenkov <g.kisenkov@gmail.com>
Mon, 12 Dec 2022 23:16:34 +0000 (00:16 +0100)
lib/sqlalchemy/orm/events.py

index 32de155a15232b58cc6196b0977ac25df17ec6ac..e9f325ad337aa8138ab34b5f31c73cc9a65200fd 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 """ORM event interfaces.
 
@@ -31,6 +30,18 @@ from .. import event
 from .. import exc
 from .. import util
 from ..util.compat import inspect_getfullargspec
+from sqlalchemy.orm.session import Session
+from sqlalchemy.orm.state import InstanceState
+from sqlalchemy.orm.base import EventConstants
+from sqlalchemy.orm.query import Query
+from typing import Union
+from weakref import ReferenceType
+from sqlalchemy.event.registry import _EventKey
+from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept
+from sqlalchemy.orm.decl_api import DeclarativeMeta
+from sqlalchemy.orm.instrumentation import ClassManager
+from sqlalchemy.orm.mapper import Mapper
+from sqlalchemy.orm.attributes import InstrumentedAttribute
 
 if TYPE_CHECKING:
     from ._typing import _O
@@ -63,21 +74,21 @@ class InstrumentationEvents(event.Events):
     _dispatch_target = instrumentation.InstrumentationFactory
 
     @classmethod
-    def _accept_with(cls, target, identifier):
+    def _accept_with(cls, target: type, identifier: str) -> _InstrumentationEventsHold:
         if isinstance(target, type):
             return _InstrumentationEventsHold(target)
         else:
             return None
 
     @classmethod
-    def _listen(cls, event_key, propagate=True, **kw):
+    def _listen(cls, event_key: _EventKey, propagate: bool = True, **kw: Any) -> None:
         target, identifier, fn = (
             event_key.dispatch_target,
             event_key.identifier,
             event_key._listen_fn,
         )
 
-        def listen(target_cls, *arg):
+        def listen(target_cls: Union[int, type], *arg: Any) -> Optional[Any]:
             listen_cls = target()
 
             # if weakref were collected, however this is not something
@@ -92,7 +103,7 @@ class InstrumentationEvents(event.Events):
             elif not propagate and target_cls is listen_cls:
                 return fn(target_cls, *arg)
 
-        def remove(ref):
+        def remove(ref: ReferenceType) -> None:
             key = event.registry._EventKey(
                 None,
                 identifier,
@@ -110,7 +121,7 @@ class InstrumentationEvents(event.Events):
         ).with_wrapper(listen).base_listen(**kw)
 
     @classmethod
-    def _clear(cls):
+    def _clear(cls) -> None:
         super()._clear()
         instrumentation._instrumentation_factory.dispatch._clear()
 
@@ -140,7 +151,7 @@ class _InstrumentationEventsHold:
 
     """
 
-    def __init__(self, class_):
+    def __init__(self, class_: type) -> None:
         self.class_ = class_
 
     dispatch = event.dispatcher(InstrumentationEvents)
@@ -199,12 +210,12 @@ class InstanceEvents(event.Events):
     _dispatch_target = instrumentation.ClassManager
 
     @classmethod
-    def _new_classmanager_instance(cls, class_, classmanager):
+    def _new_classmanager_instance(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], classmanager: ClassManager) -> None:
         _InstanceEventsHold.populate(class_, classmanager)
 
     @classmethod
     @util.preload_module("sqlalchemy.orm")
-    def _accept_with(cls, target, identifier):
+    def _accept_with(cls, target: Any, identifier: str) -> Union[_InstanceEventsHold, ClassManager, type]:
         orm = util.preloaded.orm
 
         if isinstance(target, instrumentation.ClassManager):
@@ -233,17 +244,17 @@ class InstanceEvents(event.Events):
     @classmethod
     def _listen(
         cls,
-        event_key,
-        raw=False,
-        propagate=False,
-        restore_load_context=False,
-        **kw,
-    ):
+        event_key: _EventKey,
+        raw: bool = False,
+        propagate: bool = False,
+        restore_load_context: bool = False,
+        **kw: Any,
+    ) -> None:
         target, fn = (event_key.dispatch_target, event_key._listen_fn)
 
         if not raw or restore_load_context:
 
-            def wrap(state, *arg, **kw):
+            def wrap(state: InstanceState, *arg: Any, **kw: Any) -> Optional[Any]:
                 if not raw:
                     target = state.obj()
                 else:
@@ -265,7 +276,7 @@ class InstanceEvents(event.Events):
                 event_key.with_dispatch_target(mgr).base_listen(propagate=True)
 
     @classmethod
-    def _clear(cls):
+    def _clear(cls) -> None:
         super()._clear()
         _InstanceEventsHold._clear()
 
@@ -563,11 +574,11 @@ class _EventsHold(event.RefCollection):
 
     """
 
-    def __init__(self, class_):
+    def __init__(self, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type]) -> None:
         self.class_ = class_
 
     @classmethod
-    def _clear(cls):
+    def _clear(cls) -> None:
         cls.all_holds.clear()
 
     class HoldEvents:
@@ -575,8 +586,8 @@ class _EventsHold(event.RefCollection):
 
         @classmethod
         def _listen(
-            cls, event_key, raw=False, propagate=False, retval=False, **kw
-        ):
+            cls, event_key: _EventKey, raw: bool = False, propagate: bool = False, retval: bool = False, **kw: Any
+        ) -> None:
             target = event_key.dispatch_target
 
             if target.class_ in target.all_holds:
@@ -606,7 +617,7 @@ class _EventsHold(event.RefCollection):
                             raw=raw, propagate=False, retval=retval, **kw
                         )
 
-    def remove(self, event_key):
+    def remove(self, event_key: _EventKey) -> None:
         target = event_key.dispatch_target
 
         if isinstance(target, _EventsHold):
@@ -614,7 +625,7 @@ class _EventsHold(event.RefCollection):
             del collection[event_key._key]
 
     @classmethod
-    def populate(cls, class_, subject):
+    def populate(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], subject: Union[ClassManager, Mapper]) -> None:
         for subclass in class_.__mro__:
             if subclass in cls.all_holds:
                 collection = cls.all_holds[subclass]
@@ -718,12 +729,12 @@ class MapperEvents(event.Events):
     _dispatch_target = mapperlib.Mapper
 
     @classmethod
-    def _new_mapper_instance(cls, class_, mapper):
+    def _new_mapper_instance(cls, class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], mapper: Mapper) -> None:
         _MapperEventsHold.populate(class_, mapper)
 
     @classmethod
     @util.preload_module("sqlalchemy.orm")
-    def _accept_with(cls, target, identifier):
+    def _accept_with(cls, target: Union[DeclarativeMeta, Mapper, type], identifier: str) -> Union[_MapperEventsHold, Mapper, type]:
         orm = util.preloaded.orm
 
         if target is orm.mapper:
@@ -748,8 +759,8 @@ class MapperEvents(event.Events):
 
     @classmethod
     def _listen(
-        cls, event_key, raw=False, retval=False, propagate=False, **kw
-    ):
+        cls, event_key: _EventKey, raw: bool = False, retval: bool = False, propagate: bool = False, **kw: Any
+    ) -> None:
         target, identifier, fn = (
             event_key.dispatch_target,
             event_key.identifier,
@@ -776,7 +787,7 @@ class MapperEvents(event.Events):
                 except ValueError:
                     target_index = None
 
-            def wrap(*arg, **kw):
+            def wrap(*arg: Any, **kw: Any) -> EventConstants:
                 if not raw and target_index is not None:
                     arg = list(arg)
                     arg[target_index] = arg[target_index].obj()
@@ -797,7 +808,7 @@ class MapperEvents(event.Events):
             event_key.base_listen(**kw)
 
     @classmethod
-    def _clear(cls):
+    def _clear(cls) -> None:
         super()._clear()
         _MapperEventsHold._clear()
 
@@ -1339,7 +1350,7 @@ class MapperEvents(event.Events):
 class _MapperEventsHold(_EventsHold):
     all_holds = weakref.WeakKeyDictionary()
 
-    def resolve(self, class_):
+    def resolve(self, class_: Union[DeclarativeMeta, type]) -> Optional[Mapper]:
         return _mapper_or_none(class_)
 
     class HoldMapperEvents(_EventsHold.HoldEvents, MapperEvents):
@@ -1401,7 +1412,7 @@ class SessionEvents(event.Events[Session]):
         return fn
 
     @classmethod
-    def _accept_with(cls, target, identifier):
+    def _accept_with(cls, target: Any, identifier: str) -> Union[Session, type]:
         if isinstance(target, scoped_session):
 
             target = target.session_factory
@@ -1447,7 +1458,7 @@ class SessionEvents(event.Events[Session]):
 
                 fn = event_key._listen_fn
 
-                def wrap(session, state, *arg, **kw):
+                def wrap(session: Session, state: InstanceState, *arg: Any, **kw: Any) -> Optional[Any]:
                     if not raw:
                         target = state.obj()
                         if target is None:
@@ -2303,7 +2314,7 @@ class AttributeEvents(event.Events):
         return dispatch
 
     @classmethod
-    def _accept_with(cls, target, identifier):
+    def _accept_with(cls, target: InstrumentedAttribute, identifier: str) -> InstrumentedAttribute:
         # TODO: coverage
         if isinstance(target, interfaces.MapperProperty):
             return getattr(target.parent.class_, target.key)
@@ -2313,13 +2324,13 @@ class AttributeEvents(event.Events):
     @classmethod
     def _listen(
         cls,
-        event_key,
-        active_history=False,
-        raw=False,
-        retval=False,
-        propagate=False,
-        include_key=False,
-    ):
+        event_key: _EventKey,
+        active_history: bool = False,
+        raw: bool = False,
+        retval: bool = False,
+        propagate: bool = False,
+        include_key: bool = False,
+    ) -> None:
 
         target, fn = event_key.dispatch_target, event_key._listen_fn
 
@@ -2328,7 +2339,7 @@ class AttributeEvents(event.Events):
 
         if not raw or not retval or not include_key:
 
-            def wrap(target, *arg, **kw):
+            def wrap(target: InstanceState, *arg: Any, **kw: Any) -> str:
                 if not raw:
                     target = target.obj()
                 if not retval:
@@ -2973,12 +2984,12 @@ class QueryEvents(event.Events):
         """
 
     @classmethod
-    def _listen(cls, event_key, retval=False, bake_ok=False, **kw):
+    def _listen(cls, event_key: _EventKey, retval: bool = False, bake_ok: bool = False, **kw: Any) -> None:
         fn = event_key._listen_fn
 
         if not retval:
 
-            def wrap(*arg, **kw):
+            def wrap(*arg: Query, **kw: Any) -> Query:
                 if not retval:
                     query = arg[0]
                     fn(*arg, **kw)
@@ -2989,7 +3000,7 @@ class QueryEvents(event.Events):
             event_key = event_key.with_wrapper(wrap)
         else:
             # don't assume we can apply an attribute to the callable
-            def wrap(*arg, **kw):
+            def wrap(*arg: Any, **kw: Any) -> Query:
                 return fn(*arg, **kw)
 
             event_key = event_key.with_wrapper(wrap)