]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Type annotations for sqlalchemy.orm.mapped_collection
authorMaksim Latysh <m.latysh@godeltech.com>
Tue, 24 Jan 2023 16:03:44 +0000 (11:03 -0500)
committersqla-tester <sqla-tester@sqlalchemy.org>
Tue, 24 Jan 2023 16:03:44 +0000 (11:03 -0500)
<!-- Provide a general summary of your proposed changes in the Title field above -->

### Description
<!-- Describe your changes in detail -->
An attempt to annotate lib/sqlalchemy/orm/mapped_collection.py with type hints (issue https://github.com/sqlalchemy/sqlalchemy/issues/6810)

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [ ] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

Closes: #9140
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9140
Pull-request-sha: facb4717134943dd651905f7c72618eb66a9eca5

Change-Id: I0fb80e2ea7ed2247c494487fb6c8d72efb4e9802

lib/sqlalchemy/orm/mapped_collection.py

index 8a65f847aeff66c4de9c0e642d2ba5e41b56b156..a2b085c76d70f2efda8d5b2eac52a97284326cb2 100644 (file)
@@ -4,14 +4,15 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 from __future__ import annotations
 
 from typing import Any
 from typing import Callable
 from typing import Dict
+from typing import Generic
 from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 
 from . import base
@@ -22,11 +23,24 @@ from ..sql import coercions
 from ..sql import expression
 from ..sql import roles
 
+if TYPE_CHECKING:
+    from typing import List
+    from typing import Optional
+    from typing import Sequence
+    from typing import Tuple
+    from typing import Union
+
+    from . import AttributeEventToken
+    from . import Mapper
+    from ..sql.elements import ColumnElement
+
 _KT = TypeVar("_KT", bound=Any)
 _VT = TypeVar("_VT", bound=Any)
 
+_F = TypeVar("_F", bound=Callable[[Any], Any])
 
-class _PlainColumnGetter:
+
+class _PlainColumnGetter(Generic[_KT]):
     """Plain column getter, stores collection of Column objects
     directly.
 
@@ -38,21 +52,26 @@ class _PlainColumnGetter:
 
     __slots__ = ("cols", "composite")
 
-    def __init__(self, cols):
+    def __init__(self, cols: Sequence[ColumnElement[_KT]]) -> None:
         self.cols = cols
         self.composite = len(cols) > 1
 
-    def __reduce__(self):
+    def __reduce__(
+        self,
+    ) -> Tuple[
+        Type[_SerializableColumnGetterV2[_KT]],
+        Tuple[Sequence[Tuple[Optional[str], Optional[str]]]],
+    ]:
         return _SerializableColumnGetterV2._reduce_from_cols(self.cols)
 
-    def _cols(self, mapper):
+    def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]:
         return self.cols
 
-    def __call__(self, value):
+    def __call__(self, value: _KT) -> Union[_KT, Tuple[_KT, ...]]:
         state = base.instance_state(value)
         m = base._state_mapper(state)
 
-        key = [
+        key: List[_KT] = [
             m._get_state_attr_by_column(state, state.dict, col)
             for col in self._cols(m)
         ]
@@ -62,7 +81,7 @@ class _PlainColumnGetter:
             return key[0]
 
 
-class _SerializableColumnGetterV2(_PlainColumnGetter):
+class _SerializableColumnGetterV2(_PlainColumnGetter[_KT]):
     """Updated serializable getter which deals with
     multi-table mapped classes.
 
@@ -76,38 +95,52 @@ class _SerializableColumnGetterV2(_PlainColumnGetter):
 
     __slots__ = ("colkeys",)
 
-    def __init__(self, colkeys):
+    def __init__(
+        self, colkeys: Sequence[Tuple[Optional[str], Optional[str]]]
+    ) -> None:
         self.colkeys = colkeys
         self.composite = len(colkeys) > 1
 
-    def __reduce__(self):
+    def __reduce__(
+        self,
+    ) -> Tuple[
+        Type[_SerializableColumnGetterV2[_KT]],
+        Tuple[Sequence[Tuple[Optional[str], Optional[str]]]],
+    ]:
         return self.__class__, (self.colkeys,)
 
     @classmethod
-    def _reduce_from_cols(cls, cols):
-        def _table_key(c):
+    def _reduce_from_cols(
+        cls, cols: Sequence[ColumnElement[_KT]]
+    ) -> Tuple[
+        Type[_SerializableColumnGetterV2[_KT]],
+        Tuple[Sequence[Tuple[Optional[str], Optional[str]]]],
+    ]:
+        def _table_key(c: ColumnElement[_KT]) -> Optional[str]:
             if not isinstance(c.table, expression.TableClause):
                 return None
             else:
-                return c.table.key
+                return c.table.key  # type: ignore
 
         colkeys = [(c.key, _table_key(c)) for c in cols]
         return _SerializableColumnGetterV2, (colkeys,)
 
-    def _cols(self, mapper):
-        cols = []
+    def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]:
+        cols: List[ColumnElement[_KT]] = []
         metadata = getattr(mapper.local_table, "metadata", None)
         for (ckey, tkey) in self.colkeys:
             if tkey is None or metadata is None or tkey not in metadata:
-                cols.append(mapper.local_table.c[ckey])
+                cols.append(mapper.local_table.c[ckey])  # type: ignore
             else:
                 cols.append(metadata.tables[tkey].c[ckey])
         return cols
 
 
 def column_keyed_dict(
-    mapping_spec, *, ignore_unpopulated_attribute: bool = False
-):
+    mapping_spec: Union[Type[_KT], Callable[[_KT], _VT]],
+    *,
+    ignore_unpopulated_attribute: bool = False,
+) -> Type[KeyFuncDict[_KT, _KT]]:
     """A dictionary-based collection type with column-based keying.
 
     .. versionchanged:: 2.0 Renamed :data:`.column_mapped_collection` to
@@ -155,7 +188,8 @@ def column_keyed_dict(
     ]
     keyfunc = _PlainColumnGetter(cols)
     return _mapped_collection_cls(
-        keyfunc, ignore_unpopulated_attribute=ignore_unpopulated_attribute
+        keyfunc,
+        ignore_unpopulated_attribute=ignore_unpopulated_attribute,
     )
 
 
@@ -169,13 +203,13 @@ class _AttrGetter:
         dict_ = base.instance_dict(mapped_object)
         return dict_.get(self.attr_name, base.NO_VALUE)
 
-    def __reduce__(self):
+    def __reduce__(self) -> Tuple[Type[_AttrGetter], Tuple[str]]:
         return _AttrGetter, (self.attr_name,)
 
 
 def attribute_keyed_dict(
     attr_name: str, *, ignore_unpopulated_attribute: bool = False
-) -> Type[KeyFuncDict]:
+) -> Type[KeyFuncDict[_KT, _KT]]:
     """A dictionary-based collection type with attribute-based keying.
 
     .. versionchanged:: 2.0 Renamed :data:`.attribute_mapped_collection` to
@@ -223,7 +257,7 @@ def attribute_keyed_dict(
 
 
 def keyfunc_mapping(
-    keyfunc: Callable[[Any], _KT],
+    keyfunc: _F,
     *,
     ignore_unpopulated_attribute: bool = False,
 ) -> Type[KeyFuncDict[_KT, Any]]:
@@ -297,7 +331,12 @@ class KeyFuncDict(Dict[_KT, _VT]):
 
     """
 
-    def __init__(self, keyfunc, *, ignore_unpopulated_attribute=False):
+    def __init__(
+        self,
+        keyfunc: _F,
+        *,
+        ignore_unpopulated_attribute: bool = False,
+    ) -> None:
         """Create a new collection with keying provided by keyfunc.
 
         keyfunc may be any callable that takes an object and returns an object
@@ -315,21 +354,30 @@ class KeyFuncDict(Dict[_KT, _VT]):
         self.ignore_unpopulated_attribute = ignore_unpopulated_attribute
 
     @classmethod
-    def _unreduce(cls, keyfunc, values):
-        mp = KeyFuncDict(keyfunc)
+    def _unreduce(
+        cls, keyfunc: _F, values: Dict[_KT, _KT]
+    ) -> "KeyFuncDict[_KT, _KT]":
+        mp: KeyFuncDict[_KT, _KT] = KeyFuncDict(keyfunc)
         mp.update(values)
         return mp
 
-    def __reduce__(self):
+    def __reduce__(
+        self,
+    ) -> Tuple[
+        Callable[[_KT, _KT], KeyFuncDict[_KT, _KT]],
+        Tuple[Any, Union[Dict[_KT, _KT], Dict[_KT, _KT]]],
+    ]:
         return (KeyFuncDict._unreduce, (self.keyfunc, dict(self)))
 
-    def _raise_for_unpopulated(self, value, initiator):
+    def _raise_for_unpopulated(
+        self, value: _KT, initiator: Optional[AttributeEventToken]
+    ) -> None:
         mapper = base.instance_state(value).mapper
 
         if initiator is None:
             relationship = "unknown relationship"
         else:
-            relationship = mapper.attrs[initiator.key]
+            relationship = f"{mapper.attrs[initiator.key]}"
 
         raise sa_exc.InvalidRequestError(
             f"In event triggered from population of attribute {relationship} "
@@ -345,9 +393,13 @@ class KeyFuncDict(Dict[_KT, _VT]):
             f"parameter on the mapped collection factory."
         )
 
-    @collection.appender
-    @collection.internally_instrumented
-    def set(self, value, _sa_initiator=None):
+    @collection.appender  # type: ignore[misc]
+    @collection.internally_instrumented  # type: ignore[misc]
+    def set(
+        self,
+        value: _KT,
+        _sa_initiator: Optional[AttributeEventToken] = None,
+    ) -> None:
         """Add an item by value, consulting the keyfunc for the key."""
 
         key = self.keyfunc(value)
@@ -358,11 +410,15 @@ class KeyFuncDict(Dict[_KT, _VT]):
             else:
                 return
 
-        self.__setitem__(key, value, _sa_initiator)
+        self.__setitem__(key, value, _sa_initiator)  # type: ignore[call-arg]
 
-    @collection.remover
-    @collection.internally_instrumented
-    def remove(self, value, _sa_initiator=None):
+    @collection.remover  # type: ignore[misc]
+    @collection.internally_instrumented  # type: ignore[misc]
+    def remove(
+        self,
+        value: _KT,
+        _sa_initiator: Optional[AttributeEventToken] = None,
+    ) -> None:
         """Remove an item by value, consulting the keyfunc for the key."""
 
         key = self.keyfunc(value)
@@ -381,12 +437,14 @@ class KeyFuncDict(Dict[_KT, _VT]):
                 "based on mutable properties or properties that only obtain "
                 "values after flush?" % (value, self[key], key)
             )
-        self.__delitem__(key, _sa_initiator)
+        self.__delitem__(key, _sa_initiator)  # type: ignore[call-arg]
 
 
-def _mapped_collection_cls(keyfunc, ignore_unpopulated_attribute):
-    class _MKeyfuncMapped(KeyFuncDict):
-        def __init__(self):
+def _mapped_collection_cls(
+    keyfunc: _F, ignore_unpopulated_attribute: bool
+) -> Type[KeyFuncDict[_KT, _KT]]:
+    class _MKeyfuncMapped(KeyFuncDict[_KT, _KT]):
+        def __init__(self) -> None:
             super().__init__(
                 keyfunc,
                 ignore_unpopulated_attribute=ignore_unpopulated_attribute,