From: Gleb Kisenkov Date: Wed, 16 Nov 2022 15:23:06 +0000 (-0500) Subject: Type annotations for sqlalchemy.ext.mutable X-Git-Tag: rel_2_0_0b4~42^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ba0e508141206efc55cdab91df21c18e7dd63c80;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Type annotations for sqlalchemy.ext.mutable The ``sqlalchemy.ext.mutable`` extension is now fully pep-484 typed. Huge thanks to Gleb Kisenkov for their efforts on this. Fixes: #8667 Closes: #8775 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8775 Pull-request-sha: b907888ec67facc12dbdbccd6f2d9cd533b08a50 Change-Id: Id9224e03201e6970b1ec56eb546ece4b2f3e0edd --- diff --git a/doc/build/changelog/unreleased_20/8667.rst b/doc/build/changelog/unreleased_20/8667.rst new file mode 100644 index 0000000000..50dc5d844f --- /dev/null +++ b/doc/build/changelog/unreleased_20/8667.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: bug, typing + :tickets: 8667 + + The ``sqlalchemy.ext.mutable`` extension is now fully pep-484 typed. Huge + thanks to Gleb Kisenkov for their efforts on this. diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index 1f102cb36f..f9ed17efc1 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors r"""Provide support for tracking of in-place changes to scalar values, which are propagated into ORM change events on owning parent objects. @@ -355,16 +354,47 @@ pickling process of the parent's object-relational state so that the """ # noqa: E501 +from __future__ import annotations + from collections import defaultdict +from typing import AbstractSet +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import TypeVar +from typing import Union import weakref +from weakref import WeakKeyDictionary from .. import event from .. import inspect from .. import types from ..orm import Mapper +from ..orm._typing import _ExternalEntityType +from ..orm._typing import _O +from ..orm._typing import _T +from ..orm.attributes import AttributeEventToken from ..orm.attributes import flag_modified +from ..orm.attributes import InstrumentedAttribute +from ..orm.attributes import QueryableAttribute +from ..orm.context import QueryContext +from ..orm.decl_api import DeclarativeAttributeIntercept +from ..orm.state import InstanceState +from ..orm.unitofwork import UOWTransaction from ..sql.base import SchemaEventTarget +from ..sql.schema import Column +from ..sql.type_api import TypeEngine from ..util import memoized_property +from ..util.typing import SupportsIndex +from ..util.typing import TypeGuard + +_KT = TypeVar("_KT") # Key type. +_VT = TypeVar("_VT") # Value type. class MutableBase: @@ -374,7 +404,7 @@ class MutableBase: """ @memoized_property - def _parents(self): + def _parents(self) -> WeakKeyDictionary[Any, Any]: """Dictionary of parent object's :class:`.InstanceState`->attribute name on the parent. @@ -391,7 +421,7 @@ class MutableBase: return weakref.WeakKeyDictionary() @classmethod - def coerce(cls, key, value): + def coerce(cls, key: str, value: Any) -> Optional[Any]: """Given a value, coerce it into the target type. Can be overridden by custom subclasses to coerce incoming @@ -420,7 +450,7 @@ class MutableBase: raise ValueError(msg % (key, type(value))) @classmethod - def _get_listen_keys(cls, attribute): + def _get_listen_keys(cls, attribute: QueryableAttribute[Any]) -> Set[str]: """Given a descriptor attribute, return a ``set()`` of the attribute keys which indicate a change in the state of this attribute. @@ -441,7 +471,12 @@ class MutableBase: return {attribute.key} @classmethod - def _listen_on_attribute(cls, attribute, coerce, parent_cls): + def _listen_on_attribute( + cls, + attribute: QueryableAttribute[Any], + coerce: bool, + parent_cls: _ExternalEntityType[Any], + ) -> None: """Establish this type as a mutation listener for the given mapped descriptor. @@ -455,7 +490,7 @@ class MutableBase: listen_keys = cls._get_listen_keys(attribute) - def load(state, *args): + def load(state: InstanceState[_O], *args: Any) -> None: """Listen for objects loaded or refreshed. Wrap the target data member's value with @@ -469,11 +504,20 @@ class MutableBase: state.dict[key] = val val._parents[state] = key - def load_attrs(state, ctx, attrs): + def load_attrs( + state: InstanceState[_O], + ctx: Union[object, QueryContext, UOWTransaction], + attrs: Iterable[Any], + ) -> None: if not attrs or listen_keys.intersection(attrs): load(state) - def set_(target, value, oldvalue, initiator): + def set_( + target: InstanceState[_O], + value: MutableBase | None, + oldvalue: MutableBase | None, + initiator: AttributeEventToken, + ) -> MutableBase | None: """Listen for set/replace events on the target data member. @@ -493,14 +537,18 @@ class MutableBase: oldvalue._parents.pop(inspect(target), None) return value - def pickle(state, state_dict): + def pickle( + state: InstanceState[_O], state_dict: Dict[str, Any] + ) -> None: val = state.dict.get(key, None) if val is not None: if "ext.mutable.values" not in state_dict: state_dict["ext.mutable.values"] = defaultdict(list) state_dict["ext.mutable.values"][key].append(val) - def unpickle(state, state_dict): + def unpickle( + state: InstanceState[_O], state_dict: Dict[str, Any] + ) -> None: if "ext.mutable.values" in state_dict: collection = state_dict["ext.mutable.values"] if isinstance(collection, list): @@ -543,14 +591,16 @@ class Mutable(MutableBase): """ - def changed(self): + def changed(self) -> None: """Subclasses should call this method whenever change events occur.""" for parent, key in self._parents.items(): flag_modified(parent.obj(), key) @classmethod - def associate_with_attribute(cls, attribute): + def associate_with_attribute( + cls, attribute: InstrumentedAttribute[_O] + ) -> None: """Establish this type as a mutation listener for the given mapped descriptor. @@ -558,7 +608,7 @@ class Mutable(MutableBase): cls._listen_on_attribute(attribute, True, attribute.class_) @classmethod - def associate_with(cls, sqltype): + def associate_with(cls, sqltype: type) -> None: """Associate this wrapper with all future mapped columns of the given type. @@ -575,7 +625,7 @@ class Mutable(MutableBase): """ - def listen_for_type(mapper, class_): + def listen_for_type(mapper: Mapper[_O], class_: type) -> None: if mapper.non_primary: return for prop in mapper.column_attrs: @@ -585,7 +635,7 @@ class Mutable(MutableBase): event.listen(Mapper, "mapper_configured", listen_for_type) @classmethod - def as_mutable(cls, sqltype): + def as_mutable(cls, sqltype: TypeEngine[_T]) -> TypeEngine[_T]: """Associate a SQL type with this mutable Python type. This establishes listeners that will detect ORM mappings against @@ -625,21 +675,27 @@ class Mutable(MutableBase): if isinstance(sqltype, SchemaEventTarget): @event.listens_for(sqltype, "before_parent_attach") - def _add_column_memo(sqltyp, parent): + def _add_column_memo( + sqltyp: TypeEngine[Any], + parent: Column[_T], + ) -> None: parent.info["_ext_mutable_orig_type"] = sqltyp schema_event_check = True else: schema_event_check = False - def listen_for_type(mapper, class_): + def listen_for_type( + mapper: Mapper[_T], + class_: Union[DeclarativeAttributeIntercept, type], + ) -> None: if mapper.non_primary: return for prop in mapper.column_attrs: if ( schema_event_check and hasattr(prop.expression, "info") - and prop.expression.info.get("_ext_mutable_orig_type") + and prop.expression.info.get("_ext_mutable_orig_type") # type: ignore # noqa: E501 # TODO: https://github.com/python/mypy/issues/1424#issuecomment-1272354487 is sqltype ) or (prop.columns[0].type is sqltype): cls.associate_with_attribute(getattr(class_, prop.key)) @@ -659,10 +715,10 @@ class MutableComposite(MutableBase): """ @classmethod - def _get_listen_keys(cls, attribute): + def _get_listen_keys(cls, attribute: QueryableAttribute[_O]) -> Set[str]: return {attribute.key}.union(attribute.property._attribute_keys) - def changed(self): + def changed(self) -> None: """Subclasses should call this method whenever change events occur.""" for parent, key in self._parents.items(): @@ -675,8 +731,8 @@ class MutableComposite(MutableBase): setattr(parent.obj(), attr_name, value) -def _setup_composite_listener(): - def _listen_for_type(mapper, class_): +def _setup_composite_listener() -> None: + def _listen_for_type(mapper: Mapper[_T], class_: type) -> None: for prop in mapper.iterate_properties: if ( hasattr(prop, "composite_class") @@ -694,7 +750,7 @@ def _setup_composite_listener(): _setup_composite_listener() -class MutableDict(Mutable, dict): +class MutableDict(Mutable, Dict[_KT, _VT]): """A dictionary type that implements :class:`.Mutable`. The :class:`.MutableDict` object implements a dictionary that will @@ -717,41 +773,69 @@ class MutableDict(Mutable, dict): """ - def __setitem__(self, key, value): + def __setitem__(self, key: _KT, value: _VT) -> None: """Detect dictionary set events and emit change events.""" - dict.__setitem__(self, key, value) + super().__setitem__(key, value) self.changed() - def setdefault(self, key, value): - result = dict.setdefault(self, key, value) + def _exists(self, value: _T | None) -> TypeGuard[_T]: + return value is not None + + def _is_none(self, value: _T | None) -> TypeGuard[None]: + return value is None + + @overload + def setdefault(self, key: _KT) -> _VT | None: + ... + + @overload + def setdefault(self, key: _KT, value: _VT) -> _VT: + ... + + def setdefault(self, key: _KT, value: _VT | None = None) -> _VT | None: + if self._exists(value): + result = super().setdefault(key, value) + else: + result = super().setdefault(key) # type: ignore[call-arg] self.changed() return result - def __delitem__(self, key): + def __delitem__(self, key: _KT) -> None: """Detect dictionary del events and emit change events.""" - dict.__delitem__(self, key) + super().__delitem__(key) self.changed() - def update(self, *a, **kw): - dict.update(self, *a, **kw) + def update(self, *a: Any, **kw: _VT) -> None: + super().update(*a, **kw) self.changed() - def pop(self, *arg): - result = dict.pop(self, *arg) + @overload + def pop(self, __key: _KT) -> _VT: + ... + + @overload + def pop(self, __key: _KT, __default: _VT | _T) -> _VT | _T: + ... + + def pop(self, __key: _KT, __default: _VT | _T | None = None) -> _VT | _T: + if self._exists(__default): + result = super().pop(__key, __default) + else: + result = super().pop(__key) self.changed() return result - def popitem(self): - result = dict.popitem(self) + def popitem(self) -> Tuple[_KT, _VT]: + result = super().popitem() self.changed() return result - def clear(self): - dict.clear(self) + def clear(self) -> None: + super().clear() self.changed() @classmethod - def coerce(cls, key, value): + def coerce(cls, key: str, value: Any) -> MutableDict[_KT, _VT] | None: """Convert plain dictionary to instance of this class.""" if not isinstance(value, cls): if isinstance(value, dict): @@ -760,14 +844,16 @@ class MutableDict(Mutable, dict): else: return value - def __getstate__(self): + def __getstate__(self) -> dict[_KT, _VT]: return dict(self) - def __setstate__(self, state): + def __setstate__( + self, state: Union[Dict[str, int], Dict[str, str]] + ) -> None: self.update(state) -class MutableList(Mutable, list): +class MutableList(Mutable, List[_T]): """A list type that implements :class:`.Mutable`. The :class:`.MutableList` object implements a list that will @@ -792,83 +878,88 @@ class MutableList(Mutable, list): """ - def __reduce_ex__(self, proto): + def __reduce_ex__( + self, proto: SupportsIndex + ) -> Tuple[type, Tuple[List[int]]]: return (self.__class__, (list(self),)) # needed for backwards compatibility with # older pickles - def __setstate__(self, state): + def __setstate__(self, state: Iterable[_T]) -> None: self[:] = state - def __setitem__(self, index, value): - """Detect list set events and emit change events.""" - list.__setitem__(self, index, value) - self.changed() + def is_scalar(self, value: _T | Iterable[_T]) -> TypeGuard[_T]: + return not isinstance(value, Iterable) - def __setslice__(self, start, end, value): - """Detect list set events and emit change events.""" - list.__setslice__(self, start, end, value) - self.changed() + def is_iterable(self, value: _T | Iterable[_T]) -> TypeGuard[Iterable[_T]]: + return isinstance(value, Iterable) - def __delitem__(self, index): - """Detect list del events and emit change events.""" - list.__delitem__(self, index) + def __setitem__( + self, index: SupportsIndex | slice, value: _T | Iterable[_T] + ) -> None: + """Detect list set events and emit change events.""" + if isinstance(index, SupportsIndex) and self.is_scalar(value): + super().__setitem__(index, value) + elif isinstance(index, slice) and self.is_iterable(value): + super().__setitem__(index, value) self.changed() - def __delslice__(self, start, end): + def __delitem__(self, index: SupportsIndex | slice) -> None: """Detect list del events and emit change events.""" - list.__delslice__(self, start, end) + super().__delitem__(index) self.changed() - def pop(self, *arg): - result = list.pop(self, *arg) + def pop(self, *arg: SupportsIndex) -> _T: + result = super().pop(*arg) self.changed() return result - def append(self, x): - list.append(self, x) + def append(self, x: _T) -> None: + super().append(x) self.changed() - def extend(self, x): - list.extend(self, x) + def extend(self, x: Iterable[_T]) -> None: + super().extend(x) self.changed() - def __iadd__(self, x): + def __iadd__(self, x: Iterable[_T]) -> MutableList[_T]: # type: ignore[override,misc] # noqa: E501 self.extend(x) return self - def insert(self, i, x): - list.insert(self, i, x) + def insert(self, i: SupportsIndex, x: _T) -> None: + super().insert(i, x) self.changed() - def remove(self, i): - list.remove(self, i) + def remove(self, i: _T) -> None: + super().remove(i) self.changed() - def clear(self): - list.clear(self) + def clear(self) -> None: + super().clear() self.changed() - def sort(self, **kw): - list.sort(self, **kw) + def sort(self, **kw: Any) -> None: + super().sort(**kw) self.changed() - def reverse(self): - list.reverse(self) + def reverse(self) -> None: + super().reverse() self.changed() @classmethod - def coerce(cls, index, value): + def coerce( + cls, key: str, value: MutableList[_T] | _T + ) -> Optional[MutableList[_T]]: """Convert plain list to instance of this class.""" if not isinstance(value, cls): if isinstance(value, list): return cls(value) - return Mutable.coerce(index, value) + return Mutable.coerce(key, value) else: return value -class MutableSet(Mutable, set): +class MutableSet(Mutable, Set[_T]): """A set type that implements :class:`.Mutable`. The :class:`.MutableSet` object implements a set that will @@ -894,61 +985,61 @@ class MutableSet(Mutable, set): """ - def update(self, *arg): - set.update(self, *arg) + def update(self, *arg: Iterable[_T]) -> None: + super().update(*arg) self.changed() - def intersection_update(self, *arg): - set.intersection_update(self, *arg) + def intersection_update(self, *arg: Iterable[Any]) -> None: + super().intersection_update(*arg) self.changed() - def difference_update(self, *arg): - set.difference_update(self, *arg) + def difference_update(self, *arg: Iterable[Any]) -> None: + super().difference_update(*arg) self.changed() - def symmetric_difference_update(self, *arg): - set.symmetric_difference_update(self, *arg) + def symmetric_difference_update(self, *arg: Iterable[_T]) -> None: + super().symmetric_difference_update(*arg) self.changed() - def __ior__(self, other): + def __ior__(self, other: AbstractSet[_T]) -> MutableSet[_T]: # type: ignore[override,misc] # noqa: E501 self.update(other) return self - def __iand__(self, other): + def __iand__(self, other: AbstractSet[object]) -> MutableSet[_T]: self.intersection_update(other) return self - def __ixor__(self, other): + def __ixor__(self, other: AbstractSet[_T]) -> MutableSet[_T]: # type: ignore[override,misc] # noqa: E501 self.symmetric_difference_update(other) return self - def __isub__(self, other): + def __isub__(self, other: AbstractSet[object]) -> MutableSet[_T]: # type: ignore[misc] # noqa: E501 self.difference_update(other) return self - def add(self, elem): - set.add(self, elem) + def add(self, elem: _T) -> None: + super().add(elem) self.changed() - def remove(self, elem): - set.remove(self, elem) + def remove(self, elem: _T) -> None: + super().remove(elem) self.changed() - def discard(self, elem): - set.discard(self, elem) + def discard(self, elem: _T) -> None: + super().discard(elem) self.changed() - def pop(self, *arg): - result = set.pop(self, *arg) + def pop(self, *arg: Any) -> _T: + result = super().pop(*arg) self.changed() return result - def clear(self): - set.clear(self) + def clear(self) -> None: + super().clear() self.changed() @classmethod - def coerce(cls, index, value): + def coerce(cls, index: str, value: Any) -> Optional[MutableSet[_T]]: """Convert plain set to instance of this class.""" if not isinstance(value, cls): if isinstance(value, set): @@ -957,11 +1048,13 @@ class MutableSet(Mutable, set): else: return value - def __getstate__(self): + def __getstate__(self) -> set[_T]: return set(self) - def __setstate__(self, state): + def __setstate__(self, state: Iterable[_T]) -> None: self.update(state) - def __reduce_ex__(self, proto): + def __reduce_ex__( + self, proto: SupportsIndex + ) -> Tuple[type, Tuple[List[int]]]: return (self.__class__, (list(self),))