From: Gleb Kisenkov Date: Mon, 5 Dec 2022 13:45:25 +0000 (-0500) Subject: Type annotations for sqlalchemy.ext.automap X-Git-Tag: rel_2_0_0b4~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=422d8d3bcbf2b60f053ab76c3fc29f33242ccf4b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Type annotations for sqlalchemy.ext.automap An attempt to annotate `lib/sqlalchemy/ext/automap.py` with type hints (issue [#6810](https://github.com/sqlalchemy/sqlalchemy/issues/6810#issuecomment-1127062951)). More info on how I approach it could be found in [the earlier PR](https://github.com/sqlalchemy/sqlalchemy/pull/8775). This pull request is: - [ ] A documentation / typographical error fix - Good to go, no issue or tests are needed - [ ] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [x] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. **Have a nice day!** Closes: #8874 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8874 Pull-request-sha: 834d58d77c2cfd09ea874f01eb7b75a2ea0db7cd Change-Id: Ie64b2be7a51ddc83ef8f23385fb63db5b5c1bc17 --- diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index 6eb30ba4c6..217113309f 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors r"""Define an extension to the :mod:`sqlalchemy.ext.declarative` system which automatically generates mapped classes and relationships from a database @@ -572,6 +571,22 @@ be applied as:: """ # noqa +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import List +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + from .. import util from ..orm import backref from ..orm import declarative_base as _declarative_base @@ -582,9 +597,36 @@ from ..orm.decl_base import _DeferredMapperConfig from ..orm.mapper import _CONFIGURE_MUTEX from ..schema import ForeignKeyConstraint from ..sql import and_ +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ..engine.base import Engine + from ..orm.base import RelationshipDirection + from ..orm.relationships import ORMBackrefArgument + from ..orm.relationships import Relationship + from ..sql.elements import quoted_name + from ..sql.schema import Column + from ..sql.schema import Table + from ..util import immutabledict + from ..util import Properties + + +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + +class ClassnameForTableType(Protocol): + def __call__( + self, base: Type[Any], tablename: quoted_name, table: Table + ) -> str: + ... -def classname_for_table(base, tablename, table): + +def classname_for_table( + base: Type[Any], + tablename: quoted_name, + table: Table, +) -> str: """Return the class name that should be used, given the name of a table. @@ -617,7 +659,23 @@ def classname_for_table(base, tablename, table): return str(tablename) -def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): +class NameForScalarRelationshipType(Protocol): + def __call__( + self, + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, + ) -> str: + ... + + +def name_for_scalar_relationship( + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, +) -> str: """Return the attribute name that should be used to refer from one class to another, for a scalar object reference. @@ -642,9 +700,23 @@ def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): return referred_cls.__name__.lower() +class NameForCollectionRelationshipType(Protocol): + def __call__( + self, + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, + ) -> str: + ... + + def name_for_collection_relationship( - base, local_cls, referred_cls, constraint -): + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, +) -> str: """Return the attribute name that should be used to refer from one class to another, for a collection reference. @@ -670,9 +742,85 @@ def name_for_collection_relationship( return referred_cls.__name__.lower() + "_collection" +class GenerateRelationshipType(Protocol): + @overload + def __call__( + self, + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., Relationship[Any]], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, + ) -> Relationship[Any]: + ... + + @overload + def __call__( + self, + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., ORMBackrefArgument], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, + ) -> ORMBackrefArgument: + ... + + def __call__( + self, + base: Type[Any], + direction: RelationshipDirection, + return_fn: Union[ + Callable[..., Relationship[Any]], Callable[..., ORMBackrefArgument] + ], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, + ) -> Union[ORMBackrefArgument, Relationship[Any]]: + ... + + +@overload +def generate_relationship( + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., Relationship[Any]], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, +) -> Relationship[Any]: + ... + + +@overload def generate_relationship( - base, direction, return_fn, attrname, local_cls, referred_cls, **kw -): + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., ORMBackrefArgument], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, +) -> ORMBackrefArgument: + ... + + +def generate_relationship( + base: Type[Any], + direction: RelationshipDirection, + return_fn: Union[ + Callable[..., Relationship[Any]], Callable[..., ORMBackrefArgument] + ], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, +) -> Union[Relationship[Any], ORMBackrefArgument]: r"""Generate a :func:`_orm.relationship` or :func:`.backref` on behalf of two mapped classes. @@ -721,6 +869,7 @@ def generate_relationship( by the :paramref:`.generate_relationship.return_fn` parameter. """ + if return_fn is backref: return return_fn(attrname, **kw) elif return_fn is relationship: @@ -748,7 +897,7 @@ class AutomapBase: __abstract__ = True - classes = None + classes: Optional[Properties[Type[Any]]] = None """An instance of :class:`.util.Properties` containing classes. This object behaves much like the ``.c`` collection on a table. Classes @@ -781,18 +930,24 @@ class AutomapBase: ), ) def prepare( - cls, - autoload_with=None, - engine=None, - reflect=False, - schema=None, - classname_for_table=None, - collection_class=None, - name_for_scalar_relationship=None, - name_for_collection_relationship=None, - generate_relationship=None, - reflection_options=util.EMPTY_DICT, - ): + cls: Type[Any], + autoload_with: Optional[Engine] = None, + engine: Optional[Any] = None, + reflect: bool = False, + schema: Optional[str] = None, + classname_for_table: Optional[ClassnameForTableType] = None, + collection_class: Optional[Any] = None, + name_for_scalar_relationship: Optional[ + NameForScalarRelationshipType + ] = None, + name_for_collection_relationship: Optional[ + NameForCollectionRelationshipType + ] = None, + generate_relationship: Optional[GenerateRelationshipType] = None, + reflection_options: Union[ + Dict[_KT, _VT], immutabledict[_KT, _VT] + ] = util.EMPTY_DICT, + ) -> None: """Extract mapped classes and relationships from the :class:`_schema.MetaData` and perform mappings. @@ -874,6 +1029,7 @@ class AutomapBase: autoload_with = engine if reflect: + assert autoload_with opts = dict( schema=schema, extend_existing=True, @@ -884,18 +1040,30 @@ class AutomapBase: cls.metadata.reflect(autoload_with, **opts) with _CONFIGURE_MUTEX: - table_to_map_config = { - m.local_table: m + table_to_map_config: Union[ + Dict[Optional[Table], _DeferredMapperConfig], + Dict[Table, _DeferredMapperConfig], + ] = { + cast("Table", m.local_table): m for m in _DeferredMapperConfig.classes_for_base( cls, sort=False ) } - many_to_many = [] + many_to_many: list[ + tuple[ + Table, + Table, + list[ForeignKeyConstraint], + Table, + ] + ] = [] for table in cls.metadata.tables.values(): lcl_m2m, rem_m2m, m2m_const = _is_many_to_many(cls, table) if lcl_m2m is not None: + assert rem_m2m is not None + assert m2m_const is not None many_to_many.append((lcl_m2m, rem_m2m, m2m_const, table)) elif not table.primary_key: continue @@ -961,7 +1129,7 @@ class AutomapBase: """ @classmethod - def _sa_raise_deferred_config(cls): + def _sa_raise_deferred_config(cls) -> NoReturn: raise orm_exc.UnmappedClassError( cls, msg="Class %s is a subclass of AutomapBase. " @@ -971,7 +1139,9 @@ class AutomapBase: ) -def automap_base(declarative_base=None, **kw): +def automap_base( + declarative_base: Optional[Type[Any]] = None, **kw: Any +) -> Any: r"""Produce a declarative automap base. This function produces a new base class that is a product of the @@ -1003,7 +1173,11 @@ def automap_base(declarative_base=None, **kw): ) -def _is_many_to_many(automap_base, table): +def _is_many_to_many( + automap_base: Type[Any], table: Table +) -> Tuple[ + Optional[Table], Optional[Table], Optional[list[ForeignKeyConstraint]] +]: fk_constraints = [ const for const in table.constraints @@ -1012,7 +1186,7 @@ def _is_many_to_many(automap_base, table): if len(fk_constraints) != 2: return None, None, None - cols = sum( + cols: list[Column[Any]] = sum( [ [fk.parent for fk in fk_constraint.elements] for fk_constraint in fk_constraints @@ -1031,16 +1205,21 @@ def _is_many_to_many(automap_base, table): def _relationships_for_fks( - automap_base, - map_config, - table_to_map_config, - collection_class, - name_for_scalar_relationship, - name_for_collection_relationship, - generate_relationship, -): - local_table = map_config.local_table - local_cls = map_config.cls # derived from a weakref, may be None + automap_base: Type[Any], + map_config: _DeferredMapperConfig, + table_to_map_config: Union[ + Dict[Optional[Table], _DeferredMapperConfig], + Dict[Table, _DeferredMapperConfig], + ], + collection_class: type, + name_for_scalar_relationship: NameForScalarRelationshipType, + name_for_collection_relationship: NameForCollectionRelationshipType, + generate_relationship: GenerateRelationshipType, +) -> None: + local_table = cast("Optional[Table]", map_config.local_table) + local_cls = cast( + "Optional[Type[Any]]", map_config.cls + ) # derived from a weakref, may be None if local_table is None or local_cls is None: return @@ -1065,7 +1244,7 @@ def _relationships_for_fks( automap_base, referred_cls, local_cls, constraint ) - o2m_kws = {} + o2m_kws: dict[str, Union[str, bool]] = {} nullable = False not in {fk.parent.nullable for fk in fks} if not nullable: o2m_kws["cascade"] = "all, delete-orphan" @@ -1114,7 +1293,7 @@ def _relationships_for_fks( if not create_backref: referred_cfg.properties[ backref_name - ].back_populates = relationship_name + ].back_populates = relationship_name # type: ignore[union-attr] # noqa: E501 elif create_backref: rel = generate_relationship( automap_base, @@ -1132,21 +1311,24 @@ def _relationships_for_fks( referred_cfg.properties[backref_name] = rel map_config.properties[ relationship_name - ].back_populates = backref_name + ].back_populates = backref_name # type: ignore[union-attr] def _m2m_relationship( - automap_base, - lcl_m2m, - rem_m2m, - m2m_const, - table, - table_to_map_config, - collection_class, - name_for_scalar_relationship, - name_for_collection_relationship, - generate_relationship, -): + automap_base: Type[Any], + lcl_m2m: Table, + rem_m2m: Table, + m2m_const: List[ForeignKeyConstraint], + table: Table, + table_to_map_config: Union[ + Dict[Optional[Table], _DeferredMapperConfig], + Dict[Table, _DeferredMapperConfig], + ], + collection_class: type, + name_for_scalar_relationship: NameForCollectionRelationshipType, + name_for_collection_relationship: NameForCollectionRelationshipType, + generate_relationship: GenerateRelationshipType, +) -> None: map_config = table_to_map_config.get(lcl_m2m, None) referred_cfg = table_to_map_config.get(rem_m2m, None) @@ -1196,10 +1378,10 @@ def _m2m_relationship( secondary=table, primaryjoin=and_( fk.column == fk.parent for fk in m2m_const[0].elements - ), + ), # type: ignore [arg-type] secondaryjoin=and_( fk.column == fk.parent for fk in m2m_const[1].elements - ), + ), # type: ignore [arg-type] backref=backref_obj, collection_class=collection_class, ) @@ -1209,7 +1391,7 @@ def _m2m_relationship( if not create_backref: referred_cfg.properties[ backref_name - ].back_populates = relationship_name + ].back_populates = relationship_name # type: ignore[union-attr] # noqa: E501 elif create_backref: rel = generate_relationship( automap_base, @@ -1222,10 +1404,10 @@ def _m2m_relationship( secondary=table, primaryjoin=and_( fk.column == fk.parent for fk in m2m_const[1].elements - ), + ), # type: ignore [arg-type] secondaryjoin=and_( fk.column == fk.parent for fk in m2m_const[0].elements - ), + ), # type: ignore [arg-type] back_populates=relationship_name, collection_class=collection_class, ) @@ -1233,4 +1415,4 @@ def _m2m_relationship( referred_cfg.properties[backref_name] = rel map_config.properties[ relationship_name - ].back_populates = backref_name + ].back_populates = backref_name # type: ignore[union-attr] diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index cb28ac0600..7323f990e2 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -57,10 +57,10 @@ if TYPE_CHECKING: from .mapper import Mapper from .query import Query from .relationships import _LazyLoadArgumentType - from .relationships import _ORMBackrefArgument from .relationships import _ORMColCollectionArgument from .relationships import _ORMOrderByArgument from .relationships import _RelationshipJoinConditionArgument + from .relationships import ORMBackrefArgument from .session import _SessionBind from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _FromClauseArgument @@ -781,7 +781,7 @@ def relationship( secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, back_populates: Optional[str] = None, order_by: _ORMOrderByArgument = False, - backref: Optional[_ORMBackrefArgument] = None, + backref: Optional[ORMBackrefArgument] = None, overlaps: Optional[str] = None, post_update: bool = False, cascade: str = "save-update, merge", @@ -1898,7 +1898,7 @@ def dynamic_loader( return relationship(argument, **kw) -def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument: +def backref(name: str, **kwargs: Any) -> ORMBackrefArgument: """When using the :paramref:`_orm.relationship.backref` parameter, provides specific parameters to be used when the new :func:`_orm.relationship` is generated. diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 4a9bcd711e..2a659bdef3 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -171,7 +171,7 @@ _ORMOrderByArgument = Union[ Callable[[], Iterable[ColumnElement[Any]]], Iterable[Union[str, _ColumnExpressionArgument[Any]]], ] -_ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]] +ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]] _ORMColCollectionElement = Union[ ColumnClause[Any], _HasClauseElement, roles.DMLColumnRole @@ -366,7 +366,7 @@ class RelationshipProperty( secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, back_populates: Optional[str] = None, order_by: _ORMOrderByArgument = False, - backref: Optional[_ORMBackrefArgument] = None, + backref: Optional[ORMBackrefArgument] = None, overlaps: Optional[str] = None, post_update: bool = False, cascade: str = "save-update, merge",