]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
All mypy errors are cleared
authorGleb Kisenkov <g.kisenkov@godeltech.com>
Tue, 29 Nov 2022 17:54:59 +0000 (18:54 +0100)
committerGleb Kisenkov <g.kisenkov@godeltech.com>
Tue, 29 Nov 2022 17:54:59 +0000 (18:54 +0100)
lib/sqlalchemy/ext/automap.py

index b7ef48a38ee5b2452e47dad392a0655c6b79b86a..1b3b2547f085e4da91f6f3293632866e10d585ad 100644 (file)
@@ -575,6 +575,7 @@ from __future__ import annotations
 
 from typing import Any
 from typing import Callable
+from typing import cast
 from typing import Dict
 from typing import List
 from typing import Optional
@@ -595,7 +596,7 @@ from ..orm import interfaces
 from ..orm import RelaionshipConstructorType
 from ..orm import relationship
 from ..orm.base import RelationshipDirection
-from ..orm.decl_api import DeclarativeMeta
+from ..orm.decl_api import DeclarativeBase
 from ..orm.decl_base import _DeferredMapperConfig
 from ..orm.mapper import _CONFIGURE_MUTEX
 from ..orm.relationships import _ORMBackrefArgument
@@ -603,6 +604,7 @@ from ..orm.relationships import Relationship
 from ..schema import ForeignKeyConstraint
 from ..sql import and_
 from ..sql.elements import quoted_name
+from ..sql.schema import Column
 from ..sql.schema import Table
 from ..util._py_collections import immutabledict
 from ..util.typing import Protocol
@@ -614,13 +616,15 @@ _VT = TypeVar("_VT", bound=Any)
 
 class ClassnameForTableType(Protocol):
     def __call__(
-        self, base: DeclarativeMeta, tablename: quoted_name, table: Table
+        self, base: Type[Any], tablename: quoted_name, table: Table
     ) -> str:
         ...
 
 
 def classname_for_table(
-    base: DeclarativeMeta, tablename: quoted_name, table: Table
+    base: Type[Any],
+    tablename: quoted_name,
+    table: Table,
 ) -> str:
     """Return the class name that should be used, given the name
     of a table.
@@ -657,18 +661,18 @@ def classname_for_table(
 class NameForScalarRelationshipType(Protocol):
     def __call__(
         self,
-        base: DeclarativeMeta,
-        local_cls: DeclarativeMeta,
-        referred_cls: DeclarativeMeta,
+        base: Type[Any],
+        local_cls: Type[Any],
+        referred_cls: Type[Any],
         constraint: ForeignKeyConstraint,
     ) -> str:
         ...
 
 
 def name_for_scalar_relationship(
-    base: DeclarativeMeta,
-    local_cls: DeclarativeMeta,
-    referred_cls: DeclarativeMeta,
+    base: Type[Any],
+    local_cls: Type[Any],
+    referred_cls: Type[Any],
     constraint: ForeignKeyConstraint,
 ) -> str:
     """Return the attribute name that should be used to refer from one
@@ -698,18 +702,18 @@ def name_for_scalar_relationship(
 class NameForCollectionRelationshipType(Protocol):
     def __call__(
         self,
-        base: DeclarativeMeta,
-        local_cls: DeclarativeMeta,
-        referred_cls: DeclarativeMeta,
+        base: Type[Any],
+        local_cls: Type[Any],
+        referred_cls: Type[Any],
         constraint: ForeignKeyConstraint,
     ) -> str:
         ...
 
 
 def name_for_collection_relationship(
-    base: DeclarativeMeta,
-    local_cls: DeclarativeMeta,
-    referred_cls: DeclarativeMeta,
+    base: Type[Any],
+    local_cls: Type[Any],
+    referred_cls: Type[Any],
     constraint: ForeignKeyConstraint,
 ) -> str:
     """Return the attribute name that should be used to refer from one
@@ -740,24 +744,24 @@ def name_for_collection_relationship(
 class GenerateRelationshipType(Protocol):
     def __call__(
         self,
-        base: DeclarativeMeta,
+        base: Type[Any],
         direction: RelationshipDirection,
         return_fn: BackrefConstructorType | RelaionshipConstructorType,
         attrname: str,
-        local_cls: DeclarativeMeta,
-        referred_cls: DeclarativeMeta,
+        local_cls: Type[Any],
+        referred_cls: Type[Any],
         **kw: Any,
     ) -> _ORMBackrefArgument | Relationship[Any]:
         ...
 
 
 def generate_relationship(
-    base: DeclarativeMeta,
+    base: Type[Any],
     direction: RelationshipDirection,
     return_fn: BackrefConstructorType | RelaionshipConstructorType,
     attrname: str,
-    local_cls: DeclarativeMeta,
-    referred_cls: DeclarativeMeta,
+    local_cls: Type[Any],
+    referred_cls: Type[Any],
     **kw: Any,
 ) -> _ORMBackrefArgument | Relationship[Any]:
     r"""Generate a :func:`_orm.relationship` or :func:`.backref`
@@ -879,7 +883,7 @@ class AutomapBase:
         ),
     )
     def prepare(
-        cls,
+        cls: Type[Any],
         autoload_with: Optional[Engine] = None,
         engine: Optional[Any] = None,
         reflect: bool = False,
@@ -978,6 +982,7 @@ class AutomapBase:
             autoload_with = engine
 
         if reflect:
+            assert autoload_with
             opts = dict(
                 schema=schema,
                 extend_existing=True,
@@ -988,18 +993,30 @@ class AutomapBase:
             cls.metadata.reflect(autoload_with, **opts)
 
         with _CONFIGURE_MUTEX:
-            table_to_map_config = {
-                m.local_table: m
+            table_to_map_config: Union[
+                Dict[Optional[Table], _DeferredMapperConfig],
+                Dict[Table, _DeferredMapperConfig],
+            ] = {
+                cast(Table, m.local_table): m
                 for m in _DeferredMapperConfig.classes_for_base(
                     cls, sort=False
                 )
             }
 
-            many_to_many = []
+            many_to_many: list[
+                tuple[
+                    Table,
+                    Table,
+                    list[ForeignKeyConstraint],
+                    Table,
+                ]
+            ] = []
 
             for table in cls.metadata.tables.values():
                 lcl_m2m, rem_m2m, m2m_const = _is_many_to_many(cls, table)
                 if lcl_m2m is not None:
+                    assert rem_m2m is not None
+                    assert m2m_const is not None
                     many_to_many.append((lcl_m2m, rem_m2m, m2m_const, table))
                 elif not table.primary_key:
                     continue
@@ -1076,8 +1093,8 @@ class AutomapBase:
 
 
 def automap_base(
-    declarative_base: Optional[Any] = None, **kw: Any
-) -> DeclarativeMeta:
+    declarative_base: Optional[Type[DeclarativeBase]] = None, **kw: Any
+) -> Any:
     r"""Produce a declarative automap base.
 
     This function produces a new base class that is a product of the
@@ -1110,7 +1127,7 @@ def automap_base(
 
 
 def _is_many_to_many(
-    automap_base: Type[AutomapBase], table: Table
+    automap_base: Type[Any], table: Table
 ) -> Tuple[
     Optional[Table], Optional[Table], Optional[list[ForeignKeyConstraint]]
 ]:
@@ -1122,7 +1139,7 @@ def _is_many_to_many(
     if len(fk_constraints) != 2:
         return None, None, None
 
-    cols = sum(
+    cols: list[Column[Any]] = sum(
         [
             [fk.parent for fk in fk_constraint.elements]
             for fk_constraint in fk_constraints
@@ -1141,7 +1158,7 @@ def _is_many_to_many(
 
 
 def _relationships_for_fks(
-    automap_base: DeclarativeMeta,
+    automap_base: Type[Any],
     map_config: _DeferredMapperConfig,
     table_to_map_config: Union[
         Dict[Optional[Table], _DeferredMapperConfig],
@@ -1152,8 +1169,10 @@ def _relationships_for_fks(
     name_for_collection_relationship: NameForCollectionRelationshipType,
     generate_relationship: GenerateRelationshipType,
 ) -> None:
-    local_table = map_config.local_table
-    local_cls = map_config.cls  # derived from a weakref, may be None
+    local_table = cast(Union[Table, None], map_config.local_table)
+    local_cls = cast(
+        Union[Type[Any], None], map_config.cls
+    )  # derived from a weakref, may be None
 
     if local_table is None or local_cls is None:
         return
@@ -1178,7 +1197,7 @@ def _relationships_for_fks(
                 automap_base, referred_cls, local_cls, constraint
             )
 
-            o2m_kws = {}
+            o2m_kws: dict[str, str | bool] = {}
             nullable = False not in {fk.parent.nullable for fk in fks}
             if not nullable:
                 o2m_kws["cascade"] = "all, delete-orphan"
@@ -1223,11 +1242,12 @@ def _relationships_for_fks(
                     remote_side=[fk.column for fk in constraint.elements],
                 )
                 if rel is not None:
+                    rel = cast(Relationship[Any], rel)
                     map_config.properties[relationship_name] = rel
                     if not create_backref:
                         referred_cfg.properties[
                             backref_name
-                        ].back_populates = relationship_name
+                        ].back_populates = relationship_name  # type: ignore[union-attr] # noqa: E501
             elif create_backref:
                 rel = generate_relationship(
                     automap_base,
@@ -1242,14 +1262,15 @@ def _relationships_for_fks(
                     **o2m_kws,
                 )
                 if rel is not None:
+                    rel = cast(Relationship[Any], rel)
                     referred_cfg.properties[backref_name] = rel
                     map_config.properties[
                         relationship_name
-                    ].back_populates = backref_name
+                    ].back_populates = backref_name  # type: ignore[union-attr]
 
 
 def _m2m_relationship(
-    automap_base: DeclarativeMeta,
+    automap_base: Type[Any],
     lcl_m2m: Table,
     rem_m2m: Table,
     m2m_const: List[ForeignKeyConstraint],
@@ -1312,20 +1333,21 @@ def _m2m_relationship(
             secondary=table,
             primaryjoin=and_(
                 fk.column == fk.parent for fk in m2m_const[0].elements
-            ),
+            ),  # type: ignore [arg-type]
             secondaryjoin=and_(
                 fk.column == fk.parent for fk in m2m_const[1].elements
-            ),
+            ),  # type: ignore [arg-type]
             backref=backref_obj,
             collection_class=collection_class,
         )
         if rel is not None:
+            rel = cast(Relationship[Any], rel)
             map_config.properties[relationship_name] = rel
 
             if not create_backref:
                 referred_cfg.properties[
                     backref_name
-                ].back_populates = relationship_name
+                ].back_populates = relationship_name  # type: ignore[union-attr] # noqa: E501
     elif create_backref:
         rel = generate_relationship(
             automap_base,
@@ -1338,15 +1360,16 @@ def _m2m_relationship(
             secondary=table,
             primaryjoin=and_(
                 fk.column == fk.parent for fk in m2m_const[1].elements
-            ),
+            ),  # type: ignore [arg-type]
             secondaryjoin=and_(
                 fk.column == fk.parent for fk in m2m_const[0].elements
-            ),
+            ),  # type: ignore [arg-type]
             back_populates=relationship_name,
             collection_class=collection_class,
         )
         if rel is not None:
+            rel = cast(Relationship[Any], rel)
             referred_cfg.properties[backref_name] = rel
             map_config.properties[
                 relationship_name
-            ].back_populates = backref_name
+            ].back_populates = backref_name  # type: ignore[union-attr]