]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixed all mypy errors
authorGleb Kisenkov <g.kisenkov@godeltech.com>
Tue, 15 Nov 2022 10:46:01 +0000 (11:46 +0100)
committerGleb Kisenkov <g.kisenkov@godeltech.com>
Tue, 15 Nov 2022 10:46:01 +0000 (11:46 +0100)
lib/sqlalchemy/ext/mutable.py

index 15ca466c76c209d3d7b29607a8af1ceddd6cfb6c..aa8fb8e46e20046980d37f746d5bf2a5c1fd6158 100644 (file)
@@ -363,10 +363,11 @@ from typing import Dict
 from typing import Iterable
 from typing import List
 from typing import Optional
-from typing import Self
+from typing import overload
 from typing import Set
 from typing import SupportsIndex
 from typing import Tuple
+from typing import TypeGuard
 from typing import TypeVar
 from typing import Union
 import weakref
@@ -772,17 +773,34 @@ class MutableDict(Mutable, dict[_KT, _VT]):
 
     """
 
-    def __setitem__(self: Self, key: _KT, value: _VT) -> None:
+    def __setitem__(self, key: _KT, value: _VT) -> None:
         """Detect dictionary set events and emit change events."""
         super().__setitem__(key, value)
         self.changed()
 
+    def _exists(self, value: _T | None) -> TypeGuard[_T]:
+        return value is not None
+
+    def _is_none(self, value: _T | None) -> TypeGuard[None]:
+        return value is None
+
+    @overload
+    def setdefault(self, key: _KT) -> _VT | None:
+        ...
+
+    @overload
     def setdefault(self, key: _KT, value: _VT) -> _VT:
-        result = super().setdefault(key, value)
+        ...
+
+    def setdefault(self, key: _KT, value: _VT | None = None) -> _VT | None:
+        if self._exists(value):
+            result = super().setdefault(key, value)
+        else:
+            result = super().setdefault(key)  # type: ignore[call-arg]
         self.changed()
         return result
 
-    def __delitem__(self, key: _KT):
+    def __delitem__(self, key: _KT) -> None:
         """Detect dictionary del events and emit change events."""
         super().__delitem__(key)
         self.changed()
@@ -791,8 +809,19 @@ class MutableDict(Mutable, dict[_KT, _VT]):
         super().update(*a, **kw)
         self.changed()
 
-    def pop(self, *arg: _KT) -> _VT:
-        result = super().pop(*arg)
+    @overload
+    def pop(self, __key: _KT) -> _VT:
+        ...
+
+    @overload
+    def pop(self, __key: _KT, __default: _VT | _T) -> _VT | _T:
+        ...
+
+    def pop(self, __key: _KT, __default: _VT | _T | None = None) -> _VT | _T:
+        if self._exists(__default):
+            result = super().pop(__key, __default)
+        else:
+            result = super().pop(__key)
         self.changed()
         return result
 
@@ -802,11 +831,11 @@ class MutableDict(Mutable, dict[_KT, _VT]):
         return result
 
     def clear(self) -> None:
-        dict.clear(self)
+        super().clear()
         self.changed()
 
     @classmethod
-    def coerce(cls, key: str, value: Any) -> Self | None:
+    def coerce(cls, key: str, value: Any) -> MutableDict[_KT, _VT] | None:
         """Convert plain dictionary to instance of this class."""
         if not isinstance(value, cls):
             if isinstance(value, dict):
@@ -859,13 +888,20 @@ class MutableList(Mutable, list[_T]):
     def __setstate__(self, state: Iterable[_T]) -> None:
         self[:] = state
 
+    def is_scalar(self, value: _T | Iterable[_T]) -> TypeGuard[_T]:
+        return not isinstance(value, Iterable)
+
+    def is_iterable(self, value: _T | Iterable[_T]) -> TypeGuard[Iterable[_T]]:
+        return isinstance(value, Iterable)
+
     def __setitem__(
-        self,
-        index: Union[SupportsIndex, slice],
-        value: Union[_T, Iterable[_T]],
+        self, index: SupportsIndex | slice, value: _T | Iterable[_T]
     ) -> None:
         """Detect list set events and emit change events."""
-        super().__setitem__(index, value)
+        if isinstance(index, SupportsIndex) and self.is_scalar(value):
+            super().__setitem__(index, value)
+        elif isinstance(index, slice) and self.is_iterable(value):
+            super().__setitem__(index, value)
         self.changed()
 
     def __delitem__(self, index: SupportsIndex | slice) -> None:
@@ -873,7 +909,7 @@ class MutableList(Mutable, list[_T]):
         super().__delitem__(index)
         self.changed()
 
-    def pop(self, *arg: int) -> _T:
+    def pop(self, *arg: SupportsIndex) -> _T:
         result = super().pop(*arg)
         self.changed()
         return result
@@ -886,7 +922,7 @@ class MutableList(Mutable, list[_T]):
         super().extend(x)
         self.changed()
 
-    def __iadd__(self, x: Iterable[_T]) -> Self:
+    def __iadd__(self, x: Iterable[_T]) -> MutableList[_T]:  # type: ignore[override,misc] # noqa: E501
         self.extend(x)
         return self
 
@@ -907,16 +943,18 @@ class MutableList(Mutable, list[_T]):
         self.changed()
 
     def reverse(self) -> None:
-        list.reverse(self)
+        super().reverse()
         self.changed()
 
     @classmethod
-    def coerce(cls, index: str, value: Any) -> Optional[Self]:
+    def coerce(
+        cls, key: str, value: MutableList[_T] | _T
+    ) -> Optional[MutableList[_T]]:
         """Convert plain list to instance of this class."""
         if not isinstance(value, cls):
             if isinstance(value, list):
                 return cls(value)
-            return Mutable.coerce(index, value)
+            return Mutable.coerce(key, value)
         else:
             return value
 
@@ -963,19 +1001,19 @@ class MutableSet(Mutable, set[_T]):
         super().symmetric_difference_update(*arg)
         self.changed()
 
-    def __ior__(self, other: Iterable[_T]) -> Self:
+    def __ior__(self, other: AbstractSet[_T]) -> MutableSet[_T]:  # type: ignore[override,misc] # noqa: E501
         self.update(other)
         return self
 
-    def __iand__(self, other: AbstractSet[object]) -> Self:
+    def __iand__(self, other: AbstractSet[object]) -> MutableSet[_T]:
         self.intersection_update(other)
         return self
 
-    def __ixor__(self, other: AbstractSet[_T]) -> Self:
+    def __ixor__(self, other: AbstractSet[_T]) -> MutableSet[_T]:  # type: ignore[override,misc] # noqa: E501
         self.symmetric_difference_update(other)
         return self
 
-    def __isub__(self, other: AbstractSet[object]) -> Self:
+    def __isub__(self, other: AbstractSet[object]) -> MutableSet[_T]:  # type: ignore[misc] # noqa: E501
         self.difference_update(other)
         return self
 
@@ -1001,7 +1039,7 @@ class MutableSet(Mutable, set[_T]):
         self.changed()
 
     @classmethod
-    def coerce(cls, index: str, value: Any) -> Optional[Self]:
+    def coerce(cls, index: str, value: Any) -> Optional[MutableSet[_T]]:
         """Convert plain set to instance of this class."""
         if not isinstance(value, cls):
             if isinstance(value, set):