]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Corrected some autogenerated type hints here and there
authorGleb Kisenkov <g.kisenkov@godeltech.com>
Thu, 10 Nov 2022 16:07:53 +0000 (17:07 +0100)
committerGleb Kisenkov <g.kisenkov@godeltech.com>
Thu, 10 Nov 2022 16:07:53 +0000 (17:07 +0100)
lib/sqlalchemy/ext/mutable.py

index 6ceb8a1d5d85822e0d49efced08053fe2b6c8c9b..15ca466c76c209d3d7b29607a8af1ceddd6cfb6c 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 r"""Provide support for tracking of in-place changes to scalar values,
 which are propagated into ORM change events on owning parent objects.
@@ -358,38 +357,44 @@ pickling process of the parent's object-relational state so that the
 from __future__ import annotations
 
 from collections import defaultdict
+from typing import AbstractSet
 from typing import Any
 from typing import Dict
+from typing import Iterable
 from typing import List
 from typing import Optional
 from typing import Self
 from typing import Set
+from typing import SupportsIndex
 from typing import Tuple
+from typing import TypeVar
 from typing import Union
 import weakref
 from weakref import WeakKeyDictionary
 
-from sqlalchemy.orm.attributes import AttributeEventToken
-from sqlalchemy.orm.attributes import create_proxied_attribute
-from sqlalchemy.orm.attributes import InstrumentedAttribute
-from sqlalchemy.orm.base import LoaderCallableStatus
-from sqlalchemy.orm.context import QueryContext
-from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept
-from sqlalchemy.orm.state import InstanceState
-from sqlalchemy.orm.unitofwork import UOWTransaction
-from sqlalchemy.sql.schema import Column
-from sqlalchemy.sql.sqltypes import JSON
-from sqlalchemy.sql.sqltypes import PickleType
-from test.ext.test_mutable import CustomMutableAssociationScalarJSONTest
-from test.ext.test_mutable import MutableColumnCopyJSONTest
 from .. import event
 from .. import inspect
 from .. import types
 from ..orm import Mapper
+from ..orm._typing import _ExternalEntityType
+from ..orm._typing import _O
+from ..orm._typing import _T
+from ..orm.attributes import AttributeEventToken
 from ..orm.attributes import flag_modified
+from ..orm.attributes import InstrumentedAttribute
+from ..orm.attributes import QueryableAttribute
+from ..orm.context import QueryContext
+from ..orm.decl_api import DeclarativeAttributeIntercept
+from ..orm.state import InstanceState
+from ..orm.unitofwork import UOWTransaction
 from ..sql.base import SchemaEventTarget
+from ..sql.schema import Column
+from ..sql.type_api import TypeEngine
 from ..util import memoized_property
 
+_KT = TypeVar("_KT")  # Key type.
+_VT = TypeVar("_VT")  # Value type.
+
 
 class MutableBase:
     """Common base class to :class:`.Mutable`
@@ -398,7 +403,7 @@ class MutableBase:
     """
 
     @memoized_property
-    def _parents(self) -> WeakKeyDictionary:
+    def _parents(self) -> WeakKeyDictionary[Any, Any]:
         """Dictionary of parent object's :class:`.InstanceState`->attribute
         name on the parent.
 
@@ -444,7 +449,7 @@ class MutableBase:
         raise ValueError(msg % (key, type(value)))
 
     @classmethod
-    def _get_listen_keys(cls, attribute: InstrumentedAttribute) -> Set[str]:
+    def _get_listen_keys(cls, attribute: QueryableAttribute[Any]) -> Set[str]:
         """Given a descriptor attribute, return a ``set()`` of the attribute
         keys which indicate a change in the state of this attribute.
 
@@ -467,11 +472,9 @@ class MutableBase:
     @classmethod
     def _listen_on_attribute(
         cls,
-        attribute: Union[
-            InstrumentedAttribute, create_proxied_attribute.Proxy
-        ],
+        attribute: QueryableAttribute[Any],
         coerce: bool,
-        parent_cls: Union[DeclarativeAttributeIntercept, type],
+        parent_cls: _ExternalEntityType[Any],
     ) -> None:
         """Establish this type as a mutation listener for the given
         mapped descriptor.
@@ -486,7 +489,7 @@ class MutableBase:
 
         listen_keys = cls._get_listen_keys(attribute)
 
-        def load(state: InstanceState, *args: Any) -> None:
+        def load(state: InstanceState[_O], *args: Any) -> None:
             """Listen for objects loaded or refreshed.
 
             Wrap the target data member's value with
@@ -501,27 +504,19 @@ class MutableBase:
                 val._parents[state] = key
 
         def load_attrs(
-            state: InstanceState,
+            state: InstanceState[_O],
             ctx: Union[object, QueryContext, UOWTransaction],
-            attrs: Union[List[str], frozenset],
+            attrs: Iterable[Any],
         ) -> None:
             if not attrs or listen_keys.intersection(attrs):
                 load(state)
 
         def set_(
-            target: InstanceState,
+            target: InstanceState[_O],
             value: Any,
-            oldvalue: Union[
-                MutableDict,
-                LoaderCallableStatus,
-                CustomMutableAssociationScalarJSONTest._type_fixture.CustomMutableDict,  # noqa: E501
-            ],
+            oldvalue: Any,
             initiator: AttributeEventToken,
-        ) -> Union[
-            None,
-            MutableDict,
-            CustomMutableAssociationScalarJSONTest._type_fixture.CustomMutableDict,  # noqa: E501
-        ]:
+        ) -> Any:
             """Listen for set/replace events on the target
             data member.
 
@@ -541,14 +536,18 @@ class MutableBase:
                 oldvalue._parents.pop(inspect(target), None)
             return value
 
-        def pickle(state: InstanceState, state_dict: Dict[str, Any]) -> None:
+        def pickle(
+            state: InstanceState[_O], state_dict: Dict[str, Any]
+        ) -> None:
             val = state.dict.get(key, None)
             if val is not None:
                 if "ext.mutable.values" not in state_dict:
                     state_dict["ext.mutable.values"] = defaultdict(list)
                 state_dict["ext.mutable.values"][key].append(val)
 
-        def unpickle(state: InstanceState, state_dict: Dict[str, Any]) -> None:
+        def unpickle(
+            state: InstanceState[_O], state_dict: Dict[str, Any]
+        ) -> None:
             if "ext.mutable.values" in state_dict:
                 collection = state_dict["ext.mutable.values"]
                 if isinstance(collection, list):
@@ -599,7 +598,7 @@ class Mutable(MutableBase):
 
     @classmethod
     def associate_with_attribute(
-        cls, attribute: InstrumentedAttribute
+        cls, attribute: InstrumentedAttribute[_O]
     ) -> None:
         """Establish this type as a mutation listener for the given
         mapped descriptor.
@@ -625,7 +624,7 @@ class Mutable(MutableBase):
 
         """
 
-        def listen_for_type(mapper: Mapper, class_: type) -> None:
+        def listen_for_type(mapper: Mapper[_O], class_: type) -> None:
             if mapper.non_primary:
                 return
             for prop in mapper.column_attrs:
@@ -635,13 +634,7 @@ class Mutable(MutableBase):
         event.listen(Mapper, "mapper_configured", listen_for_type)
 
     @classmethod
-    def as_mutable(
-        cls, sqltype: type
-    ) -> Union[
-        JSON,
-        PickleType,
-        MutableColumnCopyJSONTest.define_tables.JSONEncodedDict,
-    ]:
+    def as_mutable(cls, sqltype: TypeEngine[_T]) -> TypeEngine[_T]:
         """Associate a SQL type with this mutable Python type.
 
         This establishes listeners that will detect ORM mappings against
@@ -682,11 +675,8 @@ class Mutable(MutableBase):
 
             @event.listens_for(sqltype, "before_parent_attach")
             def _add_column_memo(
-                sqltyp: Union[
-                    PickleType,
-                    MutableColumnCopyJSONTest.define_tables.JSONEncodedDict,
-                ],
-                parent: Column,
+                sqltyp: TypeEngine[Any],
+                parent: Column[_T],
             ) -> None:
                 parent.info["_ext_mutable_orig_type"] = sqltyp
 
@@ -695,7 +685,8 @@ class Mutable(MutableBase):
             schema_event_check = False
 
         def listen_for_type(
-            mapper: Mapper, class_: Union[DeclarativeAttributeIntercept, type]
+            mapper: Mapper[_T],
+            class_: Union[DeclarativeAttributeIntercept, type],
         ) -> None:
             if mapper.non_primary:
                 return
@@ -703,7 +694,7 @@ class Mutable(MutableBase):
                 if (
                     schema_event_check
                     and hasattr(prop.expression, "info")
-                    and prop.expression.info.get("_ext_mutable_orig_type")
+                    and prop.expression.info.get("_ext_mutable_orig_type")  # type: ignore # noqa: E501 # TODO: https://github.com/python/mypy/issues/1424#issuecomment-1272354487
                     is sqltype
                 ) or (prop.columns[0].type is sqltype):
                     cls.associate_with_attribute(getattr(class_, prop.key))
@@ -723,9 +714,7 @@ class MutableComposite(MutableBase):
     """
 
     @classmethod
-    def _get_listen_keys(
-        cls, attribute: create_proxied_attribute.Proxy
-    ) -> Set[str]:
+    def _get_listen_keys(cls, attribute: QueryableAttribute[_O]) -> Set[str]:
         return {attribute.key}.union(attribute.property._attribute_keys)
 
     def changed(self) -> None:
@@ -742,7 +731,7 @@ class MutableComposite(MutableBase):
 
 
 def _setup_composite_listener() -> None:
-    def _listen_for_type(mapper: Mapper, class_: type) -> None:
+    def _listen_for_type(mapper: Mapper[_T], class_: type) -> None:
         for prop in mapper.iterate_properties:
             if (
                 hasattr(prop, "composite_class")
@@ -760,7 +749,7 @@ def _setup_composite_listener() -> None:
 _setup_composite_listener()
 
 
-class MutableDict(Mutable, dict):
+class MutableDict(Mutable, dict[_KT, _VT]):
     """A dictionary type that implements :class:`.Mutable`.
 
     The :class:`.MutableDict` object implements a dictionary that will
@@ -783,32 +772,32 @@ class MutableDict(Mutable, dict):
 
     """
 
-    def __setitem__(self, key: str, value: Union[int, str]) -> None:
+    def __setitem__(self: Self, key: _KT, value: _VT) -> None:
         """Detect dictionary set events and emit change events."""
-        dict.__setitem__(self, key, value)
+        super().__setitem__(key, value)
         self.changed()
 
-    def setdefault(self, key: str, value: str) -> str:
-        result = dict.setdefault(self, key, value)
+    def setdefault(self, key: _KT, value: _VT) -> _VT:
+        result = super().setdefault(key, value)
         self.changed()
         return result
 
-    def __delitem__(self, key):
+    def __delitem__(self, key: _KT):
         """Detect dictionary del events and emit change events."""
-        dict.__delitem__(self, key)
+        super().__delitem__(key)
         self.changed()
 
-    def update(self, *a: Any, **kw: Any) -> None:
-        dict.update(self, *a, **kw)
+    def update(self, *a: Any, **kw: _VT) -> None:
+        super().update(*a, **kw)
         self.changed()
 
-    def pop(self, *arg: str) -> str:
-        result = dict.pop(self, *arg)
+    def pop(self, *arg: _KT) -> _VT:
+        result = super().pop(*arg)
         self.changed()
         return result
 
-    def popitem(self) -> Tuple[str, str]:
-        result = dict.popitem(self)
+    def popitem(self) -> Tuple[_KT, _VT]:
+        result = super().popitem()
         self.changed()
         return result
 
@@ -817,13 +806,7 @@ class MutableDict(Mutable, dict):
         self.changed()
 
     @classmethod
-    def coerce(
-        cls, key: str, value: Any
-    ) -> Union[
-        None,
-        Self,
-        CustomMutableAssociationScalarJSONTest._type_fixture.CustomMutableDict,
-    ]:
+    def coerce(cls, key: str, value: Any) -> Self | None:
         """Convert plain dictionary to instance of this class."""
         if not isinstance(value, cls):
             if isinstance(value, dict):
@@ -832,7 +815,7 @@ class MutableDict(Mutable, dict):
         else:
             return value
 
-    def __getstate__(self) -> Union[Dict[str, int], Dict[str, str]]:
+    def __getstate__(self) -> dict[_KT, _VT]:
         return dict(self)
 
     def __setstate__(
@@ -841,7 +824,7 @@ class MutableDict(Mutable, dict):
         self.update(state)
 
 
-class MutableList(Mutable, list):
+class MutableList(Mutable, list[_T]):
     """A list type that implements :class:`.Mutable`.
 
     The :class:`.MutableList` object implements a list that will
@@ -866,69 +849,61 @@ class MutableList(Mutable, list):
 
     """
 
-    def __reduce_ex__(self, proto: int) -> Tuple[type, Tuple[List[int]]]:
+    def __reduce_ex__(
+        self, proto: SupportsIndex
+    ) -> Tuple[type, Tuple[List[int]]]:
         return (self.__class__, (list(self),))
 
     # needed for backwards compatibility with
     # older pickles
-    def __setstate__(self, state: List[int]) -> None:
+    def __setstate__(self, state: Iterable[_T]) -> None:
         self[:] = state
 
     def __setitem__(
         self,
-        index: Union[int, slice],
-        value: Union[List[int], Tuple[int, int], int],
+        index: Union[SupportsIndex, slice],
+        value: Union[_T, Iterable[_T]],
     ) -> None:
         """Detect list set events and emit change events."""
-        list.__setitem__(self, index, value)
-        self.changed()
-
-    def __setslice__(self, start, end, value):
-        """Detect list set events and emit change events."""
-        list.__setslice__(self, start, end, value)
-        self.changed()
-
-    def __delitem__(self, index: slice) -> None:
-        """Detect list del events and emit change events."""
-        list.__delitem__(self, index)
+        super().__setitem__(index, value)
         self.changed()
 
-    def __delslice__(self, start, end):
+    def __delitem__(self, index: SupportsIndex | slice) -> None:
         """Detect list del events and emit change events."""
-        list.__delslice__(self, start, end)
+        super().__delitem__(index)
         self.changed()
 
-    def pop(self, *arg: int) -> int:
-        result = list.pop(self, *arg)
+    def pop(self, *arg: int) -> _T:
+        result = super().pop(*arg)
         self.changed()
         return result
 
-    def append(self, x: int) -> None:
-        list.append(self, x)
+    def append(self, x: _T) -> None:
+        super().append(x)
         self.changed()
 
-    def extend(self, x: List[int]) -> None:
-        list.extend(self, x)
+    def extend(self, x: Iterable[_T]) -> None:
+        super().extend(x)
         self.changed()
 
-    def __iadd__(self, x: List[int]) -> Self:
+    def __iadd__(self, x: Iterable[_T]) -> Self:
         self.extend(x)
         return self
 
-    def insert(self, i: int, x: int) -> None:
-        list.insert(self, i, x)
+    def insert(self, i: SupportsIndex, x: _T) -> None:
+        super().insert(i, x)
         self.changed()
 
-    def remove(self, i: int) -> None:
-        list.remove(self, i)
+    def remove(self, i: _T) -> None:
+        super().remove(i)
         self.changed()
 
     def clear(self) -> None:
-        list.clear(self)
+        super().clear()
         self.changed()
 
     def sort(self, **kw: Any) -> None:
-        list.sort(self, **kw)
+        super().sort(**kw)
         self.changed()
 
     def reverse(self) -> None:
@@ -946,7 +921,7 @@ class MutableList(Mutable, list):
             return value
 
 
-class MutableSet(Mutable, set):
+class MutableSet(Mutable, set[_T]):
     """A set type that implements :class:`.Mutable`.
 
     The :class:`.MutableSet` object implements a set that will
@@ -972,57 +947,57 @@ class MutableSet(Mutable, set):
 
     """
 
-    def update(self, *arg: Set[int]) -> None:
-        set.update(self, *arg)
+    def update(self, *arg: Iterable[_T]) -> None:
+        super().update(*arg)
         self.changed()
 
-    def intersection_update(self, *arg: Set[int]) -> None:
-        set.intersection_update(self, *arg)
+    def intersection_update(self, *arg: Iterable[Any]) -> None:
+        super().intersection_update(*arg)
         self.changed()
 
-    def difference_update(self, *arg: Set[int]) -> None:
-        set.difference_update(self, *arg)
+    def difference_update(self, *arg: Iterable[Any]) -> None:
+        super().difference_update(*arg)
         self.changed()
 
-    def symmetric_difference_update(self, *arg: Set[int]) -> None:
-        set.symmetric_difference_update(self, *arg)
+    def symmetric_difference_update(self, *arg: Iterable[_T]) -> None:
+        super().symmetric_difference_update(*arg)
         self.changed()
 
-    def __ior__(self, other: Set[int]) -> Self:
+    def __ior__(self, other: Iterable[_T]) -> Self:
         self.update(other)
         return self
 
-    def __iand__(self, other: Set[int]) -> Self:
+    def __iand__(self, other: AbstractSet[object]) -> Self:
         self.intersection_update(other)
         return self
 
-    def __ixor__(self, other: Set[int]) -> Self:
+    def __ixor__(self, other: AbstractSet[_T]) -> Self:
         self.symmetric_difference_update(other)
         return self
 
-    def __isub__(self, other: Set[int]) -> Self:
+    def __isub__(self, other: AbstractSet[object]) -> Self:
         self.difference_update(other)
         return self
 
-    def add(self, elem: int) -> None:
-        set.add(self, elem)
+    def add(self, elem: _T) -> None:
+        super().add(elem)
         self.changed()
 
-    def remove(self, elem: int) -> None:
-        set.remove(self, elem)
+    def remove(self, elem: _T) -> None:
+        super().remove(elem)
         self.changed()
 
-    def discard(self, elem: int) -> None:
-        set.discard(self, elem)
+    def discard(self, elem: _T) -> None:
+        super().discard(elem)
         self.changed()
 
-    def pop(self, *arg: Any) -> int:
-        result = set.pop(self, *arg)
+    def pop(self, *arg: Any) -> _T:
+        result = super().pop(*arg)
         self.changed()
         return result
 
     def clear(self) -> None:
-        set.clear(self)
+        super().clear()
         self.changed()
 
     @classmethod
@@ -1035,11 +1010,13 @@ class MutableSet(Mutable, set):
         else:
             return value
 
-    def __getstate__(self):
+    def __getstate__(self) -> set[_T]:
         return set(self)
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Iterable[_T]) -> None:
         self.update(state)
 
-    def __reduce_ex__(self, proto: int) -> Tuple[type, Tuple[List[int]]]:
+    def __reduce_ex__(
+        self, proto: SupportsIndex
+    ) -> Tuple[type, Tuple[List[int]]]:
         return (self.__class__, (list(self),))