From b650ad8263da58fbdc54f8fab7cadca927e3edc9 Mon Sep 17 00:00:00 2001 From: Gleb Kisenkov Date: Tue, 29 Nov 2022 18:54:59 +0100 Subject: [PATCH] All mypy errors are cleared --- lib/sqlalchemy/ext/automap.py | 107 +++++++++++++++++++++------------- 1 file changed, 65 insertions(+), 42 deletions(-) diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index b7ef48a38e..1b3b2547f0 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -575,6 +575,7 @@ from __future__ import annotations from typing import Any from typing import Callable +from typing import cast from typing import Dict from typing import List from typing import Optional @@ -595,7 +596,7 @@ from ..orm import interfaces from ..orm import RelaionshipConstructorType from ..orm import relationship from ..orm.base import RelationshipDirection -from ..orm.decl_api import DeclarativeMeta +from ..orm.decl_api import DeclarativeBase from ..orm.decl_base import _DeferredMapperConfig from ..orm.mapper import _CONFIGURE_MUTEX from ..orm.relationships import _ORMBackrefArgument @@ -603,6 +604,7 @@ from ..orm.relationships import Relationship from ..schema import ForeignKeyConstraint from ..sql import and_ from ..sql.elements import quoted_name +from ..sql.schema import Column from ..sql.schema import Table from ..util._py_collections import immutabledict from ..util.typing import Protocol @@ -614,13 +616,15 @@ _VT = TypeVar("_VT", bound=Any) class ClassnameForTableType(Protocol): def __call__( - self, base: DeclarativeMeta, tablename: quoted_name, table: Table + self, base: Type[Any], tablename: quoted_name, table: Table ) -> str: ... def classname_for_table( - base: DeclarativeMeta, tablename: quoted_name, table: Table + base: Type[Any], + tablename: quoted_name, + table: Table, ) -> str: """Return the class name that should be used, given the name of a table. @@ -657,18 +661,18 @@ def classname_for_table( class NameForScalarRelationshipType(Protocol): def __call__( self, - base: DeclarativeMeta, - local_cls: DeclarativeMeta, - referred_cls: DeclarativeMeta, + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], constraint: ForeignKeyConstraint, ) -> str: ... def name_for_scalar_relationship( - base: DeclarativeMeta, - local_cls: DeclarativeMeta, - referred_cls: DeclarativeMeta, + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], constraint: ForeignKeyConstraint, ) -> str: """Return the attribute name that should be used to refer from one @@ -698,18 +702,18 @@ def name_for_scalar_relationship( class NameForCollectionRelationshipType(Protocol): def __call__( self, - base: DeclarativeMeta, - local_cls: DeclarativeMeta, - referred_cls: DeclarativeMeta, + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], constraint: ForeignKeyConstraint, ) -> str: ... def name_for_collection_relationship( - base: DeclarativeMeta, - local_cls: DeclarativeMeta, - referred_cls: DeclarativeMeta, + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], constraint: ForeignKeyConstraint, ) -> str: """Return the attribute name that should be used to refer from one @@ -740,24 +744,24 @@ def name_for_collection_relationship( class GenerateRelationshipType(Protocol): def __call__( self, - base: DeclarativeMeta, + base: Type[Any], direction: RelationshipDirection, return_fn: BackrefConstructorType | RelaionshipConstructorType, attrname: str, - local_cls: DeclarativeMeta, - referred_cls: DeclarativeMeta, + local_cls: Type[Any], + referred_cls: Type[Any], **kw: Any, ) -> _ORMBackrefArgument | Relationship[Any]: ... def generate_relationship( - base: DeclarativeMeta, + base: Type[Any], direction: RelationshipDirection, return_fn: BackrefConstructorType | RelaionshipConstructorType, attrname: str, - local_cls: DeclarativeMeta, - referred_cls: DeclarativeMeta, + local_cls: Type[Any], + referred_cls: Type[Any], **kw: Any, ) -> _ORMBackrefArgument | Relationship[Any]: r"""Generate a :func:`_orm.relationship` or :func:`.backref` @@ -879,7 +883,7 @@ class AutomapBase: ), ) def prepare( - cls, + cls: Type[Any], autoload_with: Optional[Engine] = None, engine: Optional[Any] = None, reflect: bool = False, @@ -978,6 +982,7 @@ class AutomapBase: autoload_with = engine if reflect: + assert autoload_with opts = dict( schema=schema, extend_existing=True, @@ -988,18 +993,30 @@ class AutomapBase: cls.metadata.reflect(autoload_with, **opts) with _CONFIGURE_MUTEX: - table_to_map_config = { - m.local_table: m + table_to_map_config: Union[ + Dict[Optional[Table], _DeferredMapperConfig], + Dict[Table, _DeferredMapperConfig], + ] = { + cast(Table, m.local_table): m for m in _DeferredMapperConfig.classes_for_base( cls, sort=False ) } - many_to_many = [] + many_to_many: list[ + tuple[ + Table, + Table, + list[ForeignKeyConstraint], + Table, + ] + ] = [] for table in cls.metadata.tables.values(): lcl_m2m, rem_m2m, m2m_const = _is_many_to_many(cls, table) if lcl_m2m is not None: + assert rem_m2m is not None + assert m2m_const is not None many_to_many.append((lcl_m2m, rem_m2m, m2m_const, table)) elif not table.primary_key: continue @@ -1076,8 +1093,8 @@ class AutomapBase: def automap_base( - declarative_base: Optional[Any] = None, **kw: Any -) -> DeclarativeMeta: + declarative_base: Optional[Type[DeclarativeBase]] = None, **kw: Any +) -> Any: r"""Produce a declarative automap base. This function produces a new base class that is a product of the @@ -1110,7 +1127,7 @@ def automap_base( def _is_many_to_many( - automap_base: Type[AutomapBase], table: Table + automap_base: Type[Any], table: Table ) -> Tuple[ Optional[Table], Optional[Table], Optional[list[ForeignKeyConstraint]] ]: @@ -1122,7 +1139,7 @@ def _is_many_to_many( if len(fk_constraints) != 2: return None, None, None - cols = sum( + cols: list[Column[Any]] = sum( [ [fk.parent for fk in fk_constraint.elements] for fk_constraint in fk_constraints @@ -1141,7 +1158,7 @@ def _is_many_to_many( def _relationships_for_fks( - automap_base: DeclarativeMeta, + automap_base: Type[Any], map_config: _DeferredMapperConfig, table_to_map_config: Union[ Dict[Optional[Table], _DeferredMapperConfig], @@ -1152,8 +1169,10 @@ def _relationships_for_fks( name_for_collection_relationship: NameForCollectionRelationshipType, generate_relationship: GenerateRelationshipType, ) -> None: - local_table = map_config.local_table - local_cls = map_config.cls # derived from a weakref, may be None + local_table = cast(Union[Table, None], map_config.local_table) + local_cls = cast( + Union[Type[Any], None], map_config.cls + ) # derived from a weakref, may be None if local_table is None or local_cls is None: return @@ -1178,7 +1197,7 @@ def _relationships_for_fks( automap_base, referred_cls, local_cls, constraint ) - o2m_kws = {} + o2m_kws: dict[str, str | bool] = {} nullable = False not in {fk.parent.nullable for fk in fks} if not nullable: o2m_kws["cascade"] = "all, delete-orphan" @@ -1223,11 +1242,12 @@ def _relationships_for_fks( remote_side=[fk.column for fk in constraint.elements], ) if rel is not None: + rel = cast(Relationship[Any], rel) map_config.properties[relationship_name] = rel if not create_backref: referred_cfg.properties[ backref_name - ].back_populates = relationship_name + ].back_populates = relationship_name # type: ignore[union-attr] # noqa: E501 elif create_backref: rel = generate_relationship( automap_base, @@ -1242,14 +1262,15 @@ def _relationships_for_fks( **o2m_kws, ) if rel is not None: + rel = cast(Relationship[Any], rel) referred_cfg.properties[backref_name] = rel map_config.properties[ relationship_name - ].back_populates = backref_name + ].back_populates = backref_name # type: ignore[union-attr] def _m2m_relationship( - automap_base: DeclarativeMeta, + automap_base: Type[Any], lcl_m2m: Table, rem_m2m: Table, m2m_const: List[ForeignKeyConstraint], @@ -1312,20 +1333,21 @@ def _m2m_relationship( secondary=table, primaryjoin=and_( fk.column == fk.parent for fk in m2m_const[0].elements - ), + ), # type: ignore [arg-type] secondaryjoin=and_( fk.column == fk.parent for fk in m2m_const[1].elements - ), + ), # type: ignore [arg-type] backref=backref_obj, collection_class=collection_class, ) if rel is not None: + rel = cast(Relationship[Any], rel) map_config.properties[relationship_name] = rel if not create_backref: referred_cfg.properties[ backref_name - ].back_populates = relationship_name + ].back_populates = relationship_name # type: ignore[union-attr] # noqa: E501 elif create_backref: rel = generate_relationship( automap_base, @@ -1338,15 +1360,16 @@ def _m2m_relationship( secondary=table, primaryjoin=and_( fk.column == fk.parent for fk in m2m_const[1].elements - ), + ), # type: ignore [arg-type] secondaryjoin=and_( fk.column == fk.parent for fk in m2m_const[0].elements - ), + ), # type: ignore [arg-type] back_populates=relationship_name, collection_class=collection_class, ) if rel is not None: + rel = cast(Relationship[Any], rel) referred_cfg.properties[backref_name] = rel map_config.properties[ relationship_name - ].back_populates = backref_name + ].back_populates = backref_name # type: ignore[union-attr] -- 2.47.2