#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
r"""Provide support for tracking of in-place changes to scalar values,
which are propagated into ORM change events on owning parent objects.
""" # noqa: E501
+from __future__ import annotations
+
from collections import defaultdict
+from typing import AbstractSet
+from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Set
+from typing import Tuple
+from typing import TypeVar
+from typing import Union
import weakref
+from weakref import WeakKeyDictionary
from .. import event
from .. import inspect
from .. import types
from ..orm import Mapper
+from ..orm._typing import _ExternalEntityType
+from ..orm._typing import _O
+from ..orm._typing import _T
+from ..orm.attributes import AttributeEventToken
from ..orm.attributes import flag_modified
+from ..orm.attributes import InstrumentedAttribute
+from ..orm.attributes import QueryableAttribute
+from ..orm.context import QueryContext
+from ..orm.decl_api import DeclarativeAttributeIntercept
+from ..orm.state import InstanceState
+from ..orm.unitofwork import UOWTransaction
from ..sql.base import SchemaEventTarget
+from ..sql.schema import Column
+from ..sql.type_api import TypeEngine
from ..util import memoized_property
+from ..util.typing import SupportsIndex
+from ..util.typing import TypeGuard
+
+_KT = TypeVar("_KT") # Key type.
+_VT = TypeVar("_VT") # Value type.
class MutableBase:
"""
@memoized_property
- def _parents(self):
+ def _parents(self) -> WeakKeyDictionary[Any, Any]:
"""Dictionary of parent object's :class:`.InstanceState`->attribute
name on the parent.
return weakref.WeakKeyDictionary()
@classmethod
- def coerce(cls, key, value):
+ def coerce(cls, key: str, value: Any) -> Optional[Any]:
"""Given a value, coerce it into the target type.
Can be overridden by custom subclasses to coerce incoming
raise ValueError(msg % (key, type(value)))
@classmethod
- def _get_listen_keys(cls, attribute):
+ def _get_listen_keys(cls, attribute: QueryableAttribute[Any]) -> Set[str]:
"""Given a descriptor attribute, return a ``set()`` of the attribute
keys which indicate a change in the state of this attribute.
return {attribute.key}
@classmethod
- def _listen_on_attribute(cls, attribute, coerce, parent_cls):
+ def _listen_on_attribute(
+ cls,
+ attribute: QueryableAttribute[Any],
+ coerce: bool,
+ parent_cls: _ExternalEntityType[Any],
+ ) -> None:
"""Establish this type as a mutation listener for the given
mapped descriptor.
listen_keys = cls._get_listen_keys(attribute)
- def load(state, *args):
+ def load(state: InstanceState[_O], *args: Any) -> None:
"""Listen for objects loaded or refreshed.
Wrap the target data member's value with
state.dict[key] = val
val._parents[state] = key
- def load_attrs(state, ctx, attrs):
+ def load_attrs(
+ state: InstanceState[_O],
+ ctx: Union[object, QueryContext, UOWTransaction],
+ attrs: Iterable[Any],
+ ) -> None:
if not attrs or listen_keys.intersection(attrs):
load(state)
- def set_(target, value, oldvalue, initiator):
+ def set_(
+ target: InstanceState[_O],
+ value: MutableBase | None,
+ oldvalue: MutableBase | None,
+ initiator: AttributeEventToken,
+ ) -> MutableBase | None:
"""Listen for set/replace events on the target
data member.
oldvalue._parents.pop(inspect(target), None)
return value
- def pickle(state, state_dict):
+ def pickle(
+ state: InstanceState[_O], state_dict: Dict[str, Any]
+ ) -> None:
val = state.dict.get(key, None)
if val is not None:
if "ext.mutable.values" not in state_dict:
state_dict["ext.mutable.values"] = defaultdict(list)
state_dict["ext.mutable.values"][key].append(val)
- def unpickle(state, state_dict):
+ def unpickle(
+ state: InstanceState[_O], state_dict: Dict[str, Any]
+ ) -> None:
if "ext.mutable.values" in state_dict:
collection = state_dict["ext.mutable.values"]
if isinstance(collection, list):
"""
- def changed(self):
+ def changed(self) -> None:
"""Subclasses should call this method whenever change events occur."""
for parent, key in self._parents.items():
flag_modified(parent.obj(), key)
@classmethod
- def associate_with_attribute(cls, attribute):
+ def associate_with_attribute(
+ cls, attribute: InstrumentedAttribute[_O]
+ ) -> None:
"""Establish this type as a mutation listener for the given
mapped descriptor.
cls._listen_on_attribute(attribute, True, attribute.class_)
@classmethod
- def associate_with(cls, sqltype):
+ def associate_with(cls, sqltype: type) -> None:
"""Associate this wrapper with all future mapped columns
of the given type.
"""
- def listen_for_type(mapper, class_):
+ def listen_for_type(mapper: Mapper[_O], class_: type) -> None:
if mapper.non_primary:
return
for prop in mapper.column_attrs:
event.listen(Mapper, "mapper_configured", listen_for_type)
@classmethod
- def as_mutable(cls, sqltype):
+ def as_mutable(cls, sqltype: TypeEngine[_T]) -> TypeEngine[_T]:
"""Associate a SQL type with this mutable Python type.
This establishes listeners that will detect ORM mappings against
if isinstance(sqltype, SchemaEventTarget):
@event.listens_for(sqltype, "before_parent_attach")
- def _add_column_memo(sqltyp, parent):
+ def _add_column_memo(
+ sqltyp: TypeEngine[Any],
+ parent: Column[_T],
+ ) -> None:
parent.info["_ext_mutable_orig_type"] = sqltyp
schema_event_check = True
else:
schema_event_check = False
- def listen_for_type(mapper, class_):
+ def listen_for_type(
+ mapper: Mapper[_T],
+ class_: Union[DeclarativeAttributeIntercept, type],
+ ) -> None:
if mapper.non_primary:
return
for prop in mapper.column_attrs:
if (
schema_event_check
and hasattr(prop.expression, "info")
- and prop.expression.info.get("_ext_mutable_orig_type")
+ and prop.expression.info.get("_ext_mutable_orig_type") # type: ignore # noqa: E501 # TODO: https://github.com/python/mypy/issues/1424#issuecomment-1272354487
is sqltype
) or (prop.columns[0].type is sqltype):
cls.associate_with_attribute(getattr(class_, prop.key))
"""
@classmethod
- def _get_listen_keys(cls, attribute):
+ def _get_listen_keys(cls, attribute: QueryableAttribute[_O]) -> Set[str]:
return {attribute.key}.union(attribute.property._attribute_keys)
- def changed(self):
+ def changed(self) -> None:
"""Subclasses should call this method whenever change events occur."""
for parent, key in self._parents.items():
setattr(parent.obj(), attr_name, value)
-def _setup_composite_listener():
- def _listen_for_type(mapper, class_):
+def _setup_composite_listener() -> None:
+ def _listen_for_type(mapper: Mapper[_T], class_: type) -> None:
for prop in mapper.iterate_properties:
if (
hasattr(prop, "composite_class")
_setup_composite_listener()
-class MutableDict(Mutable, dict):
+class MutableDict(Mutable, Dict[_KT, _VT]):
"""A dictionary type that implements :class:`.Mutable`.
The :class:`.MutableDict` object implements a dictionary that will
"""
- def __setitem__(self, key, value):
+ def __setitem__(self, key: _KT, value: _VT) -> None:
"""Detect dictionary set events and emit change events."""
- dict.__setitem__(self, key, value)
+ super().__setitem__(key, value)
self.changed()
- def setdefault(self, key, value):
- result = dict.setdefault(self, key, value)
+ def _exists(self, value: _T | None) -> TypeGuard[_T]:
+ return value is not None
+
+ def _is_none(self, value: _T | None) -> TypeGuard[None]:
+ return value is None
+
+ @overload
+ def setdefault(self, key: _KT) -> _VT | None:
+ ...
+
+ @overload
+ def setdefault(self, key: _KT, value: _VT) -> _VT:
+ ...
+
+ def setdefault(self, key: _KT, value: _VT | None = None) -> _VT | None:
+ if self._exists(value):
+ result = super().setdefault(key, value)
+ else:
+ result = super().setdefault(key) # type: ignore[call-arg]
self.changed()
return result
- def __delitem__(self, key):
+ def __delitem__(self, key: _KT) -> None:
"""Detect dictionary del events and emit change events."""
- dict.__delitem__(self, key)
+ super().__delitem__(key)
self.changed()
- def update(self, *a, **kw):
- dict.update(self, *a, **kw)
+ def update(self, *a: Any, **kw: _VT) -> None:
+ super().update(*a, **kw)
self.changed()
- def pop(self, *arg):
- result = dict.pop(self, *arg)
+ @overload
+ def pop(self, __key: _KT) -> _VT:
+ ...
+
+ @overload
+ def pop(self, __key: _KT, __default: _VT | _T) -> _VT | _T:
+ ...
+
+ def pop(self, __key: _KT, __default: _VT | _T | None = None) -> _VT | _T:
+ if self._exists(__default):
+ result = super().pop(__key, __default)
+ else:
+ result = super().pop(__key)
self.changed()
return result
- def popitem(self):
- result = dict.popitem(self)
+ def popitem(self) -> Tuple[_KT, _VT]:
+ result = super().popitem()
self.changed()
return result
- def clear(self):
- dict.clear(self)
+ def clear(self) -> None:
+ super().clear()
self.changed()
@classmethod
- def coerce(cls, key, value):
+ def coerce(cls, key: str, value: Any) -> MutableDict[_KT, _VT] | None:
"""Convert plain dictionary to instance of this class."""
if not isinstance(value, cls):
if isinstance(value, dict):
else:
return value
- def __getstate__(self):
+ def __getstate__(self) -> dict[_KT, _VT]:
return dict(self)
- def __setstate__(self, state):
+ def __setstate__(
+ self, state: Union[Dict[str, int], Dict[str, str]]
+ ) -> None:
self.update(state)
-class MutableList(Mutable, list):
+class MutableList(Mutable, List[_T]):
"""A list type that implements :class:`.Mutable`.
The :class:`.MutableList` object implements a list that will
"""
- def __reduce_ex__(self, proto):
+ def __reduce_ex__(
+ self, proto: SupportsIndex
+ ) -> Tuple[type, Tuple[List[int]]]:
return (self.__class__, (list(self),))
# needed for backwards compatibility with
# older pickles
- def __setstate__(self, state):
+ def __setstate__(self, state: Iterable[_T]) -> None:
self[:] = state
- def __setitem__(self, index, value):
- """Detect list set events and emit change events."""
- list.__setitem__(self, index, value)
- self.changed()
+ def is_scalar(self, value: _T | Iterable[_T]) -> TypeGuard[_T]:
+ return not isinstance(value, Iterable)
- def __setslice__(self, start, end, value):
- """Detect list set events and emit change events."""
- list.__setslice__(self, start, end, value)
- self.changed()
+ def is_iterable(self, value: _T | Iterable[_T]) -> TypeGuard[Iterable[_T]]:
+ return isinstance(value, Iterable)
- def __delitem__(self, index):
- """Detect list del events and emit change events."""
- list.__delitem__(self, index)
+ def __setitem__(
+ self, index: SupportsIndex | slice, value: _T | Iterable[_T]
+ ) -> None:
+ """Detect list set events and emit change events."""
+ if isinstance(index, SupportsIndex) and self.is_scalar(value):
+ super().__setitem__(index, value)
+ elif isinstance(index, slice) and self.is_iterable(value):
+ super().__setitem__(index, value)
self.changed()
- def __delslice__(self, start, end):
+ def __delitem__(self, index: SupportsIndex | slice) -> None:
"""Detect list del events and emit change events."""
- list.__delslice__(self, start, end)
+ super().__delitem__(index)
self.changed()
- def pop(self, *arg):
- result = list.pop(self, *arg)
+ def pop(self, *arg: SupportsIndex) -> _T:
+ result = super().pop(*arg)
self.changed()
return result
- def append(self, x):
- list.append(self, x)
+ def append(self, x: _T) -> None:
+ super().append(x)
self.changed()
- def extend(self, x):
- list.extend(self, x)
+ def extend(self, x: Iterable[_T]) -> None:
+ super().extend(x)
self.changed()
- def __iadd__(self, x):
+ def __iadd__(self, x: Iterable[_T]) -> MutableList[_T]: # type: ignore[override,misc] # noqa: E501
self.extend(x)
return self
- def insert(self, i, x):
- list.insert(self, i, x)
+ def insert(self, i: SupportsIndex, x: _T) -> None:
+ super().insert(i, x)
self.changed()
- def remove(self, i):
- list.remove(self, i)
+ def remove(self, i: _T) -> None:
+ super().remove(i)
self.changed()
- def clear(self):
- list.clear(self)
+ def clear(self) -> None:
+ super().clear()
self.changed()
- def sort(self, **kw):
- list.sort(self, **kw)
+ def sort(self, **kw: Any) -> None:
+ super().sort(**kw)
self.changed()
- def reverse(self):
- list.reverse(self)
+ def reverse(self) -> None:
+ super().reverse()
self.changed()
@classmethod
- def coerce(cls, index, value):
+ def coerce(
+ cls, key: str, value: MutableList[_T] | _T
+ ) -> Optional[MutableList[_T]]:
"""Convert plain list to instance of this class."""
if not isinstance(value, cls):
if isinstance(value, list):
return cls(value)
- return Mutable.coerce(index, value)
+ return Mutable.coerce(key, value)
else:
return value
-class MutableSet(Mutable, set):
+class MutableSet(Mutable, Set[_T]):
"""A set type that implements :class:`.Mutable`.
The :class:`.MutableSet` object implements a set that will
"""
- def update(self, *arg):
- set.update(self, *arg)
+ def update(self, *arg: Iterable[_T]) -> None:
+ super().update(*arg)
self.changed()
- def intersection_update(self, *arg):
- set.intersection_update(self, *arg)
+ def intersection_update(self, *arg: Iterable[Any]) -> None:
+ super().intersection_update(*arg)
self.changed()
- def difference_update(self, *arg):
- set.difference_update(self, *arg)
+ def difference_update(self, *arg: Iterable[Any]) -> None:
+ super().difference_update(*arg)
self.changed()
- def symmetric_difference_update(self, *arg):
- set.symmetric_difference_update(self, *arg)
+ def symmetric_difference_update(self, *arg: Iterable[_T]) -> None:
+ super().symmetric_difference_update(*arg)
self.changed()
- def __ior__(self, other):
+ def __ior__(self, other: AbstractSet[_T]) -> MutableSet[_T]: # type: ignore[override,misc] # noqa: E501
self.update(other)
return self
- def __iand__(self, other):
+ def __iand__(self, other: AbstractSet[object]) -> MutableSet[_T]:
self.intersection_update(other)
return self
- def __ixor__(self, other):
+ def __ixor__(self, other: AbstractSet[_T]) -> MutableSet[_T]: # type: ignore[override,misc] # noqa: E501
self.symmetric_difference_update(other)
return self
- def __isub__(self, other):
+ def __isub__(self, other: AbstractSet[object]) -> MutableSet[_T]: # type: ignore[misc] # noqa: E501
self.difference_update(other)
return self
- def add(self, elem):
- set.add(self, elem)
+ def add(self, elem: _T) -> None:
+ super().add(elem)
self.changed()
- def remove(self, elem):
- set.remove(self, elem)
+ def remove(self, elem: _T) -> None:
+ super().remove(elem)
self.changed()
- def discard(self, elem):
- set.discard(self, elem)
+ def discard(self, elem: _T) -> None:
+ super().discard(elem)
self.changed()
- def pop(self, *arg):
- result = set.pop(self, *arg)
+ def pop(self, *arg: Any) -> _T:
+ result = super().pop(*arg)
self.changed()
return result
- def clear(self):
- set.clear(self)
+ def clear(self) -> None:
+ super().clear()
self.changed()
@classmethod
- def coerce(cls, index, value):
+ def coerce(cls, index: str, value: Any) -> Optional[MutableSet[_T]]:
"""Convert plain set to instance of this class."""
if not isinstance(value, cls):
if isinstance(value, set):
else:
return value
- def __getstate__(self):
+ def __getstate__(self) -> set[_T]:
return set(self)
- def __setstate__(self, state):
+ def __setstate__(self, state: Iterable[_T]) -> None:
self.update(state)
- def __reduce_ex__(self, proto):
+ def __reduce_ex__(
+ self, proto: SupportsIndex
+ ) -> Tuple[type, Tuple[List[int]]]:
return (self.__class__, (list(self),))