From: Mike Bayer Date: Tue, 21 Feb 2023 16:06:17 +0000 (-0500) Subject: fix with_polymorphic X-Git-Tag: rel_2_0_5~22^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=097e1eacaaf43f728c552df9ebbfa0fb81c4b6c7;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git fix with_polymorphic Fixed typing issue where :func:`_orm.with_polymorphic` would not record the class type correctly. Fixes: #9340 Change-Id: I535ad9aede9b60475231028adb8dc270e55738a4 --- diff --git a/doc/build/changelog/unreleased_20/9340.rst b/doc/build/changelog/unreleased_20/9340.rst new file mode 100644 index 0000000000..28cef6f648 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9340.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: bug, typing + :tickets: 9340 + + Fixed typing issue where :func:`_orm.with_polymorphic` would not + record the class type correctly. diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 3bd1db79d8..64e7937f11 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -2208,7 +2208,7 @@ def aliased( def with_polymorphic( - base: Union[_O, Mapper[_O]], + base: Union[Type[_O], Mapper[_O]], classes: Union[Literal["*"], Iterable[Type[Any]]], selectable: Union[Literal[False, None], FromClause] = False, flat: bool = False, diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index ad9ce2013d..1ef0d71591 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -994,7 +994,7 @@ class AliasedInsp( @classmethod def _with_polymorphic_factory( cls, - base: Union[_O, Mapper[_O]], + base: Union[Type[_O], Mapper[_O]], classes: Union[Literal["*"], Iterable[_EntityType[Any]]], selectable: Union[Literal[False, None], FromClause] = False, flat: bool = False, diff --git a/test/ext/mypy/plain_files/issue_9340.py b/test/ext/mypy/plain_files/issue_9340.py new file mode 100644 index 0000000000..72dc72df1e --- /dev/null +++ b/test/ext/mypy/plain_files/issue_9340.py @@ -0,0 +1,63 @@ +from typing import Sequence +from typing import TYPE_CHECKING + +from sqlalchemy import create_engine +from sqlalchemy import select +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import Session +from sqlalchemy.orm import with_polymorphic + + +class Base(DeclarativeBase): + ... + + +class Message(Base): + __tablename__ = "message" + __mapper_args__ = { + "polymorphic_on": "message_type", + "polymorphic_identity": __tablename__, + } + id: Mapped[int] = mapped_column(primary_key=True) + text: Mapped[str] + message_type: Mapped[str] + + +class UserComment(Message): + __mapper_args__ = { + "polymorphic_identity": "user_comment", + } + username: Mapped[str] = mapped_column(nullable=True) + + +engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/") + + +def get_messages() -> Sequence[Message]: + with Session(engine) as session: + message_query = select(Message) + + if TYPE_CHECKING: + # EXPECTED_TYPE: Select[Tuple[Message]] + reveal_type(message_query) + + return session.scalars(message_query).all() + + +def get_poly_messages() -> Sequence[Message]: + with Session(engine) as session: + PolymorphicMessage = with_polymorphic(Message, (UserComment,)) + + if TYPE_CHECKING: + # EXPECTED_TYPE: AliasedClass[Message] + reveal_type(PolymorphicMessage) + + poly_query = select(PolymorphicMessage) + + if TYPE_CHECKING: + # EXPECTED_TYPE: Select[Tuple[Message]] + reveal_type(poly_query) + + return session.scalars(poly_query).all()