From 61cd65977216b68cc09e85ec731e5c5963d8fc0d Mon Sep 17 00:00:00 2001 From: Gleb Kisenkov Date: Thu, 10 Nov 2022 17:07:53 +0100 Subject: [PATCH] Corrected some autogenerated type hints here and there --- lib/sqlalchemy/ext/mutable.py | 243 +++++++++++++++------------------- 1 file changed, 110 insertions(+), 133 deletions(-) diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index 6ceb8a1d5d..15ca466c76 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors r"""Provide support for tracking of in-place changes to scalar values, which are propagated into ORM change events on owning parent objects. @@ -358,38 +357,44 @@ pickling process of the parent's object-relational state so that the from __future__ import annotations from collections import defaultdict +from typing import AbstractSet from typing import Any from typing import Dict +from typing import Iterable from typing import List from typing import Optional from typing import Self from typing import Set +from typing import SupportsIndex from typing import Tuple +from typing import TypeVar from typing import Union import weakref from weakref import WeakKeyDictionary -from sqlalchemy.orm.attributes import AttributeEventToken -from sqlalchemy.orm.attributes import create_proxied_attribute -from sqlalchemy.orm.attributes import InstrumentedAttribute -from sqlalchemy.orm.base import LoaderCallableStatus -from sqlalchemy.orm.context import QueryContext -from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept -from sqlalchemy.orm.state import InstanceState -from sqlalchemy.orm.unitofwork import UOWTransaction -from sqlalchemy.sql.schema import Column -from sqlalchemy.sql.sqltypes import JSON -from sqlalchemy.sql.sqltypes import PickleType -from test.ext.test_mutable import CustomMutableAssociationScalarJSONTest -from test.ext.test_mutable import MutableColumnCopyJSONTest from .. import event from .. import inspect from .. import types from ..orm import Mapper +from ..orm._typing import _ExternalEntityType +from ..orm._typing import _O +from ..orm._typing import _T +from ..orm.attributes import AttributeEventToken from ..orm.attributes import flag_modified +from ..orm.attributes import InstrumentedAttribute +from ..orm.attributes import QueryableAttribute +from ..orm.context import QueryContext +from ..orm.decl_api import DeclarativeAttributeIntercept +from ..orm.state import InstanceState +from ..orm.unitofwork import UOWTransaction from ..sql.base import SchemaEventTarget +from ..sql.schema import Column +from ..sql.type_api import TypeEngine from ..util import memoized_property +_KT = TypeVar("_KT") # Key type. +_VT = TypeVar("_VT") # Value type. + class MutableBase: """Common base class to :class:`.Mutable` @@ -398,7 +403,7 @@ class MutableBase: """ @memoized_property - def _parents(self) -> WeakKeyDictionary: + def _parents(self) -> WeakKeyDictionary[Any, Any]: """Dictionary of parent object's :class:`.InstanceState`->attribute name on the parent. @@ -444,7 +449,7 @@ class MutableBase: raise ValueError(msg % (key, type(value))) @classmethod - def _get_listen_keys(cls, attribute: InstrumentedAttribute) -> Set[str]: + def _get_listen_keys(cls, attribute: QueryableAttribute[Any]) -> Set[str]: """Given a descriptor attribute, return a ``set()`` of the attribute keys which indicate a change in the state of this attribute. @@ -467,11 +472,9 @@ class MutableBase: @classmethod def _listen_on_attribute( cls, - attribute: Union[ - InstrumentedAttribute, create_proxied_attribute.Proxy - ], + attribute: QueryableAttribute[Any], coerce: bool, - parent_cls: Union[DeclarativeAttributeIntercept, type], + parent_cls: _ExternalEntityType[Any], ) -> None: """Establish this type as a mutation listener for the given mapped descriptor. @@ -486,7 +489,7 @@ class MutableBase: listen_keys = cls._get_listen_keys(attribute) - def load(state: InstanceState, *args: Any) -> None: + def load(state: InstanceState[_O], *args: Any) -> None: """Listen for objects loaded or refreshed. Wrap the target data member's value with @@ -501,27 +504,19 @@ class MutableBase: val._parents[state] = key def load_attrs( - state: InstanceState, + state: InstanceState[_O], ctx: Union[object, QueryContext, UOWTransaction], - attrs: Union[List[str], frozenset], + attrs: Iterable[Any], ) -> None: if not attrs or listen_keys.intersection(attrs): load(state) def set_( - target: InstanceState, + target: InstanceState[_O], value: Any, - oldvalue: Union[ - MutableDict, - LoaderCallableStatus, - CustomMutableAssociationScalarJSONTest._type_fixture.CustomMutableDict, # noqa: E501 - ], + oldvalue: Any, initiator: AttributeEventToken, - ) -> Union[ - None, - MutableDict, - CustomMutableAssociationScalarJSONTest._type_fixture.CustomMutableDict, # noqa: E501 - ]: + ) -> Any: """Listen for set/replace events on the target data member. @@ -541,14 +536,18 @@ class MutableBase: oldvalue._parents.pop(inspect(target), None) return value - def pickle(state: InstanceState, state_dict: Dict[str, Any]) -> None: + def pickle( + state: InstanceState[_O], state_dict: Dict[str, Any] + ) -> None: val = state.dict.get(key, None) if val is not None: if "ext.mutable.values" not in state_dict: state_dict["ext.mutable.values"] = defaultdict(list) state_dict["ext.mutable.values"][key].append(val) - def unpickle(state: InstanceState, state_dict: Dict[str, Any]) -> None: + def unpickle( + state: InstanceState[_O], state_dict: Dict[str, Any] + ) -> None: if "ext.mutable.values" in state_dict: collection = state_dict["ext.mutable.values"] if isinstance(collection, list): @@ -599,7 +598,7 @@ class Mutable(MutableBase): @classmethod def associate_with_attribute( - cls, attribute: InstrumentedAttribute + cls, attribute: InstrumentedAttribute[_O] ) -> None: """Establish this type as a mutation listener for the given mapped descriptor. @@ -625,7 +624,7 @@ class Mutable(MutableBase): """ - def listen_for_type(mapper: Mapper, class_: type) -> None: + def listen_for_type(mapper: Mapper[_O], class_: type) -> None: if mapper.non_primary: return for prop in mapper.column_attrs: @@ -635,13 +634,7 @@ class Mutable(MutableBase): event.listen(Mapper, "mapper_configured", listen_for_type) @classmethod - def as_mutable( - cls, sqltype: type - ) -> Union[ - JSON, - PickleType, - MutableColumnCopyJSONTest.define_tables.JSONEncodedDict, - ]: + def as_mutable(cls, sqltype: TypeEngine[_T]) -> TypeEngine[_T]: """Associate a SQL type with this mutable Python type. This establishes listeners that will detect ORM mappings against @@ -682,11 +675,8 @@ class Mutable(MutableBase): @event.listens_for(sqltype, "before_parent_attach") def _add_column_memo( - sqltyp: Union[ - PickleType, - MutableColumnCopyJSONTest.define_tables.JSONEncodedDict, - ], - parent: Column, + sqltyp: TypeEngine[Any], + parent: Column[_T], ) -> None: parent.info["_ext_mutable_orig_type"] = sqltyp @@ -695,7 +685,8 @@ class Mutable(MutableBase): schema_event_check = False def listen_for_type( - mapper: Mapper, class_: Union[DeclarativeAttributeIntercept, type] + mapper: Mapper[_T], + class_: Union[DeclarativeAttributeIntercept, type], ) -> None: if mapper.non_primary: return @@ -703,7 +694,7 @@ class Mutable(MutableBase): if ( schema_event_check and hasattr(prop.expression, "info") - and prop.expression.info.get("_ext_mutable_orig_type") + and prop.expression.info.get("_ext_mutable_orig_type") # type: ignore # noqa: E501 # TODO: https://github.com/python/mypy/issues/1424#issuecomment-1272354487 is sqltype ) or (prop.columns[0].type is sqltype): cls.associate_with_attribute(getattr(class_, prop.key)) @@ -723,9 +714,7 @@ class MutableComposite(MutableBase): """ @classmethod - def _get_listen_keys( - cls, attribute: create_proxied_attribute.Proxy - ) -> Set[str]: + def _get_listen_keys(cls, attribute: QueryableAttribute[_O]) -> Set[str]: return {attribute.key}.union(attribute.property._attribute_keys) def changed(self) -> None: @@ -742,7 +731,7 @@ class MutableComposite(MutableBase): def _setup_composite_listener() -> None: - def _listen_for_type(mapper: Mapper, class_: type) -> None: + def _listen_for_type(mapper: Mapper[_T], class_: type) -> None: for prop in mapper.iterate_properties: if ( hasattr(prop, "composite_class") @@ -760,7 +749,7 @@ def _setup_composite_listener() -> None: _setup_composite_listener() -class MutableDict(Mutable, dict): +class MutableDict(Mutable, dict[_KT, _VT]): """A dictionary type that implements :class:`.Mutable`. The :class:`.MutableDict` object implements a dictionary that will @@ -783,32 +772,32 @@ class MutableDict(Mutable, dict): """ - def __setitem__(self, key: str, value: Union[int, str]) -> None: + def __setitem__(self: Self, key: _KT, value: _VT) -> None: """Detect dictionary set events and emit change events.""" - dict.__setitem__(self, key, value) + super().__setitem__(key, value) self.changed() - def setdefault(self, key: str, value: str) -> str: - result = dict.setdefault(self, key, value) + def setdefault(self, key: _KT, value: _VT) -> _VT: + result = super().setdefault(key, value) self.changed() return result - def __delitem__(self, key): + def __delitem__(self, key: _KT): """Detect dictionary del events and emit change events.""" - dict.__delitem__(self, key) + super().__delitem__(key) self.changed() - def update(self, *a: Any, **kw: Any) -> None: - dict.update(self, *a, **kw) + def update(self, *a: Any, **kw: _VT) -> None: + super().update(*a, **kw) self.changed() - def pop(self, *arg: str) -> str: - result = dict.pop(self, *arg) + def pop(self, *arg: _KT) -> _VT: + result = super().pop(*arg) self.changed() return result - def popitem(self) -> Tuple[str, str]: - result = dict.popitem(self) + def popitem(self) -> Tuple[_KT, _VT]: + result = super().popitem() self.changed() return result @@ -817,13 +806,7 @@ class MutableDict(Mutable, dict): self.changed() @classmethod - def coerce( - cls, key: str, value: Any - ) -> Union[ - None, - Self, - CustomMutableAssociationScalarJSONTest._type_fixture.CustomMutableDict, - ]: + def coerce(cls, key: str, value: Any) -> Self | None: """Convert plain dictionary to instance of this class.""" if not isinstance(value, cls): if isinstance(value, dict): @@ -832,7 +815,7 @@ class MutableDict(Mutable, dict): else: return value - def __getstate__(self) -> Union[Dict[str, int], Dict[str, str]]: + def __getstate__(self) -> dict[_KT, _VT]: return dict(self) def __setstate__( @@ -841,7 +824,7 @@ class MutableDict(Mutable, dict): self.update(state) -class MutableList(Mutable, list): +class MutableList(Mutable, list[_T]): """A list type that implements :class:`.Mutable`. The :class:`.MutableList` object implements a list that will @@ -866,69 +849,61 @@ class MutableList(Mutable, list): """ - def __reduce_ex__(self, proto: int) -> Tuple[type, Tuple[List[int]]]: + def __reduce_ex__( + self, proto: SupportsIndex + ) -> Tuple[type, Tuple[List[int]]]: return (self.__class__, (list(self),)) # needed for backwards compatibility with # older pickles - def __setstate__(self, state: List[int]) -> None: + def __setstate__(self, state: Iterable[_T]) -> None: self[:] = state def __setitem__( self, - index: Union[int, slice], - value: Union[List[int], Tuple[int, int], int], + index: Union[SupportsIndex, slice], + value: Union[_T, Iterable[_T]], ) -> None: """Detect list set events and emit change events.""" - list.__setitem__(self, index, value) - self.changed() - - def __setslice__(self, start, end, value): - """Detect list set events and emit change events.""" - list.__setslice__(self, start, end, value) - self.changed() - - def __delitem__(self, index: slice) -> None: - """Detect list del events and emit change events.""" - list.__delitem__(self, index) + super().__setitem__(index, value) self.changed() - def __delslice__(self, start, end): + def __delitem__(self, index: SupportsIndex | slice) -> None: """Detect list del events and emit change events.""" - list.__delslice__(self, start, end) + super().__delitem__(index) self.changed() - def pop(self, *arg: int) -> int: - result = list.pop(self, *arg) + def pop(self, *arg: int) -> _T: + result = super().pop(*arg) self.changed() return result - def append(self, x: int) -> None: - list.append(self, x) + def append(self, x: _T) -> None: + super().append(x) self.changed() - def extend(self, x: List[int]) -> None: - list.extend(self, x) + def extend(self, x: Iterable[_T]) -> None: + super().extend(x) self.changed() - def __iadd__(self, x: List[int]) -> Self: + def __iadd__(self, x: Iterable[_T]) -> Self: self.extend(x) return self - def insert(self, i: int, x: int) -> None: - list.insert(self, i, x) + def insert(self, i: SupportsIndex, x: _T) -> None: + super().insert(i, x) self.changed() - def remove(self, i: int) -> None: - list.remove(self, i) + def remove(self, i: _T) -> None: + super().remove(i) self.changed() def clear(self) -> None: - list.clear(self) + super().clear() self.changed() def sort(self, **kw: Any) -> None: - list.sort(self, **kw) + super().sort(**kw) self.changed() def reverse(self) -> None: @@ -946,7 +921,7 @@ class MutableList(Mutable, list): return value -class MutableSet(Mutable, set): +class MutableSet(Mutable, set[_T]): """A set type that implements :class:`.Mutable`. The :class:`.MutableSet` object implements a set that will @@ -972,57 +947,57 @@ class MutableSet(Mutable, set): """ - def update(self, *arg: Set[int]) -> None: - set.update(self, *arg) + def update(self, *arg: Iterable[_T]) -> None: + super().update(*arg) self.changed() - def intersection_update(self, *arg: Set[int]) -> None: - set.intersection_update(self, *arg) + def intersection_update(self, *arg: Iterable[Any]) -> None: + super().intersection_update(*arg) self.changed() - def difference_update(self, *arg: Set[int]) -> None: - set.difference_update(self, *arg) + def difference_update(self, *arg: Iterable[Any]) -> None: + super().difference_update(*arg) self.changed() - def symmetric_difference_update(self, *arg: Set[int]) -> None: - set.symmetric_difference_update(self, *arg) + def symmetric_difference_update(self, *arg: Iterable[_T]) -> None: + super().symmetric_difference_update(*arg) self.changed() - def __ior__(self, other: Set[int]) -> Self: + def __ior__(self, other: Iterable[_T]) -> Self: self.update(other) return self - def __iand__(self, other: Set[int]) -> Self: + def __iand__(self, other: AbstractSet[object]) -> Self: self.intersection_update(other) return self - def __ixor__(self, other: Set[int]) -> Self: + def __ixor__(self, other: AbstractSet[_T]) -> Self: self.symmetric_difference_update(other) return self - def __isub__(self, other: Set[int]) -> Self: + def __isub__(self, other: AbstractSet[object]) -> Self: self.difference_update(other) return self - def add(self, elem: int) -> None: - set.add(self, elem) + def add(self, elem: _T) -> None: + super().add(elem) self.changed() - def remove(self, elem: int) -> None: - set.remove(self, elem) + def remove(self, elem: _T) -> None: + super().remove(elem) self.changed() - def discard(self, elem: int) -> None: - set.discard(self, elem) + def discard(self, elem: _T) -> None: + super().discard(elem) self.changed() - def pop(self, *arg: Any) -> int: - result = set.pop(self, *arg) + def pop(self, *arg: Any) -> _T: + result = super().pop(*arg) self.changed() return result def clear(self) -> None: - set.clear(self) + super().clear() self.changed() @classmethod @@ -1035,11 +1010,13 @@ class MutableSet(Mutable, set): else: return value - def __getstate__(self): + def __getstate__(self) -> set[_T]: return set(self) - def __setstate__(self, state): + def __setstate__(self, state: Iterable[_T]) -> None: self.update(state) - def __reduce_ex__(self, proto: int) -> Tuple[type, Tuple[List[int]]]: + def __reduce_ex__( + self, proto: SupportsIndex + ) -> Tuple[type, Tuple[List[int]]]: return (self.__class__, (list(self),)) -- 2.47.3