From 4b82f9e0a09f04acbd6983524085955a1a15ea40 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Fri, 23 Jun 2023 20:09:29 +0200 Subject: [PATCH] Improved typing - correctly inspect orm classes and instances - add missing generics in run_sync on async connection and session - minor fixes to relationship order params and shift overload Change-Id: I9aebd3e1312bca855c8451ec0cc02a891983063f --- lib/sqlalchemy/ext/asyncio/engine.py | 4 +- lib/sqlalchemy/ext/asyncio/session.py | 4 +- lib/sqlalchemy/ext/hybrid.py | 4 +- lib/sqlalchemy/inspection.py | 36 +++++++++- lib/sqlalchemy/orm/decl_api.py | 35 +++++++-- lib/sqlalchemy/orm/relationships.py | 3 +- lib/sqlalchemy/sql/elements.py | 16 +++++ .../plain_files/ext/hybrid/hybrid_four.py | 9 +++ test/typing/plain_files/inspection_inspect.py | 55 +++++++++----- test/typing/plain_files/orm/relationship.py | 71 +++++++++++++++++++ test/typing/plain_files/sql/sql_operations.py | 4 +- 11 files changed, 207 insertions(+), 34 deletions(-) diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index e77c3df102..594eb02a7d 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -805,8 +805,8 @@ class AsyncConnection( yield result.scalars() async def run_sync( - self, fn: Callable[..., Any], *arg: Any, **kw: Any - ) -> Any: + self, fn: Callable[..., _T], *arg: Any, **kw: Any + ) -> _T: """Invoke the given synchronous (i.e. not async) callable, passing a synchronous-style :class:`_engine.Connection` as the first argument. diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index ceeaff4d90..19a441ca61 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -329,8 +329,8 @@ class AsyncSession(ReversibleProxy[Session]): ) async def run_sync( - self, fn: Callable[..., Any], *arg: Any, **kw: Any - ) -> Any: + self, fn: Callable[..., _T], *arg: Any, **kw: Any + ) -> _T: """Invoke the given synchronous (i.e. not async) callable, passing a synchronous-style :class:`_orm.Session` as the first argument. diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 5dfc529389..83dfb50337 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -923,13 +923,13 @@ class _HybridUpdaterType(Protocol[_T_con]): class _HybridDeleterType(Protocol[_T_co]): - def __call__(self, instance: Any) -> None: + def __call__(s, self: Any) -> None: ... class _HybridExprCallableType(Protocol[_T]): def __call__( - self, cls: Any + s, cls: Any ) -> Union[_HasClauseElement, SQLColumnExpression[_T]]: ... diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py index f7d542dc51..7d8479b5ec 100644 --- a/lib/sqlalchemy/inspection.py +++ b/lib/sqlalchemy/inspection.py @@ -42,11 +42,13 @@ from typing import Union from . import exc from .util.typing import Literal +from .util.typing import Protocol _T = TypeVar("_T", bound=Any) +_TCov = TypeVar("_TCov", bound=Any, covariant=True) _F = TypeVar("_F", bound=Callable[..., Any]) -_IN = TypeVar("_IN", bound="Inspectable[Any]") +_IN = TypeVar("_IN", bound=Any) _registrars: Dict[type, Union[Literal[True], Callable[[Any], Any]]] = {} @@ -66,6 +68,38 @@ class Inspectable(Generic[_T]): __slots__ = () +class _InspectableTypeProtocol(Protocol[_TCov]): + """a protocol defining a method that's used when a type (ie the class + itself) is passed to inspect(). + + """ + + def _sa_inspect_type(self) -> _TCov: + ... + + +class _InspectableProtocol(Protocol[_TCov]): + """a protocol defining a method that's used when an instance is + passed to inspect(). + + """ + + def _sa_inspect_instance(self) -> _TCov: + ... + + +@overload +def inspect( + subject: Type[_InspectableTypeProtocol[_IN]], raiseerr: bool = True +) -> _IN: + ... + + +@overload +def inspect(subject: _InspectableProtocol[_IN], raiseerr: bool = True) -> _IN: + ... + + @overload def inspect(subject: Inspectable[_IN], raiseerr: bool = True) -> _IN: ... diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 6439561d82..e6b67b326c 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -78,6 +78,7 @@ from ..util.typing import is_generic from ..util.typing import is_literal from ..util.typing import is_newtype from ..util.typing import Literal +from ..util.typing import Self if TYPE_CHECKING: from ._typing import _O @@ -137,7 +138,9 @@ class _DynamicAttributesType(type): class DeclarativeAttributeIntercept( - _DynamicAttributesType, inspection.Inspectable[Mapper[Any]] + _DynamicAttributesType, + # Inspectable is used only by the mypy plugin + inspection.Inspectable[Mapper[Any]], ): """Metaclass that may be used in conjunction with the :class:`_orm.DeclarativeBase` class to support addition of class @@ -163,9 +166,7 @@ class DCTransformDeclarative(DeclarativeAttributeIntercept): """metaclass that includes @dataclass_transforms""" -class DeclarativeMeta( - _DynamicAttributesType, inspection.Inspectable[Mapper[Any]] -): +class DeclarativeMeta(DeclarativeAttributeIntercept): metadata: MetaData registry: RegistryType @@ -633,6 +634,7 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): class DeclarativeBase( + # Inspectable is used only by the mypy plugin inspection.Inspectable[InstanceState[Any]], metaclass=DeclarativeAttributeIntercept, ): @@ -748,6 +750,13 @@ class DeclarativeBase( """ if typing.TYPE_CHECKING: + + def _sa_inspect_type(self) -> Mapper[Self]: + ... + + def _sa_inspect_instance(self) -> InstanceState[Self]: + ... + _sa_registry: ClassVar[_RegistryType] registry: ClassVar[_RegistryType] @@ -766,6 +775,9 @@ class DeclarativeBase( __name__: ClassVar[str] + # this ideally should be Mapper[Self], but mypy as of 1.4.1 does not + # like it, and breaks the declared_attr_one test. Pyright/pylance is + # ok with it. __mapper__: ClassVar[Mapper[Any]] """The :class:`_orm.Mapper` object to which a particular class is mapped. @@ -851,7 +863,10 @@ def _check_not_declarative(cls: Type[Any], base: Type[Any]) -> None: ) -class DeclarativeBaseNoMeta(inspection.Inspectable[InstanceState[Any]]): +class DeclarativeBaseNoMeta( + # Inspectable is used only by the mypy plugin + inspection.Inspectable[InstanceState[Any]] +): """Same as :class:`_orm.DeclarativeBase`, but does not use a metaclass to intercept new attributes. @@ -879,6 +894,9 @@ class DeclarativeBaseNoMeta(inspection.Inspectable[InstanceState[Any]]): """ + # this ideally should be Mapper[Self], but mypy as of 1.4.1 does not + # like it, and breaks the declared_attr_one test. Pyright/pylance is + # ok with it. __mapper__: ClassVar[Mapper[Any]] """The :class:`_orm.Mapper` object to which a particular class is mapped. @@ -903,6 +921,13 @@ class DeclarativeBaseNoMeta(inspection.Inspectable[InstanceState[Any]]): """ if typing.TYPE_CHECKING: + + def _sa_inspect_type(self) -> Mapper[Self]: + ... + + def _sa_inspect_instance(self) -> InstanceState[Self]: + ... + __tablename__: Any """String name to assign to the generated :class:`_schema.Table` object, if not specified directly via diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 610be37e92..e2ba698926 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -171,7 +171,8 @@ _ORMOrderByArgument = Union[ Literal[False], str, _ColumnExpressionArgument[Any], - Callable[[], Iterable[ColumnElement[Any]]], + Callable[[], _ColumnExpressionArgument[Any]], + Callable[[], Iterable[_ColumnExpressionArgument[Any]]], Iterable[Union[str, _ColumnExpressionArgument[Any]]], ] ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]] diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index ba074db80c..8381ee7601 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -854,9 +854,25 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): def __getitem__(self, index: Any) -> ColumnElement[Any]: ... + @overload + def __lshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: + ... + + @overload + def __lshift__(self, other: Any) -> ColumnElement[Any]: + ... + def __lshift__(self, other: Any) -> ColumnElement[Any]: ... + @overload + def __rshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: + ... + + @overload + def __rshift__(self, other: Any) -> ColumnElement[Any]: + ... + def __rshift__(self, other: Any) -> ColumnElement[Any]: ... diff --git a/test/typing/plain_files/ext/hybrid/hybrid_four.py b/test/typing/plain_files/ext/hybrid/hybrid_four.py index c4ee15bc83..a9df08f4b9 100644 --- a/test/typing/plain_files/ext/hybrid/hybrid_four.py +++ b/test/typing/plain_files/ext/hybrid/hybrid_four.py @@ -50,6 +50,15 @@ class FirstNameOnly(Base): def _name_setter(self, value: str) -> None: self.first_name = value + @name.inplace.deleter + def _name_del(self) -> None: + self.first_name = "" + + @name.inplace.expression + @classmethod + def _name_expr(cls) -> ColumnElement[str]: + return cls.first_name + "-" + class FirstNameLastName(FirstNameOnly): last_name: Mapped[str] diff --git a/test/typing/plain_files/inspection_inspect.py b/test/typing/plain_files/inspection_inspect.py index 155ceffc03..be9657318b 100644 --- a/test/typing/plain_files/inspection_inspect.py +++ b/test/typing/plain_files/inspection_inspect.py @@ -1,19 +1,10 @@ -""" -test inspect() - -however this is not really working - -""" -from typing import Any -from typing import Optional - -from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import inspect -from sqlalchemy import Integer -from sqlalchemy import String from sqlalchemy.engine.reflection import Inspector from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import DeclarativeBaseNoMeta +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import Mapper @@ -24,24 +15,50 @@ class Base(DeclarativeBase): class A(Base): __tablename__ = "a" - id = Column(Integer, primary_key=True) - data = Column(String) + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + +class BaseNoMeta(DeclarativeBaseNoMeta): + pass + + +class B(BaseNoMeta): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + +# EXPECTED_TYPE: Mapper[Any] +reveal_type(A.__mapper__) +# EXPECTED_TYPE: Mapper[Any] +reveal_type(B.__mapper__) a1 = A(data="d") +b1 = B(data="d") e = create_engine("sqlite://") -# TODO: I can't get these to work, pylance and mypy both don't want -# to accommodate for different types for the first argument +insp_a1 = inspect(a1) -t: Optional[Any] = inspect(a1) +t: bool = insp_a1.transient +# EXPECTED_TYPE: InstanceState[A] +reveal_type(insp_a1) +# EXPECTED_TYPE: InstanceState[B] +reveal_type(inspect(b1)) -m: Mapper[Any] = inspect(A) +m: Mapper[A] = inspect(A) +# EXPECTED_TYPE: Mapper[A] +reveal_type(inspect(A)) +# EXPECTED_TYPE: Mapper[B] +reveal_type(inspect(B)) -inspect(e).get_table_names() +tables: list[str] = inspect(e).get_table_names() i: Inspector = inspect(e) +# EXPECTED_TYPE: Inspector +reveal_type(inspect(e)) with e.connect() as conn: diff --git a/test/typing/plain_files/orm/relationship.py b/test/typing/plain_files/orm/relationship.py index ddd51e21e4..d0ab35249d 100644 --- a/test/typing/plain_files/orm/relationship.py +++ b/test/typing/plain_files/orm/relationship.py @@ -4,16 +4,24 @@ from __future__ import annotations import typing +from typing import ClassVar from typing import List from typing import Optional from typing import Set +from typing import TYPE_CHECKING +from sqlalchemy import create_engine from sqlalchemy import ForeignKey from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import Table from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import joinedload from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import registry from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session class Base(DeclarativeBase): @@ -90,3 +98,66 @@ if typing.TYPE_CHECKING: # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[relationship.Address\]\] reveal_type(User.addresses_style_two) + + +mapper_registry: registry = registry() + +e = create_engine("sqlite:///") + + +@mapper_registry.mapped +class A: + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + b_id: Mapped[int] = mapped_column(ForeignKey("b.id")) + number: Mapped[int] = mapped_column(primary_key=True) + number2: Mapped[int] = mapped_column(primary_key=True) + if TYPE_CHECKING: + __table__: ClassVar[Table] + + +@mapper_registry.mapped +class B: + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + # Omit order_by + a1: Mapped[list[A]] = relationship("A", uselist=True) + + # All kinds of order_by + a2: Mapped[list[A]] = relationship( + "A", uselist=True, order_by=(A.id, A.number) + ) + a3: Mapped[list[A]] = relationship( + "A", uselist=True, order_by=[A.id, A.number] + ) + a4: Mapped[list[A]] = relationship("A", uselist=True, order_by=A.id) + a5: Mapped[list[A]] = relationship( + "A", uselist=True, order_by=A.__table__.c.id + ) + a6: Mapped[list[A]] = relationship("A", uselist=True, order_by="A.number") + + # Same kinds but lambda'd + a7: Mapped[list[A]] = relationship( + "A", uselist=True, order_by=lambda: (A.id, A.number) + ) + a8: Mapped[list[A]] = relationship( + "A", uselist=True, order_by=lambda: [A.id, A.number] + ) + a9: Mapped[list[A]] = relationship( + "A", uselist=True, order_by=lambda: A.id + ) + + +mapper_registry.metadata.drop_all(e) +mapper_registry.metadata.create_all(e) + +with Session(e) as s: + s.execute(select(B).options(joinedload(B.a1))) + s.execute(select(B).options(joinedload(B.a2))) + s.execute(select(B).options(joinedload(B.a3))) + s.execute(select(B).options(joinedload(B.a4))) + s.execute(select(B).options(joinedload(B.a5))) + s.execute(select(B).options(joinedload(B.a7))) + s.execute(select(B).options(joinedload(B.a8))) + s.execute(select(B).options(joinedload(B.a9))) diff --git a/test/typing/plain_files/sql/sql_operations.py b/test/typing/plain_files/sql/sql_operations.py index 4d3775293e..56d0529da4 100644 --- a/test/typing/plain_files/sql/sql_operations.py +++ b/test/typing/plain_files/sql/sql_operations.py @@ -111,9 +111,9 @@ def test_issue_9650_bitwise() -> None: reveal_type(c2.bitwise_lshift(5)) # EXPECTED_TYPE: BinaryExpression[Any] reveal_type(c2.bitwise_rshift(5)) - # EXPECTED_TYPE: ColumnElement[Any] + # EXPECTED_TYPE: ColumnElement[int] reveal_type(c2 << 5) - # EXPECTED_TYPE: ColumnElement[Any] + # EXPECTED_TYPE: ColumnElement[int] reveal_type(c2 >> 5) -- 2.39.5