]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Type annotations for sqlalchemy.ext.mutable
authorGleb Kisenkov <g.kisenkov@gmail.com>
Wed, 16 Nov 2022 15:23:06 +0000 (10:23 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 22 Nov 2022 15:39:41 +0000 (10:39 -0500)
The ``sqlalchemy.ext.mutable`` extension is now fully pep-484 typed. Huge
thanks to Gleb Kisenkov for their efforts on this.

Fixes: #8667
Closes: #8775
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8775
Pull-request-sha: b907888ec67facc12dbdbccd6f2d9cd533b08a50

Change-Id: Id9224e03201e6970b1ec56eb546ece4b2f3e0edd

doc/build/changelog/unreleased_20/8667.rst [new file with mode: 0644]
lib/sqlalchemy/ext/mutable.py

diff --git a/doc/build/changelog/unreleased_20/8667.rst b/doc/build/changelog/unreleased_20/8667.rst
new file mode 100644 (file)
index 0000000..50dc5d8
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 8667
+
+    The ``sqlalchemy.ext.mutable`` extension is now fully pep-484 typed. Huge
+    thanks to Gleb Kisenkov for their efforts on this.
index 1f102cb36f914f5545ee3554659f1d852f8f0f64..f9ed17efc1829a55ff8b4047b39f692454649c10 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 r"""Provide support for tracking of in-place changes to scalar values,
 which are propagated into ORM change events on owning parent objects.
@@ -355,16 +354,47 @@ pickling process of the parent's object-relational state so that the
 
 """  # noqa: E501
 
+from __future__ import annotations
+
 from collections import defaultdict
+from typing import AbstractSet
+from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Set
+from typing import Tuple
+from typing import TypeVar
+from typing import Union
 import weakref
+from weakref import WeakKeyDictionary
 
 from .. import event
 from .. import inspect
 from .. import types
 from ..orm import Mapper
+from ..orm._typing import _ExternalEntityType
+from ..orm._typing import _O
+from ..orm._typing import _T
+from ..orm.attributes import AttributeEventToken
 from ..orm.attributes import flag_modified
+from ..orm.attributes import InstrumentedAttribute
+from ..orm.attributes import QueryableAttribute
+from ..orm.context import QueryContext
+from ..orm.decl_api import DeclarativeAttributeIntercept
+from ..orm.state import InstanceState
+from ..orm.unitofwork import UOWTransaction
 from ..sql.base import SchemaEventTarget
+from ..sql.schema import Column
+from ..sql.type_api import TypeEngine
 from ..util import memoized_property
+from ..util.typing import SupportsIndex
+from ..util.typing import TypeGuard
+
+_KT = TypeVar("_KT")  # Key type.
+_VT = TypeVar("_VT")  # Value type.
 
 
 class MutableBase:
@@ -374,7 +404,7 @@ class MutableBase:
     """
 
     @memoized_property
-    def _parents(self):
+    def _parents(self) -> WeakKeyDictionary[Any, Any]:
         """Dictionary of parent object's :class:`.InstanceState`->attribute
         name on the parent.
 
@@ -391,7 +421,7 @@ class MutableBase:
         return weakref.WeakKeyDictionary()
 
     @classmethod
-    def coerce(cls, key, value):
+    def coerce(cls, key: str, value: Any) -> Optional[Any]:
         """Given a value, coerce it into the target type.
 
         Can be overridden by custom subclasses to coerce incoming
@@ -420,7 +450,7 @@ class MutableBase:
         raise ValueError(msg % (key, type(value)))
 
     @classmethod
-    def _get_listen_keys(cls, attribute):
+    def _get_listen_keys(cls, attribute: QueryableAttribute[Any]) -> Set[str]:
         """Given a descriptor attribute, return a ``set()`` of the attribute
         keys which indicate a change in the state of this attribute.
 
@@ -441,7 +471,12 @@ class MutableBase:
         return {attribute.key}
 
     @classmethod
-    def _listen_on_attribute(cls, attribute, coerce, parent_cls):
+    def _listen_on_attribute(
+        cls,
+        attribute: QueryableAttribute[Any],
+        coerce: bool,
+        parent_cls: _ExternalEntityType[Any],
+    ) -> None:
         """Establish this type as a mutation listener for the given
         mapped descriptor.
 
@@ -455,7 +490,7 @@ class MutableBase:
 
         listen_keys = cls._get_listen_keys(attribute)
 
-        def load(state, *args):
+        def load(state: InstanceState[_O], *args: Any) -> None:
             """Listen for objects loaded or refreshed.
 
             Wrap the target data member's value with
@@ -469,11 +504,20 @@ class MutableBase:
                     state.dict[key] = val
                 val._parents[state] = key
 
-        def load_attrs(state, ctx, attrs):
+        def load_attrs(
+            state: InstanceState[_O],
+            ctx: Union[object, QueryContext, UOWTransaction],
+            attrs: Iterable[Any],
+        ) -> None:
             if not attrs or listen_keys.intersection(attrs):
                 load(state)
 
-        def set_(target, value, oldvalue, initiator):
+        def set_(
+            target: InstanceState[_O],
+            value: MutableBase | None,
+            oldvalue: MutableBase | None,
+            initiator: AttributeEventToken,
+        ) -> MutableBase | None:
             """Listen for set/replace events on the target
             data member.
 
@@ -493,14 +537,18 @@ class MutableBase:
                 oldvalue._parents.pop(inspect(target), None)
             return value
 
-        def pickle(state, state_dict):
+        def pickle(
+            state: InstanceState[_O], state_dict: Dict[str, Any]
+        ) -> None:
             val = state.dict.get(key, None)
             if val is not None:
                 if "ext.mutable.values" not in state_dict:
                     state_dict["ext.mutable.values"] = defaultdict(list)
                 state_dict["ext.mutable.values"][key].append(val)
 
-        def unpickle(state, state_dict):
+        def unpickle(
+            state: InstanceState[_O], state_dict: Dict[str, Any]
+        ) -> None:
             if "ext.mutable.values" in state_dict:
                 collection = state_dict["ext.mutable.values"]
                 if isinstance(collection, list):
@@ -543,14 +591,16 @@ class Mutable(MutableBase):
 
     """
 
-    def changed(self):
+    def changed(self) -> None:
         """Subclasses should call this method whenever change events occur."""
 
         for parent, key in self._parents.items():
             flag_modified(parent.obj(), key)
 
     @classmethod
-    def associate_with_attribute(cls, attribute):
+    def associate_with_attribute(
+        cls, attribute: InstrumentedAttribute[_O]
+    ) -> None:
         """Establish this type as a mutation listener for the given
         mapped descriptor.
 
@@ -558,7 +608,7 @@ class Mutable(MutableBase):
         cls._listen_on_attribute(attribute, True, attribute.class_)
 
     @classmethod
-    def associate_with(cls, sqltype):
+    def associate_with(cls, sqltype: type) -> None:
         """Associate this wrapper with all future mapped columns
         of the given type.
 
@@ -575,7 +625,7 @@ class Mutable(MutableBase):
 
         """
 
-        def listen_for_type(mapper, class_):
+        def listen_for_type(mapper: Mapper[_O], class_: type) -> None:
             if mapper.non_primary:
                 return
             for prop in mapper.column_attrs:
@@ -585,7 +635,7 @@ class Mutable(MutableBase):
         event.listen(Mapper, "mapper_configured", listen_for_type)
 
     @classmethod
-    def as_mutable(cls, sqltype):
+    def as_mutable(cls, sqltype: TypeEngine[_T]) -> TypeEngine[_T]:
         """Associate a SQL type with this mutable Python type.
 
         This establishes listeners that will detect ORM mappings against
@@ -625,21 +675,27 @@ class Mutable(MutableBase):
         if isinstance(sqltype, SchemaEventTarget):
 
             @event.listens_for(sqltype, "before_parent_attach")
-            def _add_column_memo(sqltyp, parent):
+            def _add_column_memo(
+                sqltyp: TypeEngine[Any],
+                parent: Column[_T],
+            ) -> None:
                 parent.info["_ext_mutable_orig_type"] = sqltyp
 
             schema_event_check = True
         else:
             schema_event_check = False
 
-        def listen_for_type(mapper, class_):
+        def listen_for_type(
+            mapper: Mapper[_T],
+            class_: Union[DeclarativeAttributeIntercept, type],
+        ) -> None:
             if mapper.non_primary:
                 return
             for prop in mapper.column_attrs:
                 if (
                     schema_event_check
                     and hasattr(prop.expression, "info")
-                    and prop.expression.info.get("_ext_mutable_orig_type")
+                    and prop.expression.info.get("_ext_mutable_orig_type")  # type: ignore # noqa: E501 # TODO: https://github.com/python/mypy/issues/1424#issuecomment-1272354487
                     is sqltype
                 ) or (prop.columns[0].type is sqltype):
                     cls.associate_with_attribute(getattr(class_, prop.key))
@@ -659,10 +715,10 @@ class MutableComposite(MutableBase):
     """
 
     @classmethod
-    def _get_listen_keys(cls, attribute):
+    def _get_listen_keys(cls, attribute: QueryableAttribute[_O]) -> Set[str]:
         return {attribute.key}.union(attribute.property._attribute_keys)
 
-    def changed(self):
+    def changed(self) -> None:
         """Subclasses should call this method whenever change events occur."""
 
         for parent, key in self._parents.items():
@@ -675,8 +731,8 @@ class MutableComposite(MutableBase):
                 setattr(parent.obj(), attr_name, value)
 
 
-def _setup_composite_listener():
-    def _listen_for_type(mapper, class_):
+def _setup_composite_listener() -> None:
+    def _listen_for_type(mapper: Mapper[_T], class_: type) -> None:
         for prop in mapper.iterate_properties:
             if (
                 hasattr(prop, "composite_class")
@@ -694,7 +750,7 @@ def _setup_composite_listener():
 _setup_composite_listener()
 
 
-class MutableDict(Mutable, dict):
+class MutableDict(Mutable, Dict[_KT, _VT]):
     """A dictionary type that implements :class:`.Mutable`.
 
     The :class:`.MutableDict` object implements a dictionary that will
@@ -717,41 +773,69 @@ class MutableDict(Mutable, dict):
 
     """
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: _KT, value: _VT) -> None:
         """Detect dictionary set events and emit change events."""
-        dict.__setitem__(self, key, value)
+        super().__setitem__(key, value)
         self.changed()
 
-    def setdefault(self, key, value):
-        result = dict.setdefault(self, key, value)
+    def _exists(self, value: _T | None) -> TypeGuard[_T]:
+        return value is not None
+
+    def _is_none(self, value: _T | None) -> TypeGuard[None]:
+        return value is None
+
+    @overload
+    def setdefault(self, key: _KT) -> _VT | None:
+        ...
+
+    @overload
+    def setdefault(self, key: _KT, value: _VT) -> _VT:
+        ...
+
+    def setdefault(self, key: _KT, value: _VT | None = None) -> _VT | None:
+        if self._exists(value):
+            result = super().setdefault(key, value)
+        else:
+            result = super().setdefault(key)  # type: ignore[call-arg]
         self.changed()
         return result
 
-    def __delitem__(self, key):
+    def __delitem__(self, key: _KT) -> None:
         """Detect dictionary del events and emit change events."""
-        dict.__delitem__(self, key)
+        super().__delitem__(key)
         self.changed()
 
-    def update(self, *a, **kw):
-        dict.update(self, *a, **kw)
+    def update(self, *a: Any, **kw: _VT) -> None:
+        super().update(*a, **kw)
         self.changed()
 
-    def pop(self, *arg):
-        result = dict.pop(self, *arg)
+    @overload
+    def pop(self, __key: _KT) -> _VT:
+        ...
+
+    @overload
+    def pop(self, __key: _KT, __default: _VT | _T) -> _VT | _T:
+        ...
+
+    def pop(self, __key: _KT, __default: _VT | _T | None = None) -> _VT | _T:
+        if self._exists(__default):
+            result = super().pop(__key, __default)
+        else:
+            result = super().pop(__key)
         self.changed()
         return result
 
-    def popitem(self):
-        result = dict.popitem(self)
+    def popitem(self) -> Tuple[_KT, _VT]:
+        result = super().popitem()
         self.changed()
         return result
 
-    def clear(self):
-        dict.clear(self)
+    def clear(self) -> None:
+        super().clear()
         self.changed()
 
     @classmethod
-    def coerce(cls, key, value):
+    def coerce(cls, key: str, value: Any) -> MutableDict[_KT, _VT] | None:
         """Convert plain dictionary to instance of this class."""
         if not isinstance(value, cls):
             if isinstance(value, dict):
@@ -760,14 +844,16 @@ class MutableDict(Mutable, dict):
         else:
             return value
 
-    def __getstate__(self):
+    def __getstate__(self) -> dict[_KT, _VT]:
         return dict(self)
 
-    def __setstate__(self, state):
+    def __setstate__(
+        self, state: Union[Dict[str, int], Dict[str, str]]
+    ) -> None:
         self.update(state)
 
 
-class MutableList(Mutable, list):
+class MutableList(Mutable, List[_T]):
     """A list type that implements :class:`.Mutable`.
 
     The :class:`.MutableList` object implements a list that will
@@ -792,83 +878,88 @@ class MutableList(Mutable, list):
 
     """
 
-    def __reduce_ex__(self, proto):
+    def __reduce_ex__(
+        self, proto: SupportsIndex
+    ) -> Tuple[type, Tuple[List[int]]]:
         return (self.__class__, (list(self),))
 
     # needed for backwards compatibility with
     # older pickles
-    def __setstate__(self, state):
+    def __setstate__(self, state: Iterable[_T]) -> None:
         self[:] = state
 
-    def __setitem__(self, index, value):
-        """Detect list set events and emit change events."""
-        list.__setitem__(self, index, value)
-        self.changed()
+    def is_scalar(self, value: _T | Iterable[_T]) -> TypeGuard[_T]:
+        return not isinstance(value, Iterable)
 
-    def __setslice__(self, start, end, value):
-        """Detect list set events and emit change events."""
-        list.__setslice__(self, start, end, value)
-        self.changed()
+    def is_iterable(self, value: _T | Iterable[_T]) -> TypeGuard[Iterable[_T]]:
+        return isinstance(value, Iterable)
 
-    def __delitem__(self, index):
-        """Detect list del events and emit change events."""
-        list.__delitem__(self, index)
+    def __setitem__(
+        self, index: SupportsIndex | slice, value: _T | Iterable[_T]
+    ) -> None:
+        """Detect list set events and emit change events."""
+        if isinstance(index, SupportsIndex) and self.is_scalar(value):
+            super().__setitem__(index, value)
+        elif isinstance(index, slice) and self.is_iterable(value):
+            super().__setitem__(index, value)
         self.changed()
 
-    def __delslice__(self, start, end):
+    def __delitem__(self, index: SupportsIndex | slice) -> None:
         """Detect list del events and emit change events."""
-        list.__delslice__(self, start, end)
+        super().__delitem__(index)
         self.changed()
 
-    def pop(self, *arg):
-        result = list.pop(self, *arg)
+    def pop(self, *arg: SupportsIndex) -> _T:
+        result = super().pop(*arg)
         self.changed()
         return result
 
-    def append(self, x):
-        list.append(self, x)
+    def append(self, x: _T) -> None:
+        super().append(x)
         self.changed()
 
-    def extend(self, x):
-        list.extend(self, x)
+    def extend(self, x: Iterable[_T]) -> None:
+        super().extend(x)
         self.changed()
 
-    def __iadd__(self, x):
+    def __iadd__(self, x: Iterable[_T]) -> MutableList[_T]:  # type: ignore[override,misc] # noqa: E501
         self.extend(x)
         return self
 
-    def insert(self, i, x):
-        list.insert(self, i, x)
+    def insert(self, i: SupportsIndex, x: _T) -> None:
+        super().insert(i, x)
         self.changed()
 
-    def remove(self, i):
-        list.remove(self, i)
+    def remove(self, i: _T) -> None:
+        super().remove(i)
         self.changed()
 
-    def clear(self):
-        list.clear(self)
+    def clear(self) -> None:
+        super().clear()
         self.changed()
 
-    def sort(self, **kw):
-        list.sort(self, **kw)
+    def sort(self, **kw: Any) -> None:
+        super().sort(**kw)
         self.changed()
 
-    def reverse(self):
-        list.reverse(self)
+    def reverse(self) -> None:
+        super().reverse()
         self.changed()
 
     @classmethod
-    def coerce(cls, index, value):
+    def coerce(
+        cls, key: str, value: MutableList[_T] | _T
+    ) -> Optional[MutableList[_T]]:
         """Convert plain list to instance of this class."""
         if not isinstance(value, cls):
             if isinstance(value, list):
                 return cls(value)
-            return Mutable.coerce(index, value)
+            return Mutable.coerce(key, value)
         else:
             return value
 
 
-class MutableSet(Mutable, set):
+class MutableSet(Mutable, Set[_T]):
     """A set type that implements :class:`.Mutable`.
 
     The :class:`.MutableSet` object implements a set that will
@@ -894,61 +985,61 @@ class MutableSet(Mutable, set):
 
     """
 
-    def update(self, *arg):
-        set.update(self, *arg)
+    def update(self, *arg: Iterable[_T]) -> None:
+        super().update(*arg)
         self.changed()
 
-    def intersection_update(self, *arg):
-        set.intersection_update(self, *arg)
+    def intersection_update(self, *arg: Iterable[Any]) -> None:
+        super().intersection_update(*arg)
         self.changed()
 
-    def difference_update(self, *arg):
-        set.difference_update(self, *arg)
+    def difference_update(self, *arg: Iterable[Any]) -> None:
+        super().difference_update(*arg)
         self.changed()
 
-    def symmetric_difference_update(self, *arg):
-        set.symmetric_difference_update(self, *arg)
+    def symmetric_difference_update(self, *arg: Iterable[_T]) -> None:
+        super().symmetric_difference_update(*arg)
         self.changed()
 
-    def __ior__(self, other):
+    def __ior__(self, other: AbstractSet[_T]) -> MutableSet[_T]:  # type: ignore[override,misc] # noqa: E501
         self.update(other)
         return self
 
-    def __iand__(self, other):
+    def __iand__(self, other: AbstractSet[object]) -> MutableSet[_T]:
         self.intersection_update(other)
         return self
 
-    def __ixor__(self, other):
+    def __ixor__(self, other: AbstractSet[_T]) -> MutableSet[_T]:  # type: ignore[override,misc] # noqa: E501
         self.symmetric_difference_update(other)
         return self
 
-    def __isub__(self, other):
+    def __isub__(self, other: AbstractSet[object]) -> MutableSet[_T]:  # type: ignore[misc] # noqa: E501
         self.difference_update(other)
         return self
 
-    def add(self, elem):
-        set.add(self, elem)
+    def add(self, elem: _T) -> None:
+        super().add(elem)
         self.changed()
 
-    def remove(self, elem):
-        set.remove(self, elem)
+    def remove(self, elem: _T) -> None:
+        super().remove(elem)
         self.changed()
 
-    def discard(self, elem):
-        set.discard(self, elem)
+    def discard(self, elem: _T) -> None:
+        super().discard(elem)
         self.changed()
 
-    def pop(self, *arg):
-        result = set.pop(self, *arg)
+    def pop(self, *arg: Any) -> _T:
+        result = super().pop(*arg)
         self.changed()
         return result
 
-    def clear(self):
-        set.clear(self)
+    def clear(self) -> None:
+        super().clear()
         self.changed()
 
     @classmethod
-    def coerce(cls, index, value):
+    def coerce(cls, index: str, value: Any) -> Optional[MutableSet[_T]]:
         """Convert plain set to instance of this class."""
         if not isinstance(value, cls):
             if isinstance(value, set):
@@ -957,11 +1048,13 @@ class MutableSet(Mutable, set):
         else:
             return value
 
-    def __getstate__(self):
+    def __getstate__(self) -> set[_T]:
         return set(self)
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Iterable[_T]) -> None:
         self.update(state)
 
-    def __reduce_ex__(self, proto):
+    def __reduce_ex__(
+        self, proto: SupportsIndex
+    ) -> Tuple[type, Tuple[List[int]]]:
         return (self.__class__, (list(self),))