From 8bdfdcc6d7ae5464a8bf08bdde1cf7f2841730af Mon Sep 17 00:00:00 2001 From: Gleb Kisenkov Date: Tue, 8 Nov 2022 16:45:37 +0100 Subject: [PATCH] Auto runtime types + minor manual updates --- lib/sqlalchemy/ext/mutable.py | 196 ++++++++++++++++++++++++---------- 1 file changed, 137 insertions(+), 59 deletions(-) diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index 1f102cb36f..6ceb8a1d5d 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -355,9 +355,33 @@ pickling process of the parent's object-relational state so that the """ # noqa: E501 +from __future__ import annotations + from collections import defaultdict +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Self +from typing import Set +from typing import Tuple +from typing import Union import weakref - +from weakref import WeakKeyDictionary + +from sqlalchemy.orm.attributes import AttributeEventToken +from sqlalchemy.orm.attributes import create_proxied_attribute +from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.orm.base import LoaderCallableStatus +from sqlalchemy.orm.context import QueryContext +from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept +from sqlalchemy.orm.state import InstanceState +from sqlalchemy.orm.unitofwork import UOWTransaction +from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.sqltypes import JSON +from sqlalchemy.sql.sqltypes import PickleType +from test.ext.test_mutable import CustomMutableAssociationScalarJSONTest +from test.ext.test_mutable import MutableColumnCopyJSONTest from .. import event from .. import inspect from .. import types @@ -374,7 +398,7 @@ class MutableBase: """ @memoized_property - def _parents(self): + def _parents(self) -> WeakKeyDictionary: """Dictionary of parent object's :class:`.InstanceState`->attribute name on the parent. @@ -391,7 +415,7 @@ class MutableBase: return weakref.WeakKeyDictionary() @classmethod - def coerce(cls, key, value): + def coerce(cls, key: str, value: Any) -> Optional[Any]: """Given a value, coerce it into the target type. Can be overridden by custom subclasses to coerce incoming @@ -420,7 +444,7 @@ class MutableBase: raise ValueError(msg % (key, type(value))) @classmethod - def _get_listen_keys(cls, attribute): + def _get_listen_keys(cls, attribute: InstrumentedAttribute) -> Set[str]: """Given a descriptor attribute, return a ``set()`` of the attribute keys which indicate a change in the state of this attribute. @@ -441,7 +465,14 @@ class MutableBase: return {attribute.key} @classmethod - def _listen_on_attribute(cls, attribute, coerce, parent_cls): + def _listen_on_attribute( + cls, + attribute: Union[ + InstrumentedAttribute, create_proxied_attribute.Proxy + ], + coerce: bool, + parent_cls: Union[DeclarativeAttributeIntercept, type], + ) -> None: """Establish this type as a mutation listener for the given mapped descriptor. @@ -455,7 +486,7 @@ class MutableBase: listen_keys = cls._get_listen_keys(attribute) - def load(state, *args): + def load(state: InstanceState, *args: Any) -> None: """Listen for objects loaded or refreshed. Wrap the target data member's value with @@ -469,11 +500,28 @@ class MutableBase: state.dict[key] = val val._parents[state] = key - def load_attrs(state, ctx, attrs): + def load_attrs( + state: InstanceState, + ctx: Union[object, QueryContext, UOWTransaction], + attrs: Union[List[str], frozenset], + ) -> None: if not attrs or listen_keys.intersection(attrs): load(state) - def set_(target, value, oldvalue, initiator): + def set_( + target: InstanceState, + value: Any, + oldvalue: Union[ + MutableDict, + LoaderCallableStatus, + CustomMutableAssociationScalarJSONTest._type_fixture.CustomMutableDict, # noqa: E501 + ], + initiator: AttributeEventToken, + ) -> Union[ + None, + MutableDict, + CustomMutableAssociationScalarJSONTest._type_fixture.CustomMutableDict, # noqa: E501 + ]: """Listen for set/replace events on the target data member. @@ -493,14 +541,14 @@ class MutableBase: oldvalue._parents.pop(inspect(target), None) return value - def pickle(state, state_dict): + def pickle(state: InstanceState, state_dict: Dict[str, Any]) -> None: val = state.dict.get(key, None) if val is not None: if "ext.mutable.values" not in state_dict: state_dict["ext.mutable.values"] = defaultdict(list) state_dict["ext.mutable.values"][key].append(val) - def unpickle(state, state_dict): + def unpickle(state: InstanceState, state_dict: Dict[str, Any]) -> None: if "ext.mutable.values" in state_dict: collection = state_dict["ext.mutable.values"] if isinstance(collection, list): @@ -543,14 +591,16 @@ class Mutable(MutableBase): """ - def changed(self): + def changed(self) -> None: """Subclasses should call this method whenever change events occur.""" for parent, key in self._parents.items(): flag_modified(parent.obj(), key) @classmethod - def associate_with_attribute(cls, attribute): + def associate_with_attribute( + cls, attribute: InstrumentedAttribute + ) -> None: """Establish this type as a mutation listener for the given mapped descriptor. @@ -558,7 +608,7 @@ class Mutable(MutableBase): cls._listen_on_attribute(attribute, True, attribute.class_) @classmethod - def associate_with(cls, sqltype): + def associate_with(cls, sqltype: type) -> None: """Associate this wrapper with all future mapped columns of the given type. @@ -575,7 +625,7 @@ class Mutable(MutableBase): """ - def listen_for_type(mapper, class_): + def listen_for_type(mapper: Mapper, class_: type) -> None: if mapper.non_primary: return for prop in mapper.column_attrs: @@ -585,7 +635,13 @@ class Mutable(MutableBase): event.listen(Mapper, "mapper_configured", listen_for_type) @classmethod - def as_mutable(cls, sqltype): + def as_mutable( + cls, sqltype: type + ) -> Union[ + JSON, + PickleType, + MutableColumnCopyJSONTest.define_tables.JSONEncodedDict, + ]: """Associate a SQL type with this mutable Python type. This establishes listeners that will detect ORM mappings against @@ -625,14 +681,22 @@ class Mutable(MutableBase): if isinstance(sqltype, SchemaEventTarget): @event.listens_for(sqltype, "before_parent_attach") - def _add_column_memo(sqltyp, parent): + def _add_column_memo( + sqltyp: Union[ + PickleType, + MutableColumnCopyJSONTest.define_tables.JSONEncodedDict, + ], + parent: Column, + ) -> None: parent.info["_ext_mutable_orig_type"] = sqltyp schema_event_check = True else: schema_event_check = False - def listen_for_type(mapper, class_): + def listen_for_type( + mapper: Mapper, class_: Union[DeclarativeAttributeIntercept, type] + ) -> None: if mapper.non_primary: return for prop in mapper.column_attrs: @@ -659,10 +723,12 @@ class MutableComposite(MutableBase): """ @classmethod - def _get_listen_keys(cls, attribute): + def _get_listen_keys( + cls, attribute: create_proxied_attribute.Proxy + ) -> Set[str]: return {attribute.key}.union(attribute.property._attribute_keys) - def changed(self): + def changed(self) -> None: """Subclasses should call this method whenever change events occur.""" for parent, key in self._parents.items(): @@ -675,8 +741,8 @@ class MutableComposite(MutableBase): setattr(parent.obj(), attr_name, value) -def _setup_composite_listener(): - def _listen_for_type(mapper, class_): +def _setup_composite_listener() -> None: + def _listen_for_type(mapper: Mapper, class_: type) -> None: for prop in mapper.iterate_properties: if ( hasattr(prop, "composite_class") @@ -717,12 +783,12 @@ class MutableDict(Mutable, dict): """ - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Union[int, str]) -> None: """Detect dictionary set events and emit change events.""" dict.__setitem__(self, key, value) self.changed() - def setdefault(self, key, value): + def setdefault(self, key: str, value: str) -> str: result = dict.setdefault(self, key, value) self.changed() return result @@ -732,26 +798,32 @@ class MutableDict(Mutable, dict): dict.__delitem__(self, key) self.changed() - def update(self, *a, **kw): + def update(self, *a: Any, **kw: Any) -> None: dict.update(self, *a, **kw) self.changed() - def pop(self, *arg): + def pop(self, *arg: str) -> str: result = dict.pop(self, *arg) self.changed() return result - def popitem(self): + def popitem(self) -> Tuple[str, str]: result = dict.popitem(self) self.changed() return result - def clear(self): + def clear(self) -> None: dict.clear(self) self.changed() @classmethod - def coerce(cls, key, value): + def coerce( + cls, key: str, value: Any + ) -> Union[ + None, + Self, + CustomMutableAssociationScalarJSONTest._type_fixture.CustomMutableDict, + ]: """Convert plain dictionary to instance of this class.""" if not isinstance(value, cls): if isinstance(value, dict): @@ -760,10 +832,12 @@ class MutableDict(Mutable, dict): else: return value - def __getstate__(self): + def __getstate__(self) -> Union[Dict[str, int], Dict[str, str]]: return dict(self) - def __setstate__(self, state): + def __setstate__( + self, state: Union[Dict[str, int], Dict[str, str]] + ) -> None: self.update(state) @@ -792,15 +866,19 @@ class MutableList(Mutable, list): """ - def __reduce_ex__(self, proto): + def __reduce_ex__(self, proto: int) -> Tuple[type, Tuple[List[int]]]: return (self.__class__, (list(self),)) # needed for backwards compatibility with # older pickles - def __setstate__(self, state): + def __setstate__(self, state: List[int]) -> None: self[:] = state - def __setitem__(self, index, value): + def __setitem__( + self, + index: Union[int, slice], + value: Union[List[int], Tuple[int, int], int], + ) -> None: """Detect list set events and emit change events.""" list.__setitem__(self, index, value) self.changed() @@ -810,7 +888,7 @@ class MutableList(Mutable, list): list.__setslice__(self, start, end, value) self.changed() - def __delitem__(self, index): + def __delitem__(self, index: slice) -> None: """Detect list del events and emit change events.""" list.__delitem__(self, index) self.changed() @@ -820,45 +898,45 @@ class MutableList(Mutable, list): list.__delslice__(self, start, end) self.changed() - def pop(self, *arg): + def pop(self, *arg: int) -> int: result = list.pop(self, *arg) self.changed() return result - def append(self, x): + def append(self, x: int) -> None: list.append(self, x) self.changed() - def extend(self, x): + def extend(self, x: List[int]) -> None: list.extend(self, x) self.changed() - def __iadd__(self, x): + def __iadd__(self, x: List[int]) -> Self: self.extend(x) return self - def insert(self, i, x): + def insert(self, i: int, x: int) -> None: list.insert(self, i, x) self.changed() - def remove(self, i): + def remove(self, i: int) -> None: list.remove(self, i) self.changed() - def clear(self): + def clear(self) -> None: list.clear(self) self.changed() - def sort(self, **kw): + def sort(self, **kw: Any) -> None: list.sort(self, **kw) self.changed() - def reverse(self): + def reverse(self) -> None: list.reverse(self) self.changed() @classmethod - def coerce(cls, index, value): + def coerce(cls, index: str, value: Any) -> Optional[Self]: """Convert plain list to instance of this class.""" if not isinstance(value, cls): if isinstance(value, list): @@ -894,61 +972,61 @@ class MutableSet(Mutable, set): """ - def update(self, *arg): + def update(self, *arg: Set[int]) -> None: set.update(self, *arg) self.changed() - def intersection_update(self, *arg): + def intersection_update(self, *arg: Set[int]) -> None: set.intersection_update(self, *arg) self.changed() - def difference_update(self, *arg): + def difference_update(self, *arg: Set[int]) -> None: set.difference_update(self, *arg) self.changed() - def symmetric_difference_update(self, *arg): + def symmetric_difference_update(self, *arg: Set[int]) -> None: set.symmetric_difference_update(self, *arg) self.changed() - def __ior__(self, other): + def __ior__(self, other: Set[int]) -> Self: self.update(other) return self - def __iand__(self, other): + def __iand__(self, other: Set[int]) -> Self: self.intersection_update(other) return self - def __ixor__(self, other): + def __ixor__(self, other: Set[int]) -> Self: self.symmetric_difference_update(other) return self - def __isub__(self, other): + def __isub__(self, other: Set[int]) -> Self: self.difference_update(other) return self - def add(self, elem): + def add(self, elem: int) -> None: set.add(self, elem) self.changed() - def remove(self, elem): + def remove(self, elem: int) -> None: set.remove(self, elem) self.changed() - def discard(self, elem): + def discard(self, elem: int) -> None: set.discard(self, elem) self.changed() - def pop(self, *arg): + def pop(self, *arg: Any) -> int: result = set.pop(self, *arg) self.changed() return result - def clear(self): + def clear(self) -> None: set.clear(self) self.changed() @classmethod - def coerce(cls, index, value): + def coerce(cls, index: str, value: Any) -> Optional[Self]: """Convert plain set to instance of this class.""" if not isinstance(value, cls): if isinstance(value, set): @@ -963,5 +1041,5 @@ class MutableSet(Mutable, set): def __setstate__(self, state): self.update(state) - def __reduce_ex__(self, proto): + def __reduce_ex__(self, proto: int) -> Tuple[type, Tuple[List[int]]]: return (self.__class__, (list(self),)) -- 2.47.3