]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Type annotations for sqlalchemy.ext.automap
authorGleb Kisenkov <g.kisenkov@godeltech.com>
Mon, 5 Dec 2022 13:45:25 +0000 (08:45 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Dec 2022 14:45:14 +0000 (09:45 -0500)
An attempt to annotate `lib/sqlalchemy/ext/automap.py` with type hints (issue [#6810](https://github.com/sqlalchemy/sqlalchemy/issues/6810#issuecomment-1127062951)).

More info on how I approach it could be found in [the earlier PR](https://github.com/sqlalchemy/sqlalchemy/pull/8775).

This pull request is:

- [ ] A documentation / typographical error fix
  - Good to go, no issue or tests are needed
- [ ] A short code fix
  - please include the issue number, and create an issue if none exists, which
    must include a complete example of the issue. one line code fixes without an
    issue and demonstration will not be accepted.
  - Please include: `Fixes: #<issue number>` in the commit message
  - please include tests. one line code fixes without tests will not be accepted.
- [x] A new feature implementation
  - please include the issue number, and create an issue if none exists, which must
    include a complete example of how the feature would look.
  - Please include: `Fixes: #<issue number>` in the commit message
  - please include tests.

**Have a nice day!**

Closes: #8874
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8874
Pull-request-sha: 834d58d77c2cfd09ea874f01eb7b75a2ea0db7cd

Change-Id: Ie64b2be7a51ddc83ef8f23385fb63db5b5c1bc17

lib/sqlalchemy/ext/automap.py
lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/relationships.py

index 6eb30ba4c6eae36ac01c006d31fb09b9070c2236..217113309f465dde74d2e5d57608e748f9766a34 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 r"""Define an extension to the :mod:`sqlalchemy.ext.declarative` system
 which automatically generates mapped classes and relationships from a database
@@ -572,6 +571,22 @@ be applied as::
 
 
 """  # noqa
+from __future__ import annotations
+
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import NoReturn
+from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
+
 from .. import util
 from ..orm import backref
 from ..orm import declarative_base as _declarative_base
@@ -582,9 +597,36 @@ from ..orm.decl_base import _DeferredMapperConfig
 from ..orm.mapper import _CONFIGURE_MUTEX
 from ..schema import ForeignKeyConstraint
 from ..sql import and_
+from ..util.typing import Protocol
+
+if TYPE_CHECKING:
+    from ..engine.base import Engine
+    from ..orm.base import RelationshipDirection
+    from ..orm.relationships import ORMBackrefArgument
+    from ..orm.relationships import Relationship
+    from ..sql.elements import quoted_name
+    from ..sql.schema import Column
+    from ..sql.schema import Table
+    from ..util import immutabledict
+    from ..util import Properties
+
+
+_KT = TypeVar("_KT", bound=Any)
+_VT = TypeVar("_VT", bound=Any)
+
 
+class ClassnameForTableType(Protocol):
+    def __call__(
+        self, base: Type[Any], tablename: quoted_name, table: Table
+    ) -> str:
+        ...
 
-def classname_for_table(base, tablename, table):
+
+def classname_for_table(
+    base: Type[Any],
+    tablename: quoted_name,
+    table: Table,
+) -> str:
     """Return the class name that should be used, given the name
     of a table.
 
@@ -617,7 +659,23 @@ def classname_for_table(base, tablename, table):
     return str(tablename)
 
 
-def name_for_scalar_relationship(base, local_cls, referred_cls, constraint):
+class NameForScalarRelationshipType(Protocol):
+    def __call__(
+        self,
+        base: Type[Any],
+        local_cls: Type[Any],
+        referred_cls: Type[Any],
+        constraint: ForeignKeyConstraint,
+    ) -> str:
+        ...
+
+
+def name_for_scalar_relationship(
+    base: Type[Any],
+    local_cls: Type[Any],
+    referred_cls: Type[Any],
+    constraint: ForeignKeyConstraint,
+) -> str:
     """Return the attribute name that should be used to refer from one
     class to another, for a scalar object reference.
 
@@ -642,9 +700,23 @@ def name_for_scalar_relationship(base, local_cls, referred_cls, constraint):
     return referred_cls.__name__.lower()
 
 
+class NameForCollectionRelationshipType(Protocol):
+    def __call__(
+        self,
+        base: Type[Any],
+        local_cls: Type[Any],
+        referred_cls: Type[Any],
+        constraint: ForeignKeyConstraint,
+    ) -> str:
+        ...
+
+
 def name_for_collection_relationship(
-    base, local_cls, referred_cls, constraint
-):
+    base: Type[Any],
+    local_cls: Type[Any],
+    referred_cls: Type[Any],
+    constraint: ForeignKeyConstraint,
+) -> str:
     """Return the attribute name that should be used to refer from one
     class to another, for a collection reference.
 
@@ -670,9 +742,85 @@ def name_for_collection_relationship(
     return referred_cls.__name__.lower() + "_collection"
 
 
+class GenerateRelationshipType(Protocol):
+    @overload
+    def __call__(
+        self,
+        base: Type[Any],
+        direction: RelationshipDirection,
+        return_fn: Callable[..., Relationship[Any]],
+        attrname: str,
+        local_cls: Type[Any],
+        referred_cls: Type[Any],
+        **kw: Any,
+    ) -> Relationship[Any]:
+        ...
+
+    @overload
+    def __call__(
+        self,
+        base: Type[Any],
+        direction: RelationshipDirection,
+        return_fn: Callable[..., ORMBackrefArgument],
+        attrname: str,
+        local_cls: Type[Any],
+        referred_cls: Type[Any],
+        **kw: Any,
+    ) -> ORMBackrefArgument:
+        ...
+
+    def __call__(
+        self,
+        base: Type[Any],
+        direction: RelationshipDirection,
+        return_fn: Union[
+            Callable[..., Relationship[Any]], Callable[..., ORMBackrefArgument]
+        ],
+        attrname: str,
+        local_cls: Type[Any],
+        referred_cls: Type[Any],
+        **kw: Any,
+    ) -> Union[ORMBackrefArgument, Relationship[Any]]:
+        ...
+
+
+@overload
+def generate_relationship(
+    base: Type[Any],
+    direction: RelationshipDirection,
+    return_fn: Callable[..., Relationship[Any]],
+    attrname: str,
+    local_cls: Type[Any],
+    referred_cls: Type[Any],
+    **kw: Any,
+) -> Relationship[Any]:
+    ...
+
+
+@overload
 def generate_relationship(
-    base, direction, return_fn, attrname, local_cls, referred_cls, **kw
-):
+    base: Type[Any],
+    direction: RelationshipDirection,
+    return_fn: Callable[..., ORMBackrefArgument],
+    attrname: str,
+    local_cls: Type[Any],
+    referred_cls: Type[Any],
+    **kw: Any,
+) -> ORMBackrefArgument:
+    ...
+
+
+def generate_relationship(
+    base: Type[Any],
+    direction: RelationshipDirection,
+    return_fn: Union[
+        Callable[..., Relationship[Any]], Callable[..., ORMBackrefArgument]
+    ],
+    attrname: str,
+    local_cls: Type[Any],
+    referred_cls: Type[Any],
+    **kw: Any,
+) -> Union[Relationship[Any], ORMBackrefArgument]:
     r"""Generate a :func:`_orm.relationship` or :func:`.backref`
     on behalf of two
     mapped classes.
@@ -721,6 +869,7 @@ def generate_relationship(
      by the :paramref:`.generate_relationship.return_fn` parameter.
 
     """
+
     if return_fn is backref:
         return return_fn(attrname, **kw)
     elif return_fn is relationship:
@@ -748,7 +897,7 @@ class AutomapBase:
 
     __abstract__ = True
 
-    classes = None
+    classes: Optional[Properties[Type[Any]]] = None
     """An instance of :class:`.util.Properties` containing classes.
 
     This object behaves much like the ``.c`` collection on a table.  Classes
@@ -781,18 +930,24 @@ class AutomapBase:
         ),
     )
     def prepare(
-        cls,
-        autoload_with=None,
-        engine=None,
-        reflect=False,
-        schema=None,
-        classname_for_table=None,
-        collection_class=None,
-        name_for_scalar_relationship=None,
-        name_for_collection_relationship=None,
-        generate_relationship=None,
-        reflection_options=util.EMPTY_DICT,
-    ):
+        cls: Type[Any],
+        autoload_with: Optional[Engine] = None,
+        engine: Optional[Any] = None,
+        reflect: bool = False,
+        schema: Optional[str] = None,
+        classname_for_table: Optional[ClassnameForTableType] = None,
+        collection_class: Optional[Any] = None,
+        name_for_scalar_relationship: Optional[
+            NameForScalarRelationshipType
+        ] = None,
+        name_for_collection_relationship: Optional[
+            NameForCollectionRelationshipType
+        ] = None,
+        generate_relationship: Optional[GenerateRelationshipType] = None,
+        reflection_options: Union[
+            Dict[_KT, _VT], immutabledict[_KT, _VT]
+        ] = util.EMPTY_DICT,
+    ) -> None:
         """Extract mapped classes and relationships from the
         :class:`_schema.MetaData` and
         perform mappings.
@@ -874,6 +1029,7 @@ class AutomapBase:
             autoload_with = engine
 
         if reflect:
+            assert autoload_with
             opts = dict(
                 schema=schema,
                 extend_existing=True,
@@ -884,18 +1040,30 @@ class AutomapBase:
             cls.metadata.reflect(autoload_with, **opts)
 
         with _CONFIGURE_MUTEX:
-            table_to_map_config = {
-                m.local_table: m
+            table_to_map_config: Union[
+                Dict[Optional[Table], _DeferredMapperConfig],
+                Dict[Table, _DeferredMapperConfig],
+            ] = {
+                cast("Table", m.local_table): m
                 for m in _DeferredMapperConfig.classes_for_base(
                     cls, sort=False
                 )
             }
 
-            many_to_many = []
+            many_to_many: list[
+                tuple[
+                    Table,
+                    Table,
+                    list[ForeignKeyConstraint],
+                    Table,
+                ]
+            ] = []
 
             for table in cls.metadata.tables.values():
                 lcl_m2m, rem_m2m, m2m_const = _is_many_to_many(cls, table)
                 if lcl_m2m is not None:
+                    assert rem_m2m is not None
+                    assert m2m_const is not None
                     many_to_many.append((lcl_m2m, rem_m2m, m2m_const, table))
                 elif not table.primary_key:
                     continue
@@ -961,7 +1129,7 @@ class AutomapBase:
     """
 
     @classmethod
-    def _sa_raise_deferred_config(cls):
+    def _sa_raise_deferred_config(cls) -> NoReturn:
         raise orm_exc.UnmappedClassError(
             cls,
             msg="Class %s is a subclass of AutomapBase.  "
@@ -971,7 +1139,9 @@ class AutomapBase:
         )
 
 
-def automap_base(declarative_base=None, **kw):
+def automap_base(
+    declarative_base: Optional[Type[Any]] = None, **kw: Any
+) -> Any:
     r"""Produce a declarative automap base.
 
     This function produces a new base class that is a product of the
@@ -1003,7 +1173,11 @@ def automap_base(declarative_base=None, **kw):
     )
 
 
-def _is_many_to_many(automap_base, table):
+def _is_many_to_many(
+    automap_base: Type[Any], table: Table
+) -> Tuple[
+    Optional[Table], Optional[Table], Optional[list[ForeignKeyConstraint]]
+]:
     fk_constraints = [
         const
         for const in table.constraints
@@ -1012,7 +1186,7 @@ def _is_many_to_many(automap_base, table):
     if len(fk_constraints) != 2:
         return None, None, None
 
-    cols = sum(
+    cols: list[Column[Any]] = sum(
         [
             [fk.parent for fk in fk_constraint.elements]
             for fk_constraint in fk_constraints
@@ -1031,16 +1205,21 @@ def _is_many_to_many(automap_base, table):
 
 
 def _relationships_for_fks(
-    automap_base,
-    map_config,
-    table_to_map_config,
-    collection_class,
-    name_for_scalar_relationship,
-    name_for_collection_relationship,
-    generate_relationship,
-):
-    local_table = map_config.local_table
-    local_cls = map_config.cls  # derived from a weakref, may be None
+    automap_base: Type[Any],
+    map_config: _DeferredMapperConfig,
+    table_to_map_config: Union[
+        Dict[Optional[Table], _DeferredMapperConfig],
+        Dict[Table, _DeferredMapperConfig],
+    ],
+    collection_class: type,
+    name_for_scalar_relationship: NameForScalarRelationshipType,
+    name_for_collection_relationship: NameForCollectionRelationshipType,
+    generate_relationship: GenerateRelationshipType,
+) -> None:
+    local_table = cast("Optional[Table]", map_config.local_table)
+    local_cls = cast(
+        "Optional[Type[Any]]", map_config.cls
+    )  # derived from a weakref, may be None
 
     if local_table is None or local_cls is None:
         return
@@ -1065,7 +1244,7 @@ def _relationships_for_fks(
                 automap_base, referred_cls, local_cls, constraint
             )
 
-            o2m_kws = {}
+            o2m_kws: dict[str, Union[str, bool]] = {}
             nullable = False not in {fk.parent.nullable for fk in fks}
             if not nullable:
                 o2m_kws["cascade"] = "all, delete-orphan"
@@ -1114,7 +1293,7 @@ def _relationships_for_fks(
                     if not create_backref:
                         referred_cfg.properties[
                             backref_name
-                        ].back_populates = relationship_name
+                        ].back_populates = relationship_name  # type: ignore[union-attr] # noqa: E501
             elif create_backref:
                 rel = generate_relationship(
                     automap_base,
@@ -1132,21 +1311,24 @@ def _relationships_for_fks(
                     referred_cfg.properties[backref_name] = rel
                     map_config.properties[
                         relationship_name
-                    ].back_populates = backref_name
+                    ].back_populates = backref_name  # type: ignore[union-attr]
 
 
 def _m2m_relationship(
-    automap_base,
-    lcl_m2m,
-    rem_m2m,
-    m2m_const,
-    table,
-    table_to_map_config,
-    collection_class,
-    name_for_scalar_relationship,
-    name_for_collection_relationship,
-    generate_relationship,
-):
+    automap_base: Type[Any],
+    lcl_m2m: Table,
+    rem_m2m: Table,
+    m2m_const: List[ForeignKeyConstraint],
+    table: Table,
+    table_to_map_config: Union[
+        Dict[Optional[Table], _DeferredMapperConfig],
+        Dict[Table, _DeferredMapperConfig],
+    ],
+    collection_class: type,
+    name_for_scalar_relationship: NameForCollectionRelationshipType,
+    name_for_collection_relationship: NameForCollectionRelationshipType,
+    generate_relationship: GenerateRelationshipType,
+) -> None:
 
     map_config = table_to_map_config.get(lcl_m2m, None)
     referred_cfg = table_to_map_config.get(rem_m2m, None)
@@ -1196,10 +1378,10 @@ def _m2m_relationship(
             secondary=table,
             primaryjoin=and_(
                 fk.column == fk.parent for fk in m2m_const[0].elements
-            ),
+            ),  # type: ignore [arg-type]
             secondaryjoin=and_(
                 fk.column == fk.parent for fk in m2m_const[1].elements
-            ),
+            ),  # type: ignore [arg-type]
             backref=backref_obj,
             collection_class=collection_class,
         )
@@ -1209,7 +1391,7 @@ def _m2m_relationship(
             if not create_backref:
                 referred_cfg.properties[
                     backref_name
-                ].back_populates = relationship_name
+                ].back_populates = relationship_name  # type: ignore[union-attr] # noqa: E501
     elif create_backref:
         rel = generate_relationship(
             automap_base,
@@ -1222,10 +1404,10 @@ def _m2m_relationship(
             secondary=table,
             primaryjoin=and_(
                 fk.column == fk.parent for fk in m2m_const[1].elements
-            ),
+            ),  # type: ignore [arg-type]
             secondaryjoin=and_(
                 fk.column == fk.parent for fk in m2m_const[0].elements
-            ),
+            ),  # type: ignore [arg-type]
             back_populates=relationship_name,
             collection_class=collection_class,
         )
@@ -1233,4 +1415,4 @@ def _m2m_relationship(
             referred_cfg.properties[backref_name] = rel
             map_config.properties[
                 relationship_name
-            ].back_populates = backref_name
+            ].back_populates = backref_name  # type: ignore[union-attr]
index cb28ac0600d77b0d93a697f8368d81bbb14ea6d3..7323f990e24454435fd4f2a8b0a62467571d23e1 100644 (file)
@@ -57,10 +57,10 @@ if TYPE_CHECKING:
     from .mapper import Mapper
     from .query import Query
     from .relationships import _LazyLoadArgumentType
-    from .relationships import _ORMBackrefArgument
     from .relationships import _ORMColCollectionArgument
     from .relationships import _ORMOrderByArgument
     from .relationships import _RelationshipJoinConditionArgument
+    from .relationships import ORMBackrefArgument
     from .session import _SessionBind
     from ..sql._typing import _ColumnExpressionArgument
     from ..sql._typing import _FromClauseArgument
@@ -781,7 +781,7 @@ def relationship(
     secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
     back_populates: Optional[str] = None,
     order_by: _ORMOrderByArgument = False,
-    backref: Optional[_ORMBackrefArgument] = None,
+    backref: Optional[ORMBackrefArgument] = None,
     overlaps: Optional[str] = None,
     post_update: bool = False,
     cascade: str = "save-update, merge",
@@ -1898,7 +1898,7 @@ def dynamic_loader(
     return relationship(argument, **kw)
 
 
-def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument:
+def backref(name: str, **kwargs: Any) -> ORMBackrefArgument:
     """When using the :paramref:`_orm.relationship.backref` parameter,
     provides specific parameters to be used when the new
     :func:`_orm.relationship` is generated.
index 4a9bcd711ea141c04dbbf526ce06cb7be1f03dc7..2a659bdef3b3988cd1b0e505b65ae3bb00543001 100644 (file)
@@ -171,7 +171,7 @@ _ORMOrderByArgument = Union[
     Callable[[], Iterable[ColumnElement[Any]]],
     Iterable[Union[str, _ColumnExpressionArgument[Any]]],
 ]
-_ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
+ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
 
 _ORMColCollectionElement = Union[
     ColumnClause[Any], _HasClauseElement, roles.DMLColumnRole
@@ -366,7 +366,7 @@ class RelationshipProperty(
         secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
         back_populates: Optional[str] = None,
         order_by: _ORMOrderByArgument = False,
-        backref: Optional[_ORMBackrefArgument] = None,
+        backref: Optional[ORMBackrefArgument] = None,
         overlaps: Optional[str] = None,
         post_update: bool = False,
         cascade: str = "save-update, merge",