From 173e4164834e7ac5c77184a425a32f9afd088af4 Mon Sep 17 00:00:00 2001 From: Maksim Latysh Date: Tue, 24 Jan 2023 11:03:44 -0500 Subject: [PATCH] Type annotations for sqlalchemy.orm.mapped_collection ### Description An attempt to annotate lib/sqlalchemy/orm/mapped_collection.py with type hints (issue https://github.com/sqlalchemy/sqlalchemy/issues/6810) ### Checklist This pull request is: - [ ] A documentation / typographical error fix - Good to go, no issue or tests are needed - [ ] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. Closes: #9140 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9140 Pull-request-sha: facb4717134943dd651905f7c72618eb66a9eca5 Change-Id: I0fb80e2ea7ed2247c494487fb6c8d72efb4e9802 --- lib/sqlalchemy/orm/mapped_collection.py | 136 +++++++++++++++++------- 1 file changed, 97 insertions(+), 39 deletions(-) diff --git a/lib/sqlalchemy/orm/mapped_collection.py b/lib/sqlalchemy/orm/mapped_collection.py index 8a65f847ae..a2b085c76d 100644 --- a/lib/sqlalchemy/orm/mapped_collection.py +++ b/lib/sqlalchemy/orm/mapped_collection.py @@ -4,14 +4,15 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors from __future__ import annotations from typing import Any from typing import Callable from typing import Dict +from typing import Generic from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from . import base @@ -22,11 +23,24 @@ from ..sql import coercions from ..sql import expression from ..sql import roles +if TYPE_CHECKING: + from typing import List + from typing import Optional + from typing import Sequence + from typing import Tuple + from typing import Union + + from . import AttributeEventToken + from . import Mapper + from ..sql.elements import ColumnElement + _KT = TypeVar("_KT", bound=Any) _VT = TypeVar("_VT", bound=Any) +_F = TypeVar("_F", bound=Callable[[Any], Any]) -class _PlainColumnGetter: + +class _PlainColumnGetter(Generic[_KT]): """Plain column getter, stores collection of Column objects directly. @@ -38,21 +52,26 @@ class _PlainColumnGetter: __slots__ = ("cols", "composite") - def __init__(self, cols): + def __init__(self, cols: Sequence[ColumnElement[_KT]]) -> None: self.cols = cols self.composite = len(cols) > 1 - def __reduce__(self): + def __reduce__( + self, + ) -> Tuple[ + Type[_SerializableColumnGetterV2[_KT]], + Tuple[Sequence[Tuple[Optional[str], Optional[str]]]], + ]: return _SerializableColumnGetterV2._reduce_from_cols(self.cols) - def _cols(self, mapper): + def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]: return self.cols - def __call__(self, value): + def __call__(self, value: _KT) -> Union[_KT, Tuple[_KT, ...]]: state = base.instance_state(value) m = base._state_mapper(state) - key = [ + key: List[_KT] = [ m._get_state_attr_by_column(state, state.dict, col) for col in self._cols(m) ] @@ -62,7 +81,7 @@ class _PlainColumnGetter: return key[0] -class _SerializableColumnGetterV2(_PlainColumnGetter): +class _SerializableColumnGetterV2(_PlainColumnGetter[_KT]): """Updated serializable getter which deals with multi-table mapped classes. @@ -76,38 +95,52 @@ class _SerializableColumnGetterV2(_PlainColumnGetter): __slots__ = ("colkeys",) - def __init__(self, colkeys): + def __init__( + self, colkeys: Sequence[Tuple[Optional[str], Optional[str]]] + ) -> None: self.colkeys = colkeys self.composite = len(colkeys) > 1 - def __reduce__(self): + def __reduce__( + self, + ) -> Tuple[ + Type[_SerializableColumnGetterV2[_KT]], + Tuple[Sequence[Tuple[Optional[str], Optional[str]]]], + ]: return self.__class__, (self.colkeys,) @classmethod - def _reduce_from_cols(cls, cols): - def _table_key(c): + def _reduce_from_cols( + cls, cols: Sequence[ColumnElement[_KT]] + ) -> Tuple[ + Type[_SerializableColumnGetterV2[_KT]], + Tuple[Sequence[Tuple[Optional[str], Optional[str]]]], + ]: + def _table_key(c: ColumnElement[_KT]) -> Optional[str]: if not isinstance(c.table, expression.TableClause): return None else: - return c.table.key + return c.table.key # type: ignore colkeys = [(c.key, _table_key(c)) for c in cols] return _SerializableColumnGetterV2, (colkeys,) - def _cols(self, mapper): - cols = [] + def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]: + cols: List[ColumnElement[_KT]] = [] metadata = getattr(mapper.local_table, "metadata", None) for (ckey, tkey) in self.colkeys: if tkey is None or metadata is None or tkey not in metadata: - cols.append(mapper.local_table.c[ckey]) + cols.append(mapper.local_table.c[ckey]) # type: ignore else: cols.append(metadata.tables[tkey].c[ckey]) return cols def column_keyed_dict( - mapping_spec, *, ignore_unpopulated_attribute: bool = False -): + mapping_spec: Union[Type[_KT], Callable[[_KT], _VT]], + *, + ignore_unpopulated_attribute: bool = False, +) -> Type[KeyFuncDict[_KT, _KT]]: """A dictionary-based collection type with column-based keying. .. versionchanged:: 2.0 Renamed :data:`.column_mapped_collection` to @@ -155,7 +188,8 @@ def column_keyed_dict( ] keyfunc = _PlainColumnGetter(cols) return _mapped_collection_cls( - keyfunc, ignore_unpopulated_attribute=ignore_unpopulated_attribute + keyfunc, + ignore_unpopulated_attribute=ignore_unpopulated_attribute, ) @@ -169,13 +203,13 @@ class _AttrGetter: dict_ = base.instance_dict(mapped_object) return dict_.get(self.attr_name, base.NO_VALUE) - def __reduce__(self): + def __reduce__(self) -> Tuple[Type[_AttrGetter], Tuple[str]]: return _AttrGetter, (self.attr_name,) def attribute_keyed_dict( attr_name: str, *, ignore_unpopulated_attribute: bool = False -) -> Type[KeyFuncDict]: +) -> Type[KeyFuncDict[_KT, _KT]]: """A dictionary-based collection type with attribute-based keying. .. versionchanged:: 2.0 Renamed :data:`.attribute_mapped_collection` to @@ -223,7 +257,7 @@ def attribute_keyed_dict( def keyfunc_mapping( - keyfunc: Callable[[Any], _KT], + keyfunc: _F, *, ignore_unpopulated_attribute: bool = False, ) -> Type[KeyFuncDict[_KT, Any]]: @@ -297,7 +331,12 @@ class KeyFuncDict(Dict[_KT, _VT]): """ - def __init__(self, keyfunc, *, ignore_unpopulated_attribute=False): + def __init__( + self, + keyfunc: _F, + *, + ignore_unpopulated_attribute: bool = False, + ) -> None: """Create a new collection with keying provided by keyfunc. keyfunc may be any callable that takes an object and returns an object @@ -315,21 +354,30 @@ class KeyFuncDict(Dict[_KT, _VT]): self.ignore_unpopulated_attribute = ignore_unpopulated_attribute @classmethod - def _unreduce(cls, keyfunc, values): - mp = KeyFuncDict(keyfunc) + def _unreduce( + cls, keyfunc: _F, values: Dict[_KT, _KT] + ) -> "KeyFuncDict[_KT, _KT]": + mp: KeyFuncDict[_KT, _KT] = KeyFuncDict(keyfunc) mp.update(values) return mp - def __reduce__(self): + def __reduce__( + self, + ) -> Tuple[ + Callable[[_KT, _KT], KeyFuncDict[_KT, _KT]], + Tuple[Any, Union[Dict[_KT, _KT], Dict[_KT, _KT]]], + ]: return (KeyFuncDict._unreduce, (self.keyfunc, dict(self))) - def _raise_for_unpopulated(self, value, initiator): + def _raise_for_unpopulated( + self, value: _KT, initiator: Optional[AttributeEventToken] + ) -> None: mapper = base.instance_state(value).mapper if initiator is None: relationship = "unknown relationship" else: - relationship = mapper.attrs[initiator.key] + relationship = f"{mapper.attrs[initiator.key]}" raise sa_exc.InvalidRequestError( f"In event triggered from population of attribute {relationship} " @@ -345,9 +393,13 @@ class KeyFuncDict(Dict[_KT, _VT]): f"parameter on the mapped collection factory." ) - @collection.appender - @collection.internally_instrumented - def set(self, value, _sa_initiator=None): + @collection.appender # type: ignore[misc] + @collection.internally_instrumented # type: ignore[misc] + def set( + self, + value: _KT, + _sa_initiator: Optional[AttributeEventToken] = None, + ) -> None: """Add an item by value, consulting the keyfunc for the key.""" key = self.keyfunc(value) @@ -358,11 +410,15 @@ class KeyFuncDict(Dict[_KT, _VT]): else: return - self.__setitem__(key, value, _sa_initiator) + self.__setitem__(key, value, _sa_initiator) # type: ignore[call-arg] - @collection.remover - @collection.internally_instrumented - def remove(self, value, _sa_initiator=None): + @collection.remover # type: ignore[misc] + @collection.internally_instrumented # type: ignore[misc] + def remove( + self, + value: _KT, + _sa_initiator: Optional[AttributeEventToken] = None, + ) -> None: """Remove an item by value, consulting the keyfunc for the key.""" key = self.keyfunc(value) @@ -381,12 +437,14 @@ class KeyFuncDict(Dict[_KT, _VT]): "based on mutable properties or properties that only obtain " "values after flush?" % (value, self[key], key) ) - self.__delitem__(key, _sa_initiator) + self.__delitem__(key, _sa_initiator) # type: ignore[call-arg] -def _mapped_collection_cls(keyfunc, ignore_unpopulated_attribute): - class _MKeyfuncMapped(KeyFuncDict): - def __init__(self): +def _mapped_collection_cls( + keyfunc: _F, ignore_unpopulated_attribute: bool +) -> Type[KeyFuncDict[_KT, _KT]]: + class _MKeyfuncMapped(KeyFuncDict[_KT, _KT]): + def __init__(self) -> None: super().__init__( keyfunc, ignore_unpopulated_attribute=ignore_unpopulated_attribute, -- 2.47.2