]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Auto runtime types + minor manual updates
authorGleb Kisenkov <g.kisenkov@godeltech.com>
Tue, 8 Nov 2022 15:45:37 +0000 (16:45 +0100)
committerGleb Kisenkov <g.kisenkov@godeltech.com>
Tue, 8 Nov 2022 15:45:37 +0000 (16:45 +0100)
lib/sqlalchemy/ext/mutable.py

index 1f102cb36f914f5545ee3554659f1d852f8f0f64..6ceb8a1d5d85822e0d49efced08053fe2b6c8c9b 100644 (file)
@@ -355,9 +355,33 @@ pickling process of the parent's object-relational state so that the
 
 """  # noqa: E501
 
+from __future__ import annotations
+
 from collections import defaultdict
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Self
+from typing import Set
+from typing import Tuple
+from typing import Union
 import weakref
-
+from weakref import WeakKeyDictionary
+
+from sqlalchemy.orm.attributes import AttributeEventToken
+from sqlalchemy.orm.attributes import create_proxied_attribute
+from sqlalchemy.orm.attributes import InstrumentedAttribute
+from sqlalchemy.orm.base import LoaderCallableStatus
+from sqlalchemy.orm.context import QueryContext
+from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept
+from sqlalchemy.orm.state import InstanceState
+from sqlalchemy.orm.unitofwork import UOWTransaction
+from sqlalchemy.sql.schema import Column
+from sqlalchemy.sql.sqltypes import JSON
+from sqlalchemy.sql.sqltypes import PickleType
+from test.ext.test_mutable import CustomMutableAssociationScalarJSONTest
+from test.ext.test_mutable import MutableColumnCopyJSONTest
 from .. import event
 from .. import inspect
 from .. import types
@@ -374,7 +398,7 @@ class MutableBase:
     """
 
     @memoized_property
-    def _parents(self):
+    def _parents(self) -> WeakKeyDictionary:
         """Dictionary of parent object's :class:`.InstanceState`->attribute
         name on the parent.
 
@@ -391,7 +415,7 @@ class MutableBase:
         return weakref.WeakKeyDictionary()
 
     @classmethod
-    def coerce(cls, key, value):
+    def coerce(cls, key: str, value: Any) -> Optional[Any]:
         """Given a value, coerce it into the target type.
 
         Can be overridden by custom subclasses to coerce incoming
@@ -420,7 +444,7 @@ class MutableBase:
         raise ValueError(msg % (key, type(value)))
 
     @classmethod
-    def _get_listen_keys(cls, attribute):
+    def _get_listen_keys(cls, attribute: InstrumentedAttribute) -> Set[str]:
         """Given a descriptor attribute, return a ``set()`` of the attribute
         keys which indicate a change in the state of this attribute.
 
@@ -441,7 +465,14 @@ class MutableBase:
         return {attribute.key}
 
     @classmethod
-    def _listen_on_attribute(cls, attribute, coerce, parent_cls):
+    def _listen_on_attribute(
+        cls,
+        attribute: Union[
+            InstrumentedAttribute, create_proxied_attribute.Proxy
+        ],
+        coerce: bool,
+        parent_cls: Union[DeclarativeAttributeIntercept, type],
+    ) -> None:
         """Establish this type as a mutation listener for the given
         mapped descriptor.
 
@@ -455,7 +486,7 @@ class MutableBase:
 
         listen_keys = cls._get_listen_keys(attribute)
 
-        def load(state, *args):
+        def load(state: InstanceState, *args: Any) -> None:
             """Listen for objects loaded or refreshed.
 
             Wrap the target data member's value with
@@ -469,11 +500,28 @@ class MutableBase:
                     state.dict[key] = val
                 val._parents[state] = key
 
-        def load_attrs(state, ctx, attrs):
+        def load_attrs(
+            state: InstanceState,
+            ctx: Union[object, QueryContext, UOWTransaction],
+            attrs: Union[List[str], frozenset],
+        ) -> None:
             if not attrs or listen_keys.intersection(attrs):
                 load(state)
 
-        def set_(target, value, oldvalue, initiator):
+        def set_(
+            target: InstanceState,
+            value: Any,
+            oldvalue: Union[
+                MutableDict,
+                LoaderCallableStatus,
+                CustomMutableAssociationScalarJSONTest._type_fixture.CustomMutableDict,  # noqa: E501
+            ],
+            initiator: AttributeEventToken,
+        ) -> Union[
+            None,
+            MutableDict,
+            CustomMutableAssociationScalarJSONTest._type_fixture.CustomMutableDict,  # noqa: E501
+        ]:
             """Listen for set/replace events on the target
             data member.
 
@@ -493,14 +541,14 @@ class MutableBase:
                 oldvalue._parents.pop(inspect(target), None)
             return value
 
-        def pickle(state, state_dict):
+        def pickle(state: InstanceState, state_dict: Dict[str, Any]) -> None:
             val = state.dict.get(key, None)
             if val is not None:
                 if "ext.mutable.values" not in state_dict:
                     state_dict["ext.mutable.values"] = defaultdict(list)
                 state_dict["ext.mutable.values"][key].append(val)
 
-        def unpickle(state, state_dict):
+        def unpickle(state: InstanceState, state_dict: Dict[str, Any]) -> None:
             if "ext.mutable.values" in state_dict:
                 collection = state_dict["ext.mutable.values"]
                 if isinstance(collection, list):
@@ -543,14 +591,16 @@ class Mutable(MutableBase):
 
     """
 
-    def changed(self):
+    def changed(self) -> None:
         """Subclasses should call this method whenever change events occur."""
 
         for parent, key in self._parents.items():
             flag_modified(parent.obj(), key)
 
     @classmethod
-    def associate_with_attribute(cls, attribute):
+    def associate_with_attribute(
+        cls, attribute: InstrumentedAttribute
+    ) -> None:
         """Establish this type as a mutation listener for the given
         mapped descriptor.
 
@@ -558,7 +608,7 @@ class Mutable(MutableBase):
         cls._listen_on_attribute(attribute, True, attribute.class_)
 
     @classmethod
-    def associate_with(cls, sqltype):
+    def associate_with(cls, sqltype: type) -> None:
         """Associate this wrapper with all future mapped columns
         of the given type.
 
@@ -575,7 +625,7 @@ class Mutable(MutableBase):
 
         """
 
-        def listen_for_type(mapper, class_):
+        def listen_for_type(mapper: Mapper, class_: type) -> None:
             if mapper.non_primary:
                 return
             for prop in mapper.column_attrs:
@@ -585,7 +635,13 @@ class Mutable(MutableBase):
         event.listen(Mapper, "mapper_configured", listen_for_type)
 
     @classmethod
-    def as_mutable(cls, sqltype):
+    def as_mutable(
+        cls, sqltype: type
+    ) -> Union[
+        JSON,
+        PickleType,
+        MutableColumnCopyJSONTest.define_tables.JSONEncodedDict,
+    ]:
         """Associate a SQL type with this mutable Python type.
 
         This establishes listeners that will detect ORM mappings against
@@ -625,14 +681,22 @@ class Mutable(MutableBase):
         if isinstance(sqltype, SchemaEventTarget):
 
             @event.listens_for(sqltype, "before_parent_attach")
-            def _add_column_memo(sqltyp, parent):
+            def _add_column_memo(
+                sqltyp: Union[
+                    PickleType,
+                    MutableColumnCopyJSONTest.define_tables.JSONEncodedDict,
+                ],
+                parent: Column,
+            ) -> None:
                 parent.info["_ext_mutable_orig_type"] = sqltyp
 
             schema_event_check = True
         else:
             schema_event_check = False
 
-        def listen_for_type(mapper, class_):
+        def listen_for_type(
+            mapper: Mapper, class_: Union[DeclarativeAttributeIntercept, type]
+        ) -> None:
             if mapper.non_primary:
                 return
             for prop in mapper.column_attrs:
@@ -659,10 +723,12 @@ class MutableComposite(MutableBase):
     """
 
     @classmethod
-    def _get_listen_keys(cls, attribute):
+    def _get_listen_keys(
+        cls, attribute: create_proxied_attribute.Proxy
+    ) -> Set[str]:
         return {attribute.key}.union(attribute.property._attribute_keys)
 
-    def changed(self):
+    def changed(self) -> None:
         """Subclasses should call this method whenever change events occur."""
 
         for parent, key in self._parents.items():
@@ -675,8 +741,8 @@ class MutableComposite(MutableBase):
                 setattr(parent.obj(), attr_name, value)
 
 
-def _setup_composite_listener():
-    def _listen_for_type(mapper, class_):
+def _setup_composite_listener() -> None:
+    def _listen_for_type(mapper: Mapper, class_: type) -> None:
         for prop in mapper.iterate_properties:
             if (
                 hasattr(prop, "composite_class")
@@ -717,12 +783,12 @@ class MutableDict(Mutable, dict):
 
     """
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: str, value: Union[int, str]) -> None:
         """Detect dictionary set events and emit change events."""
         dict.__setitem__(self, key, value)
         self.changed()
 
-    def setdefault(self, key, value):
+    def setdefault(self, key: str, value: str) -> str:
         result = dict.setdefault(self, key, value)
         self.changed()
         return result
@@ -732,26 +798,32 @@ class MutableDict(Mutable, dict):
         dict.__delitem__(self, key)
         self.changed()
 
-    def update(self, *a, **kw):
+    def update(self, *a: Any, **kw: Any) -> None:
         dict.update(self, *a, **kw)
         self.changed()
 
-    def pop(self, *arg):
+    def pop(self, *arg: str) -> str:
         result = dict.pop(self, *arg)
         self.changed()
         return result
 
-    def popitem(self):
+    def popitem(self) -> Tuple[str, str]:
         result = dict.popitem(self)
         self.changed()
         return result
 
-    def clear(self):
+    def clear(self) -> None:
         dict.clear(self)
         self.changed()
 
     @classmethod
-    def coerce(cls, key, value):
+    def coerce(
+        cls, key: str, value: Any
+    ) -> Union[
+        None,
+        Self,
+        CustomMutableAssociationScalarJSONTest._type_fixture.CustomMutableDict,
+    ]:
         """Convert plain dictionary to instance of this class."""
         if not isinstance(value, cls):
             if isinstance(value, dict):
@@ -760,10 +832,12 @@ class MutableDict(Mutable, dict):
         else:
             return value
 
-    def __getstate__(self):
+    def __getstate__(self) -> Union[Dict[str, int], Dict[str, str]]:
         return dict(self)
 
-    def __setstate__(self, state):
+    def __setstate__(
+        self, state: Union[Dict[str, int], Dict[str, str]]
+    ) -> None:
         self.update(state)
 
 
@@ -792,15 +866,19 @@ class MutableList(Mutable, list):
 
     """
 
-    def __reduce_ex__(self, proto):
+    def __reduce_ex__(self, proto: int) -> Tuple[type, Tuple[List[int]]]:
         return (self.__class__, (list(self),))
 
     # needed for backwards compatibility with
     # older pickles
-    def __setstate__(self, state):
+    def __setstate__(self, state: List[int]) -> None:
         self[:] = state
 
-    def __setitem__(self, index, value):
+    def __setitem__(
+        self,
+        index: Union[int, slice],
+        value: Union[List[int], Tuple[int, int], int],
+    ) -> None:
         """Detect list set events and emit change events."""
         list.__setitem__(self, index, value)
         self.changed()
@@ -810,7 +888,7 @@ class MutableList(Mutable, list):
         list.__setslice__(self, start, end, value)
         self.changed()
 
-    def __delitem__(self, index):
+    def __delitem__(self, index: slice) -> None:
         """Detect list del events and emit change events."""
         list.__delitem__(self, index)
         self.changed()
@@ -820,45 +898,45 @@ class MutableList(Mutable, list):
         list.__delslice__(self, start, end)
         self.changed()
 
-    def pop(self, *arg):
+    def pop(self, *arg: int) -> int:
         result = list.pop(self, *arg)
         self.changed()
         return result
 
-    def append(self, x):
+    def append(self, x: int) -> None:
         list.append(self, x)
         self.changed()
 
-    def extend(self, x):
+    def extend(self, x: List[int]) -> None:
         list.extend(self, x)
         self.changed()
 
-    def __iadd__(self, x):
+    def __iadd__(self, x: List[int]) -> Self:
         self.extend(x)
         return self
 
-    def insert(self, i, x):
+    def insert(self, i: int, x: int) -> None:
         list.insert(self, i, x)
         self.changed()
 
-    def remove(self, i):
+    def remove(self, i: int) -> None:
         list.remove(self, i)
         self.changed()
 
-    def clear(self):
+    def clear(self) -> None:
         list.clear(self)
         self.changed()
 
-    def sort(self, **kw):
+    def sort(self, **kw: Any) -> None:
         list.sort(self, **kw)
         self.changed()
 
-    def reverse(self):
+    def reverse(self) -> None:
         list.reverse(self)
         self.changed()
 
     @classmethod
-    def coerce(cls, index, value):
+    def coerce(cls, index: str, value: Any) -> Optional[Self]:
         """Convert plain list to instance of this class."""
         if not isinstance(value, cls):
             if isinstance(value, list):
@@ -894,61 +972,61 @@ class MutableSet(Mutable, set):
 
     """
 
-    def update(self, *arg):
+    def update(self, *arg: Set[int]) -> None:
         set.update(self, *arg)
         self.changed()
 
-    def intersection_update(self, *arg):
+    def intersection_update(self, *arg: Set[int]) -> None:
         set.intersection_update(self, *arg)
         self.changed()
 
-    def difference_update(self, *arg):
+    def difference_update(self, *arg: Set[int]) -> None:
         set.difference_update(self, *arg)
         self.changed()
 
-    def symmetric_difference_update(self, *arg):
+    def symmetric_difference_update(self, *arg: Set[int]) -> None:
         set.symmetric_difference_update(self, *arg)
         self.changed()
 
-    def __ior__(self, other):
+    def __ior__(self, other: Set[int]) -> Self:
         self.update(other)
         return self
 
-    def __iand__(self, other):
+    def __iand__(self, other: Set[int]) -> Self:
         self.intersection_update(other)
         return self
 
-    def __ixor__(self, other):
+    def __ixor__(self, other: Set[int]) -> Self:
         self.symmetric_difference_update(other)
         return self
 
-    def __isub__(self, other):
+    def __isub__(self, other: Set[int]) -> Self:
         self.difference_update(other)
         return self
 
-    def add(self, elem):
+    def add(self, elem: int) -> None:
         set.add(self, elem)
         self.changed()
 
-    def remove(self, elem):
+    def remove(self, elem: int) -> None:
         set.remove(self, elem)
         self.changed()
 
-    def discard(self, elem):
+    def discard(self, elem: int) -> None:
         set.discard(self, elem)
         self.changed()
 
-    def pop(self, *arg):
+    def pop(self, *arg: Any) -> int:
         result = set.pop(self, *arg)
         self.changed()
         return result
 
-    def clear(self):
+    def clear(self) -> None:
         set.clear(self)
         self.changed()
 
     @classmethod
-    def coerce(cls, index, value):
+    def coerce(cls, index: str, value: Any) -> Optional[Self]:
         """Convert plain set to instance of this class."""
         if not isinstance(value, cls):
             if isinstance(value, set):
@@ -963,5 +1041,5 @@ class MutableSet(Mutable, set):
     def __setstate__(self, state):
         self.update(state)
 
-    def __reduce_ex__(self, proto):
+    def __reduce_ex__(self, proto: int) -> Tuple[type, Tuple[List[int]]]:
         return (self.__class__, (list(self),))