]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improved typing
authorFederico Caselli <cfederico87@gmail.com>
Fri, 23 Jun 2023 18:09:29 +0000 (20:09 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 11 Jul 2023 20:20:50 +0000 (22:20 +0200)
- correctly inspect orm classes and instances
- add missing generics in run_sync on async connection and session
- minor fixes to relationship order params and shift overload

Change-Id: I9aebd3e1312bca855c8451ec0cc02a891983063f

lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/inspection.py
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/sql/elements.py
test/typing/plain_files/ext/hybrid/hybrid_four.py
test/typing/plain_files/inspection_inspect.py
test/typing/plain_files/orm/relationship.py
test/typing/plain_files/sql/sql_operations.py

index e77c3df10219d937e4e7914aebf6b5b1a9bc365c..594eb02a7df29d01d0e334cae8deb85a839516d8 100644 (file)
@@ -805,8 +805,8 @@ class AsyncConnection(
             yield result.scalars()
 
     async def run_sync(
-        self, fn: Callable[..., Any], *arg: Any, **kw: Any
-    ) -> Any:
+        self, fn: Callable[..., _T], *arg: Any, **kw: Any
+    ) -> _T:
         """Invoke the given synchronous (i.e. not async) callable,
         passing a synchronous-style :class:`_engine.Connection` as the first
         argument.
index ceeaff4d90f31d2055123ac3b872030b1a3d24dc..19a441ca613943ddc2c0b677e28b30e6a1a9991c 100644 (file)
@@ -329,8 +329,8 @@ class AsyncSession(ReversibleProxy[Session]):
         )
 
     async def run_sync(
-        self, fn: Callable[..., Any], *arg: Any, **kw: Any
-    ) -> Any:
+        self, fn: Callable[..., _T], *arg: Any, **kw: Any
+    ) -> _T:
         """Invoke the given synchronous (i.e. not async) callable,
         passing a synchronous-style :class:`_orm.Session` as the first
         argument.
index 5dfc529389483272180f291e9e99ab51fca49e0a..83dfb5033791d7187ee696710ea576b9e321a98c 100644 (file)
@@ -923,13 +923,13 @@ class _HybridUpdaterType(Protocol[_T_con]):
 
 
 class _HybridDeleterType(Protocol[_T_co]):
-    def __call__(self, instance: Any) -> None:
+    def __call__(s, self: Any) -> None:
         ...
 
 
 class _HybridExprCallableType(Protocol[_T]):
     def __call__(
-        self, cls: Any
+        s, cls: Any
     ) -> Union[_HasClauseElement, SQLColumnExpression[_T]]:
         ...
 
index f7d542dc51ef3062f3d2b5b84ebf81e631ea75b4..7d8479b5ecff863a8546be05e3864b43dd0fd573 100644 (file)
@@ -42,11 +42,13 @@ from typing import Union
 
 from . import exc
 from .util.typing import Literal
+from .util.typing import Protocol
 
 _T = TypeVar("_T", bound=Any)
+_TCov = TypeVar("_TCov", bound=Any, covariant=True)
 _F = TypeVar("_F", bound=Callable[..., Any])
 
-_IN = TypeVar("_IN", bound="Inspectable[Any]")
+_IN = TypeVar("_IN", bound=Any)
 
 _registrars: Dict[type, Union[Literal[True], Callable[[Any], Any]]] = {}
 
@@ -66,6 +68,38 @@ class Inspectable(Generic[_T]):
     __slots__ = ()
 
 
+class _InspectableTypeProtocol(Protocol[_TCov]):
+    """a protocol defining a method that's used when a type (ie the class
+    itself) is passed to inspect().
+
+    """
+
+    def _sa_inspect_type(self) -> _TCov:
+        ...
+
+
+class _InspectableProtocol(Protocol[_TCov]):
+    """a protocol defining a method that's used when an instance is
+    passed to inspect().
+
+    """
+
+    def _sa_inspect_instance(self) -> _TCov:
+        ...
+
+
+@overload
+def inspect(
+    subject: Type[_InspectableTypeProtocol[_IN]], raiseerr: bool = True
+) -> _IN:
+    ...
+
+
+@overload
+def inspect(subject: _InspectableProtocol[_IN], raiseerr: bool = True) -> _IN:
+    ...
+
+
 @overload
 def inspect(subject: Inspectable[_IN], raiseerr: bool = True) -> _IN:
     ...
index 6439561d82f5790bbbc4848eabb3d2e0c7eda57e..e6b67b326c32b9c752dc043911c4656a1a1d5ae2 100644 (file)
@@ -78,6 +78,7 @@ from ..util.typing import is_generic
 from ..util.typing import is_literal
 from ..util.typing import is_newtype
 from ..util.typing import Literal
+from ..util.typing import Self
 
 if TYPE_CHECKING:
     from ._typing import _O
@@ -137,7 +138,9 @@ class _DynamicAttributesType(type):
 
 
 class DeclarativeAttributeIntercept(
-    _DynamicAttributesType, inspection.Inspectable[Mapper[Any]]
+    _DynamicAttributesType,
+    # Inspectable is used only by the mypy plugin
+    inspection.Inspectable[Mapper[Any]],
 ):
     """Metaclass that may be used in conjunction with the
     :class:`_orm.DeclarativeBase` class to support addition of class
@@ -163,9 +166,7 @@ class DCTransformDeclarative(DeclarativeAttributeIntercept):
     """metaclass that includes @dataclass_transforms"""
 
 
-class DeclarativeMeta(
-    _DynamicAttributesType, inspection.Inspectable[Mapper[Any]]
-):
+class DeclarativeMeta(DeclarativeAttributeIntercept):
     metadata: MetaData
     registry: RegistryType
 
@@ -633,6 +634,7 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative):
 
 
 class DeclarativeBase(
+    # Inspectable is used only by the mypy plugin
     inspection.Inspectable[InstanceState[Any]],
     metaclass=DeclarativeAttributeIntercept,
 ):
@@ -748,6 +750,13 @@ class DeclarativeBase(
     """
 
     if typing.TYPE_CHECKING:
+
+        def _sa_inspect_type(self) -> Mapper[Self]:
+            ...
+
+        def _sa_inspect_instance(self) -> InstanceState[Self]:
+            ...
+
         _sa_registry: ClassVar[_RegistryType]
 
         registry: ClassVar[_RegistryType]
@@ -766,6 +775,9 @@ class DeclarativeBase(
 
         __name__: ClassVar[str]
 
+        # this ideally should be Mapper[Self], but mypy as of 1.4.1 does not
+        # like it, and breaks the declared_attr_one test. Pyright/pylance is
+        # ok with it.
         __mapper__: ClassVar[Mapper[Any]]
         """The :class:`_orm.Mapper` object to which a particular class is
         mapped.
@@ -851,7 +863,10 @@ def _check_not_declarative(cls: Type[Any], base: Type[Any]) -> None:
         )
 
 
-class DeclarativeBaseNoMeta(inspection.Inspectable[InstanceState[Any]]):
+class DeclarativeBaseNoMeta(
+    # Inspectable is used only by the mypy plugin
+    inspection.Inspectable[InstanceState[Any]]
+):
     """Same as :class:`_orm.DeclarativeBase`, but does not use a metaclass
     to intercept new attributes.
 
@@ -879,6 +894,9 @@ class DeclarativeBaseNoMeta(inspection.Inspectable[InstanceState[Any]]):
 
     """
 
+    # this ideally should be Mapper[Self], but mypy as of 1.4.1 does not
+    # like it, and breaks the declared_attr_one test. Pyright/pylance is
+    # ok with it.
     __mapper__: ClassVar[Mapper[Any]]
     """The :class:`_orm.Mapper` object to which a particular class is
     mapped.
@@ -903,6 +921,13 @@ class DeclarativeBaseNoMeta(inspection.Inspectable[InstanceState[Any]]):
     """
 
     if typing.TYPE_CHECKING:
+
+        def _sa_inspect_type(self) -> Mapper[Self]:
+            ...
+
+        def _sa_inspect_instance(self) -> InstanceState[Self]:
+            ...
+
         __tablename__: Any
         """String name to assign to the generated
         :class:`_schema.Table` object, if not specified directly via
index 610be37e92ff24ca570b2edfb2e362a5181691fb..e2ba6989266cd6e55747ade73962cf42f7e84fd9 100644 (file)
@@ -171,7 +171,8 @@ _ORMOrderByArgument = Union[
     Literal[False],
     str,
     _ColumnExpressionArgument[Any],
-    Callable[[], Iterable[ColumnElement[Any]]],
+    Callable[[], _ColumnExpressionArgument[Any]],
+    Callable[[], Iterable[_ColumnExpressionArgument[Any]]],
     Iterable[Union[str, _ColumnExpressionArgument[Any]]],
 ]
 ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
index ba074db80c6ec11de4031f8f5ca0452883622387..8381ee7601831cb328bab049b28c63a4225cff54 100644 (file)
@@ -854,9 +854,25 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         def __getitem__(self, index: Any) -> ColumnElement[Any]:
             ...
 
+        @overload
+        def __lshift__(self: _SQO[int], other: Any) -> ColumnElement[int]:
+            ...
+
+        @overload
+        def __lshift__(self, other: Any) -> ColumnElement[Any]:
+            ...
+
         def __lshift__(self, other: Any) -> ColumnElement[Any]:
             ...
 
+        @overload
+        def __rshift__(self: _SQO[int], other: Any) -> ColumnElement[int]:
+            ...
+
+        @overload
+        def __rshift__(self, other: Any) -> ColumnElement[Any]:
+            ...
+
         def __rshift__(self, other: Any) -> ColumnElement[Any]:
             ...
 
index c4ee15bc8322040c08edffd5dccabda561e05f07..a9df08f4b9a891683fbd288f89e29d2edcd4eaea 100644 (file)
@@ -50,6 +50,15 @@ class FirstNameOnly(Base):
     def _name_setter(self, value: str) -> None:
         self.first_name = value
 
+    @name.inplace.deleter
+    def _name_del(self) -> None:
+        self.first_name = ""
+
+    @name.inplace.expression
+    @classmethod
+    def _name_expr(cls) -> ColumnElement[str]:
+        return cls.first_name + "-"
+
 
 class FirstNameLastName(FirstNameOnly):
     last_name: Mapped[str]
index 155ceffc0358d574a4ae43a884b5aa27935a15d5..be9657318b9f2f76fd586dcb9ddf6d81dc8ba346 100644 (file)
@@ -1,19 +1,10 @@
-"""
-test inspect()
-
-however this is not really working
-
-"""
-from typing import Any
-from typing import Optional
-
-from sqlalchemy import Column
 from sqlalchemy import create_engine
 from sqlalchemy import inspect
-from sqlalchemy import Integer
-from sqlalchemy import String
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import DeclarativeBaseNoMeta
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import Mapper
 
 
@@ -24,24 +15,50 @@ class Base(DeclarativeBase):
 class A(Base):
     __tablename__ = "a"
 
-    id = Column(Integer, primary_key=True)
-    data = Column(String)
+    id: Mapped[int] = mapped_column(primary_key=True)
+    data: Mapped[str]
+
+
+class BaseNoMeta(DeclarativeBaseNoMeta):
+    pass
+
+
+class B(BaseNoMeta):
+    __tablename__ = "b"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    data: Mapped[str]
+
 
+# EXPECTED_TYPE: Mapper[Any]
+reveal_type(A.__mapper__)
+# EXPECTED_TYPE: Mapper[Any]
+reveal_type(B.__mapper__)
 
 a1 = A(data="d")
+b1 = B(data="d")
 
 e = create_engine("sqlite://")
 
-# TODO: I can't get these to work, pylance and mypy both don't want
-# to accommodate for different types for the first argument
+insp_a1 = inspect(a1)
 
-t: Optional[Any] = inspect(a1)
+t: bool = insp_a1.transient
+# EXPECTED_TYPE: InstanceState[A]
+reveal_type(insp_a1)
+# EXPECTED_TYPE: InstanceState[B]
+reveal_type(inspect(b1))
 
-m: Mapper[Any] = inspect(A)
+m: Mapper[A] = inspect(A)
+# EXPECTED_TYPE: Mapper[A]
+reveal_type(inspect(A))
+# EXPECTED_TYPE: Mapper[B]
+reveal_type(inspect(B))
 
-inspect(e).get_table_names()
+tables: list[str] = inspect(e).get_table_names()
 
 i: Inspector = inspect(e)
+# EXPECTED_TYPE: Inspector
+reveal_type(inspect(e))
 
 
 with e.connect() as conn:
index ddd51e21e4376265c340944a8656845b63a9bb1a..d0ab35249d1bbaf79ddb5622cb48721fc35a30b6 100644 (file)
@@ -4,16 +4,24 @@
 from __future__ import annotations
 
 import typing
+from typing import ClassVar
 from typing import List
 from typing import Optional
 from typing import Set
+from typing import TYPE_CHECKING
 
+from sqlalchemy import create_engine
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import Table
 from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import registry
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
 
 
 class Base(DeclarativeBase):
@@ -90,3 +98,66 @@ if typing.TYPE_CHECKING:
 
     # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[relationship.Address\]\]
     reveal_type(User.addresses_style_two)
+
+
+mapper_registry: registry = registry()
+
+e = create_engine("sqlite:///")
+
+
+@mapper_registry.mapped
+class A:
+    __tablename__ = "a"
+    id: Mapped[int] = mapped_column(primary_key=True)
+    b_id: Mapped[int] = mapped_column(ForeignKey("b.id"))
+    number: Mapped[int] = mapped_column(primary_key=True)
+    number2: Mapped[int] = mapped_column(primary_key=True)
+    if TYPE_CHECKING:
+        __table__: ClassVar[Table]
+
+
+@mapper_registry.mapped
+class B:
+    __tablename__ = "b"
+    id: Mapped[int] = mapped_column(primary_key=True)
+
+    # Omit order_by
+    a1: Mapped[list[A]] = relationship("A", uselist=True)
+
+    # All kinds of order_by
+    a2: Mapped[list[A]] = relationship(
+        "A", uselist=True, order_by=(A.id, A.number)
+    )
+    a3: Mapped[list[A]] = relationship(
+        "A", uselist=True, order_by=[A.id, A.number]
+    )
+    a4: Mapped[list[A]] = relationship("A", uselist=True, order_by=A.id)
+    a5: Mapped[list[A]] = relationship(
+        "A", uselist=True, order_by=A.__table__.c.id
+    )
+    a6: Mapped[list[A]] = relationship("A", uselist=True, order_by="A.number")
+
+    # Same kinds but lambda'd
+    a7: Mapped[list[A]] = relationship(
+        "A", uselist=True, order_by=lambda: (A.id, A.number)
+    )
+    a8: Mapped[list[A]] = relationship(
+        "A", uselist=True, order_by=lambda: [A.id, A.number]
+    )
+    a9: Mapped[list[A]] = relationship(
+        "A", uselist=True, order_by=lambda: A.id
+    )
+
+
+mapper_registry.metadata.drop_all(e)
+mapper_registry.metadata.create_all(e)
+
+with Session(e) as s:
+    s.execute(select(B).options(joinedload(B.a1)))
+    s.execute(select(B).options(joinedload(B.a2)))
+    s.execute(select(B).options(joinedload(B.a3)))
+    s.execute(select(B).options(joinedload(B.a4)))
+    s.execute(select(B).options(joinedload(B.a5)))
+    s.execute(select(B).options(joinedload(B.a7)))
+    s.execute(select(B).options(joinedload(B.a8)))
+    s.execute(select(B).options(joinedload(B.a9)))
index 4d3775293e8e3f28469a8315f63b438396397975..56d0529da407babccf549ebf55df3fd6e8c38d81 100644 (file)
@@ -111,9 +111,9 @@ def test_issue_9650_bitwise() -> None:
     reveal_type(c2.bitwise_lshift(5))
     # EXPECTED_TYPE: BinaryExpression[Any]
     reveal_type(c2.bitwise_rshift(5))
-    # EXPECTED_TYPE: ColumnElement[Any]
+    # EXPECTED_TYPE: ColumnElement[int]
     reveal_type(c2 << 5)
-    # EXPECTED_TYPE: ColumnElement[Any]
+    # EXPECTED_TYPE: ColumnElement[int]
     reveal_type(c2 >> 5)