From c25fda3e4d0888f080c93b915a67201e8ff81995 Mon Sep 17 00:00:00 2001 From: Gleb Kisenkov Date: Sun, 27 Nov 2022 16:01:34 +0100 Subject: [PATCH] Added protocols, refined some of the runtime-caught type hints --- lib/sqlalchemy/ext/automap.py | 118 +++++++++++++++++++----- lib/sqlalchemy/orm/__init__.py | 6 ++ lib/sqlalchemy/orm/_orm_constructors.py | 54 +++++++++++ test/conftest.py | 26 ------ 4 files changed, 154 insertions(+), 50 deletions(-) diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index d5832e4ef2..b7ef48a38e 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -579,27 +579,44 @@ from typing import Dict from typing import List from typing import Optional from typing import Tuple +from typing import Type +from typing import TypeVar from typing import Union from mypy_extensions import NoReturn -from sqlalchemy.engine.base import Engine -from sqlalchemy.orm.base import RelationshipDirection -from sqlalchemy.orm.decl_api import DeclarativeMeta -from sqlalchemy.orm.relationships import Relationship -from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.sql.schema import Table -from sqlalchemy.util._py_collections import immutabledict from .. import util +from ..engine.base import Engine from ..orm import backref +from ..orm import BackrefConstructorType from ..orm import declarative_base as _declarative_base from ..orm import exc as orm_exc from ..orm import interfaces +from ..orm import RelaionshipConstructorType from ..orm import relationship +from ..orm.base import RelationshipDirection +from ..orm.decl_api import DeclarativeMeta from ..orm.decl_base import _DeferredMapperConfig from ..orm.mapper import _CONFIGURE_MUTEX +from ..orm.relationships import _ORMBackrefArgument +from ..orm.relationships import Relationship from ..schema import ForeignKeyConstraint from ..sql import and_ +from ..sql.elements import quoted_name +from ..sql.schema import Table +from ..util._py_collections import immutabledict +from ..util.typing import Protocol +from ..util.typing import TypeGuard + +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + + +class ClassnameForTableType(Protocol): + def __call__( + self, base: DeclarativeMeta, tablename: quoted_name, table: Table + ) -> str: + ... def classname_for_table( @@ -637,6 +654,17 @@ def classname_for_table( return str(tablename) +class NameForScalarRelationshipType(Protocol): + def __call__( + self, + base: DeclarativeMeta, + local_cls: DeclarativeMeta, + referred_cls: DeclarativeMeta, + constraint: ForeignKeyConstraint, + ) -> str: + ... + + def name_for_scalar_relationship( base: DeclarativeMeta, local_cls: DeclarativeMeta, @@ -667,6 +695,17 @@ def name_for_scalar_relationship( return referred_cls.__name__.lower() +class NameForCollectionRelationshipType(Protocol): + def __call__( + self, + base: DeclarativeMeta, + local_cls: DeclarativeMeta, + referred_cls: DeclarativeMeta, + constraint: ForeignKeyConstraint, + ) -> str: + ... + + def name_for_collection_relationship( base: DeclarativeMeta, local_cls: DeclarativeMeta, @@ -698,15 +737,29 @@ def name_for_collection_relationship( return referred_cls.__name__.lower() + "_collection" +class GenerateRelationshipType(Protocol): + def __call__( + self, + base: DeclarativeMeta, + direction: RelationshipDirection, + return_fn: BackrefConstructorType | RelaionshipConstructorType, + attrname: str, + local_cls: DeclarativeMeta, + referred_cls: DeclarativeMeta, + **kw: Any, + ) -> _ORMBackrefArgument | Relationship[Any]: + ... + + def generate_relationship( base: DeclarativeMeta, direction: RelationshipDirection, - return_fn: Callable, + return_fn: BackrefConstructorType | RelaionshipConstructorType, attrname: str, local_cls: DeclarativeMeta, referred_cls: DeclarativeMeta, **kw: Any, -) -> Union[Tuple[str, Dict[str, Any]], Relationship]: +) -> _ORMBackrefArgument | Relationship[Any]: r"""Generate a :func:`_orm.relationship` or :func:`.backref` on behalf of two mapped classes. @@ -755,9 +808,20 @@ def generate_relationship( by the :paramref:`.generate_relationship.return_fn` parameter. """ - if return_fn is backref: + + def is_backref( + func: Callable[..., Any] + ) -> TypeGuard[BackrefConstructorType]: + return func is backref + + def is_relationship( + func: Callable[..., Any] + ) -> TypeGuard[RelaionshipConstructorType]: + return func is relationship + + if is_backref(return_fn): return return_fn(attrname, **kw) - elif return_fn is relationship: + elif is_relationship(return_fn): return return_fn(referred_cls, **kw) else: raise TypeError("Unknown relationship function: %s" % return_fn) @@ -820,13 +884,17 @@ class AutomapBase: engine: Optional[Any] = None, reflect: bool = False, schema: Optional[str] = None, - classname_for_table: Optional[Callable] = None, + classname_for_table: Optional[ClassnameForTableType] = None, collection_class: Optional[Any] = None, - name_for_scalar_relationship: Optional[Callable] = None, - name_for_collection_relationship: Optional[Callable] = None, - generate_relationship: Optional[Callable] = None, + name_for_scalar_relationship: Optional[ + NameForScalarRelationshipType + ] = None, + name_for_collection_relationship: Optional[ + NameForCollectionRelationshipType + ] = None, + generate_relationship: Optional[GenerateRelationshipType] = None, reflection_options: Union[ - Dict[str, Any], immutabledict + Dict[_KT, _VT], immutabledict[_KT, _VT] ] = util.EMPTY_DICT, ) -> None: """Extract mapped classes and relationships from the @@ -1042,8 +1110,10 @@ def automap_base( def _is_many_to_many( - automap_base: DeclarativeMeta, table: Table -) -> Tuple[None, None, None]: + automap_base: Type[AutomapBase], table: Table +) -> Tuple[ + Optional[Table], Optional[Table], Optional[list[ForeignKeyConstraint]] +]: fk_constraints = [ const for const in table.constraints @@ -1078,9 +1148,9 @@ def _relationships_for_fks( Dict[Table, _DeferredMapperConfig], ], collection_class: type, - name_for_scalar_relationship: Callable, - name_for_collection_relationship: Callable, - generate_relationship: Callable, + name_for_scalar_relationship: NameForScalarRelationshipType, + name_for_collection_relationship: NameForCollectionRelationshipType, + generate_relationship: GenerateRelationshipType, ) -> None: local_table = map_config.local_table local_cls = map_config.cls # derived from a weakref, may be None @@ -1189,9 +1259,9 @@ def _m2m_relationship( Dict[Table, _DeferredMapperConfig], ], collection_class: type, - name_for_scalar_relationship: Callable, - name_for_collection_relationship: Callable, - generate_relationship: Callable, + name_for_scalar_relationship: NameForCollectionRelationshipType, + name_for_collection_relationship: NameForCollectionRelationshipType, + generate_relationship: GenerateRelationshipType, ) -> None: map_config = table_to_map_config.get(lcl_m2m, None) diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 96acce2ff8..0618b22619 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -23,6 +23,9 @@ from . import strategy_options as strategy_options from ._orm_constructors import _mapper_fn as mapper from ._orm_constructors import aliased as aliased from ._orm_constructors import backref as backref +from ._orm_constructors import ( + BackrefConstructorType as BackrefConstructorType, +) from ._orm_constructors import clear_mappers as clear_mappers from ._orm_constructors import column_property as column_property from ._orm_constructors import composite as composite @@ -34,6 +37,9 @@ from ._orm_constructors import join as join from ._orm_constructors import mapped_column as mapped_column from ._orm_constructors import outerjoin as outerjoin from ._orm_constructors import query_expression as query_expression +from ._orm_constructors import ( + RelaionshipConstructorType as RelaionshipConstructorType, +) from ._orm_constructors import relationship as relationship from ._orm_constructors import synonym as synonym from ._orm_constructors import with_loader_criteria as with_loader_criteria diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 30119d9d79..0a506e9eef 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -46,6 +46,7 @@ from ..sql.schema import SchemaConst from ..sql.selectable import FromClause from ..util.typing import Annotated from ..util.typing import Literal +from ..util.typing import Protocol if TYPE_CHECKING: from ._typing import _EntityType @@ -734,6 +735,54 @@ def with_loader_criteria( ) +class RelaionshipConstructorType(Protocol): + def __call__( + self, + argument: Optional[_RelationshipArgumentType[Any]] = None, + secondary: Optional[Union[FromClause, str]] = None, + *, + uselist: Optional[bool] = None, + collection_class: Optional[ + Union[Type[Collection[Any]], Callable[[], Collection[Any]]] + ] = None, + primaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + back_populates: Optional[str] = None, + order_by: _ORMOrderByArgument = False, + backref: Optional[_ORMBackrefArgument] = None, + overlaps: Optional[str] = None, + post_update: bool = False, + cascade: str = "save-update, merge", + viewonly: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Union[_NoArg, _T] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + lazy: _LazyLoadArgumentType = "select", + passive_deletes: Union[Literal["all"], bool] = False, + passive_updates: bool = True, + active_history: bool = False, + enable_typechecks: bool = True, + foreign_keys: Optional[_ORMColCollectionArgument] = None, + remote_side: Optional[_ORMColCollectionArgument] = None, + join_depth: Optional[int] = None, + comparator_factory: Optional[ + Type[RelationshipProperty.Comparator[Any]] + ] = None, + single_parent: bool = False, + innerjoin: bool = False, + distinct_target_key: Optional[bool] = None, + load_on_pending: bool = False, + query_class: Optional[Type[Query[Any]]] = None, + info: Optional[_InfoType] = None, + omit_join: Literal[None, False] = None, + sync_backref: Optional[bool] = None, + **kw: Any, + ) -> Relationship[Any]: + ... + + def relationship( argument: Optional[_RelationshipArgumentType[Any]] = None, secondary: Optional[Union[FromClause, str]] = None, @@ -1854,6 +1903,11 @@ def dynamic_loader( return relationship(argument, **kw) +class BackrefConstructorType(Protocol): + def __call__(self, name: str, **kwargs: Any) -> _ORMBackrefArgument: + ... + + def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument: """When using the :paramref:`_orm.relationship.backref` parameter, provides specific parameters to be used when the new diff --git a/test/conftest.py b/test/conftest.py index e2fe7854e7..b7f2d945ca 100755 --- a/test/conftest.py +++ b/test/conftest.py @@ -53,29 +53,3 @@ with open(bootstrap_file) as f: to_bootstrap = "pytest" exec(code, globals(), locals()) from sqla_pytestplugin import * # noqa - - -# def pytest_collection_finish(session): -# """Handle the pytest collection finish hook: configure pyannotate. -# Explicitly delay importing `collect_types` until all tests have -# been collected. This gives gevent a chance to monkey patch the -# world before importing pyannotate. -# """ -# from pyannotate_runtime import collect_types - -# collect_types.init_types_collection() - - -# @pytest.fixture(autouse=True) -# def collect_types_fixture(): -# from pyannotate_runtime import collect_types - -# collect_types.start() -# yield -# collect_types.stop() - - -# def pytest_sessionfinish(session, exitstatus): -# from pyannotate_runtime import collect_types - -# collect_types.dump_stats("type_info.json") -- 2.47.2