]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added protocols, refined some of the runtime-caught type hints
authorGleb Kisenkov <g.kisenkov@godeltech.com>
Sun, 27 Nov 2022 15:01:34 +0000 (16:01 +0100)
committerGleb Kisenkov <g.kisenkov@godeltech.com>
Sun, 27 Nov 2022 15:01:34 +0000 (16:01 +0100)
lib/sqlalchemy/ext/automap.py
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/_orm_constructors.py
test/conftest.py

index d5832e4ef23efe4fe169f8dc8e7505bfc6f2ca86..b7ef48a38ee5b2452e47dad392a0655c6b79b86a 100644 (file)
@@ -579,27 +579,44 @@ from typing import Dict
 from typing import List
 from typing import Optional
 from typing import Tuple
+from typing import Type
+from typing import TypeVar
 from typing import Union
 
 from mypy_extensions import NoReturn
 
-from sqlalchemy.engine.base import Engine
-from sqlalchemy.orm.base import RelationshipDirection
-from sqlalchemy.orm.decl_api import DeclarativeMeta
-from sqlalchemy.orm.relationships import Relationship
-from sqlalchemy.sql.elements import quoted_name
-from sqlalchemy.sql.schema import Table
-from sqlalchemy.util._py_collections import immutabledict
 from .. import util
+from ..engine.base import Engine
 from ..orm import backref
+from ..orm import BackrefConstructorType
 from ..orm import declarative_base as _declarative_base
 from ..orm import exc as orm_exc
 from ..orm import interfaces
+from ..orm import RelaionshipConstructorType
 from ..orm import relationship
+from ..orm.base import RelationshipDirection
+from ..orm.decl_api import DeclarativeMeta
 from ..orm.decl_base import _DeferredMapperConfig
 from ..orm.mapper import _CONFIGURE_MUTEX
+from ..orm.relationships import _ORMBackrefArgument
+from ..orm.relationships import Relationship
 from ..schema import ForeignKeyConstraint
 from ..sql import and_
+from ..sql.elements import quoted_name
+from ..sql.schema import Table
+from ..util._py_collections import immutabledict
+from ..util.typing import Protocol
+from ..util.typing import TypeGuard
+
+_KT = TypeVar("_KT", bound=Any)
+_VT = TypeVar("_VT", bound=Any)
+
+
+class ClassnameForTableType(Protocol):
+    def __call__(
+        self, base: DeclarativeMeta, tablename: quoted_name, table: Table
+    ) -> str:
+        ...
 
 
 def classname_for_table(
@@ -637,6 +654,17 @@ def classname_for_table(
     return str(tablename)
 
 
+class NameForScalarRelationshipType(Protocol):
+    def __call__(
+        self,
+        base: DeclarativeMeta,
+        local_cls: DeclarativeMeta,
+        referred_cls: DeclarativeMeta,
+        constraint: ForeignKeyConstraint,
+    ) -> str:
+        ...
+
+
 def name_for_scalar_relationship(
     base: DeclarativeMeta,
     local_cls: DeclarativeMeta,
@@ -667,6 +695,17 @@ def name_for_scalar_relationship(
     return referred_cls.__name__.lower()
 
 
+class NameForCollectionRelationshipType(Protocol):
+    def __call__(
+        self,
+        base: DeclarativeMeta,
+        local_cls: DeclarativeMeta,
+        referred_cls: DeclarativeMeta,
+        constraint: ForeignKeyConstraint,
+    ) -> str:
+        ...
+
+
 def name_for_collection_relationship(
     base: DeclarativeMeta,
     local_cls: DeclarativeMeta,
@@ -698,15 +737,29 @@ def name_for_collection_relationship(
     return referred_cls.__name__.lower() + "_collection"
 
 
+class GenerateRelationshipType(Protocol):
+    def __call__(
+        self,
+        base: DeclarativeMeta,
+        direction: RelationshipDirection,
+        return_fn: BackrefConstructorType | RelaionshipConstructorType,
+        attrname: str,
+        local_cls: DeclarativeMeta,
+        referred_cls: DeclarativeMeta,
+        **kw: Any,
+    ) -> _ORMBackrefArgument | Relationship[Any]:
+        ...
+
+
 def generate_relationship(
     base: DeclarativeMeta,
     direction: RelationshipDirection,
-    return_fn: Callable,
+    return_fn: BackrefConstructorType | RelaionshipConstructorType,
     attrname: str,
     local_cls: DeclarativeMeta,
     referred_cls: DeclarativeMeta,
     **kw: Any,
-) -> Union[Tuple[str, Dict[str, Any]], Relationship]:
+) -> _ORMBackrefArgument | Relationship[Any]:
     r"""Generate a :func:`_orm.relationship` or :func:`.backref`
     on behalf of two
     mapped classes.
@@ -755,9 +808,20 @@ def generate_relationship(
      by the :paramref:`.generate_relationship.return_fn` parameter.
 
     """
-    if return_fn is backref:
+
+    def is_backref(
+        func: Callable[..., Any]
+    ) -> TypeGuard[BackrefConstructorType]:
+        return func is backref
+
+    def is_relationship(
+        func: Callable[..., Any]
+    ) -> TypeGuard[RelaionshipConstructorType]:
+        return func is relationship
+
+    if is_backref(return_fn):
         return return_fn(attrname, **kw)
-    elif return_fn is relationship:
+    elif is_relationship(return_fn):
         return return_fn(referred_cls, **kw)
     else:
         raise TypeError("Unknown relationship function: %s" % return_fn)
@@ -820,13 +884,17 @@ class AutomapBase:
         engine: Optional[Any] = None,
         reflect: bool = False,
         schema: Optional[str] = None,
-        classname_for_table: Optional[Callable] = None,
+        classname_for_table: Optional[ClassnameForTableType] = None,
         collection_class: Optional[Any] = None,
-        name_for_scalar_relationship: Optional[Callable] = None,
-        name_for_collection_relationship: Optional[Callable] = None,
-        generate_relationship: Optional[Callable] = None,
+        name_for_scalar_relationship: Optional[
+            NameForScalarRelationshipType
+        ] = None,
+        name_for_collection_relationship: Optional[
+            NameForCollectionRelationshipType
+        ] = None,
+        generate_relationship: Optional[GenerateRelationshipType] = None,
         reflection_options: Union[
-            Dict[str, Any], immutabledict
+            Dict[_KT, _VT], immutabledict[_KT, _VT]
         ] = util.EMPTY_DICT,
     ) -> None:
         """Extract mapped classes and relationships from the
@@ -1042,8 +1110,10 @@ def automap_base(
 
 
 def _is_many_to_many(
-    automap_base: DeclarativeMeta, table: Table
-) -> Tuple[None, None, None]:
+    automap_base: Type[AutomapBase], table: Table
+) -> Tuple[
+    Optional[Table], Optional[Table], Optional[list[ForeignKeyConstraint]]
+]:
     fk_constraints = [
         const
         for const in table.constraints
@@ -1078,9 +1148,9 @@ def _relationships_for_fks(
         Dict[Table, _DeferredMapperConfig],
     ],
     collection_class: type,
-    name_for_scalar_relationship: Callable,
-    name_for_collection_relationship: Callable,
-    generate_relationship: Callable,
+    name_for_scalar_relationship: NameForScalarRelationshipType,
+    name_for_collection_relationship: NameForCollectionRelationshipType,
+    generate_relationship: GenerateRelationshipType,
 ) -> None:
     local_table = map_config.local_table
     local_cls = map_config.cls  # derived from a weakref, may be None
@@ -1189,9 +1259,9 @@ def _m2m_relationship(
         Dict[Table, _DeferredMapperConfig],
     ],
     collection_class: type,
-    name_for_scalar_relationship: Callable,
-    name_for_collection_relationship: Callable,
-    generate_relationship: Callable,
+    name_for_scalar_relationship: NameForCollectionRelationshipType,
+    name_for_collection_relationship: NameForCollectionRelationshipType,
+    generate_relationship: GenerateRelationshipType,
 ) -> None:
 
     map_config = table_to_map_config.get(lcl_m2m, None)
index 96acce2ff885b4109404f680bcc523a1cc75497f..0618b226198df0c7de649a9a8e50d149b6fdfad1 100644 (file)
@@ -23,6 +23,9 @@ from . import strategy_options as strategy_options
 from ._orm_constructors import _mapper_fn as mapper
 from ._orm_constructors import aliased as aliased
 from ._orm_constructors import backref as backref
+from ._orm_constructors import (
+    BackrefConstructorType as BackrefConstructorType,
+)
 from ._orm_constructors import clear_mappers as clear_mappers
 from ._orm_constructors import column_property as column_property
 from ._orm_constructors import composite as composite
@@ -34,6 +37,9 @@ from ._orm_constructors import join as join
 from ._orm_constructors import mapped_column as mapped_column
 from ._orm_constructors import outerjoin as outerjoin
 from ._orm_constructors import query_expression as query_expression
+from ._orm_constructors import (
+    RelaionshipConstructorType as RelaionshipConstructorType,
+)
 from ._orm_constructors import relationship as relationship
 from ._orm_constructors import synonym as synonym
 from ._orm_constructors import with_loader_criteria as with_loader_criteria
index 30119d9d79e4813ed83d90e648cd342b42fea807..0a506e9eefe8003f42ae4015a9bff41a45ce2fed 100644 (file)
@@ -46,6 +46,7 @@ from ..sql.schema import SchemaConst
 from ..sql.selectable import FromClause
 from ..util.typing import Annotated
 from ..util.typing import Literal
+from ..util.typing import Protocol
 
 if TYPE_CHECKING:
     from ._typing import _EntityType
@@ -734,6 +735,54 @@ def with_loader_criteria(
     )
 
 
+class RelaionshipConstructorType(Protocol):
+    def __call__(
+        self,
+        argument: Optional[_RelationshipArgumentType[Any]] = None,
+        secondary: Optional[Union[FromClause, str]] = None,
+        *,
+        uselist: Optional[bool] = None,
+        collection_class: Optional[
+            Union[Type[Collection[Any]], Callable[[], Collection[Any]]]
+        ] = None,
+        primaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
+        secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
+        back_populates: Optional[str] = None,
+        order_by: _ORMOrderByArgument = False,
+        backref: Optional[_ORMBackrefArgument] = None,
+        overlaps: Optional[str] = None,
+        post_update: bool = False,
+        cascade: str = "save-update, merge",
+        viewonly: bool = False,
+        init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+        repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
+        default: Union[_NoArg, _T] = _NoArg.NO_ARG,
+        default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+        kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
+        lazy: _LazyLoadArgumentType = "select",
+        passive_deletes: Union[Literal["all"], bool] = False,
+        passive_updates: bool = True,
+        active_history: bool = False,
+        enable_typechecks: bool = True,
+        foreign_keys: Optional[_ORMColCollectionArgument] = None,
+        remote_side: Optional[_ORMColCollectionArgument] = None,
+        join_depth: Optional[int] = None,
+        comparator_factory: Optional[
+            Type[RelationshipProperty.Comparator[Any]]
+        ] = None,
+        single_parent: bool = False,
+        innerjoin: bool = False,
+        distinct_target_key: Optional[bool] = None,
+        load_on_pending: bool = False,
+        query_class: Optional[Type[Query[Any]]] = None,
+        info: Optional[_InfoType] = None,
+        omit_join: Literal[None, False] = None,
+        sync_backref: Optional[bool] = None,
+        **kw: Any,
+    ) -> Relationship[Any]:
+        ...
+
+
 def relationship(
     argument: Optional[_RelationshipArgumentType[Any]] = None,
     secondary: Optional[Union[FromClause, str]] = None,
@@ -1854,6 +1903,11 @@ def dynamic_loader(
     return relationship(argument, **kw)
 
 
+class BackrefConstructorType(Protocol):
+    def __call__(self, name: str, **kwargs: Any) -> _ORMBackrefArgument:
+        ...
+
+
 def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument:
     """When using the :paramref:`_orm.relationship.backref` parameter,
     provides specific parameters to be used when the new
index e2fe7854e7b66486ff379187b6930747ed8890bf..b7f2d945cacce6b892323753dcb4cbb56f0aacb0 100755 (executable)
@@ -53,29 +53,3 @@ with open(bootstrap_file) as f:
     to_bootstrap = "pytest"
     exec(code, globals(), locals())
     from sqla_pytestplugin import *  # noqa
-
-
-# def pytest_collection_finish(session):
-#     """Handle the pytest collection finish hook: configure pyannotate.
-#     Explicitly delay importing `collect_types` until all tests have
-#     been collected.  This gives gevent a chance to monkey patch the
-#     world before importing pyannotate.
-#     """
-#     from pyannotate_runtime import collect_types
-
-#     collect_types.init_types_collection()
-
-
-# @pytest.fixture(autouse=True)
-# def collect_types_fixture():
-#     from pyannotate_runtime import collect_types
-
-#     collect_types.start()
-#     yield
-#     collect_types.stop()
-
-
-# def pytest_sessionfinish(session, exitstatus):
-#     from pyannotate_runtime import collect_types
-
-#     collect_types.dump_stats("type_info.json")