From 0ab0d0c98068d5d259a939cdff19ba34dd0433d0 Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Sat, 18 Oct 2025 10:20:55 -0400 Subject: [PATCH] Change typing tests to use `assert_type` instead of `reveal_type` Closes: #12922 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12922 Pull-request-sha: 580f6638168c33e6c50e95066312ac605433665f Change-Id: I9f3bdb4c105971f53fa10ed8a934356203ddb080 --- lib/sqlalchemy/testing/fixtures/mypy.py | 74 +--- .../dialects/postgresql/pg_stuff.py | 60 ++-- .../plain_files/engine/engine_inspection.py | 13 +- .../plain_files/engine/engine_result.py | 87 ++--- test/typing/plain_files/engine/engines.py | 22 +- .../association_proxy_one.py | 8 +- .../association_proxy_three.py | 4 +- .../association_proxy_two.py | 7 +- .../typing/plain_files/ext/asyncio/engines.py | 34 +- .../plain_files/ext/hybrid/hybrid_one.py | 29 +- .../plain_files/ext/hybrid/hybrid_two.py | 49 ++- test/typing/plain_files/ext/indexable.py | 24 +- .../ext/orderinglist/orderinglist_one.py | 9 +- test/typing/plain_files/inspection_inspect.py | 24 +- test/typing/plain_files/orm/composite.py | 14 +- test/typing/plain_files/orm/composite_dc.py | 14 +- .../orm/dataclass_transforms_decorator.py | 5 +- ...dataclass_transforms_decorator_w_mixins.py | 8 +- .../orm/dataclass_transforms_one.py | 7 +- .../plain_files/orm/declared_attr_one.py | 21 +- .../plain_files/orm/declared_attr_two.py | 14 +- test/typing/plain_files/orm/dynamic_rel.py | 20 +- test/typing/plain_files/orm/issue_9340.py | 12 +- test/typing/plain_files/orm/keyfunc_dict.py | 4 +- test/typing/plain_files/orm/relationship.py | 58 ++-- test/typing/plain_files/orm/session.py | 22 +- test/typing/plain_files/orm/sessionmakers.py | 45 +-- .../orm/trad_relationship_uselist.py | 44 +-- .../orm/traditional_relationship.py | 30 +- test/typing/plain_files/orm/typed_queries.py | 239 +++++-------- test/typing/plain_files/orm/write_only.py | 8 +- .../plain_files/sql/common_sql_element.py | 152 ++++----- test/typing/plain_files/sql/functions.py | 86 ++--- .../typing/plain_files/sql/functions_again.py | 67 ++-- test/typing/plain_files/sql/lambda_stmt.py | 17 +- test/typing/plain_files/sql/misc.py | 13 +- test/typing/plain_files/sql/operators.py | 14 +- test/typing/plain_files/sql/sql_operations.py | 64 ++-- test/typing/plain_files/sql/sqltypes.py | 15 +- test/typing/plain_files/sql/typed_results.py | 316 +++++++----------- tools/generate_sql_functions.py | 17 +- 41 files changed, 723 insertions(+), 1047 deletions(-) diff --git a/lib/sqlalchemy/testing/fixtures/mypy.py b/lib/sqlalchemy/testing/fixtures/mypy.py index b1d2ee0e81..cc16fa3744 100644 --- a/lib/sqlalchemy/testing/fixtures/mypy.py +++ b/lib/sqlalchemy/testing/fixtures/mypy.py @@ -29,8 +29,6 @@ try: except ImportError: _mypy_vers_tuple = (0, 0, 0) -mypy_14 = _mypy_vers_tuple >= (1, 4) - @config.add_to_marker.mypy class MypyTest(TestBase): @@ -128,9 +126,7 @@ class MypyTest(TestBase): def _collect_messages(self, path): expected_messages = [] - expected_re = re.compile( - r"\s*# EXPECTED(_MYPY)?(_RE)?(_ROW)?(_TYPE)?: (.+)" - ) + expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)") py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)") with open(path) as file_: current_assert_messages = [] @@ -148,71 +144,21 @@ class MypyTest(TestBase): if m: is_mypy = bool(m.group(1)) is_re = bool(m.group(2)) - is_row = bool(m.group(3)) - is_type = bool(m.group(4)) - - expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(5)) - if is_row: - expected_msg = re.sub( - r"Row\[([^\]]+)\]", - lambda m: f"tuple[{m.group(1)}, fallback=s" - f"qlalchemy.engine.row.{m.group(0)}]", - expected_msg, - ) - # For some reason it does not use or syntax (|) - expected_msg = re.sub( - r"Optional\[(.*)\]", - lambda m: f"Union[{m.group(1)}, None]", - expected_msg, - ) + is_type = bool(m.group(3)) + expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4)) if is_type: - if not is_re: - # the goal here is that we can cut-and-paste - # from vscode -> pylance into the - # EXPECTED_TYPE: line, then the test suite will - # validate that line against what mypy produces - expected_msg = re.sub( - r"([\[\]])", - lambda m: rf"\{m.group(0)}", - expected_msg, - ) - - # note making sure preceding text matches - # with a dot, so that an expect for "Select" - # does not match "TypedSelect" - expected_msg = re.sub( - r"([\w_]+)", - lambda m: rf"(?:.*\.)?{m.group(1)}\*?", - expected_msg, - ) - - expected_msg = re.sub( - "List", "builtins.list", expected_msg - ) - - expected_msg = re.sub( - r"\b(int|str|float|bool)\b", - lambda m: rf"builtins.{m.group(0)}\*?", - expected_msg, - ) - # expected_msg = re.sub( - # r"(Sequence|Tuple|List|Union)", - # lambda m: fr"typing.{m.group(0)}\*?", - # expected_msg, - # ) is_mypy = is_re = True expected_msg = f'Revealed type is "{expected_msg}"' - if mypy_14: - # use_or_syntax - # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368 # noqa: E501 - expected_msg = re.sub( - r"Optional\[(.*?)\]", - lambda m: f"{m.group(1)} | None", - expected_msg, - ) + # use_or_syntax + # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368 # noqa: E501 + expected_msg = re.sub( + r"Optional\[(.*?)\]", + lambda m: f"{m.group(1)} | None", + expected_msg, + ) current_assert_messages.append( (is_mypy, is_re, expected_msg.strip()) ) diff --git a/test/typing/plain_files/dialects/postgresql/pg_stuff.py b/test/typing/plain_files/dialects/postgresql/pg_stuff.py index 20785bc2cb..ec99ec1b0c 100644 --- a/test/typing/plain_files/dialects/postgresql/pg_stuff.py +++ b/test/typing/plain_files/dialects/postgresql/pg_stuff.py @@ -1,5 +1,9 @@ +from datetime import date +from datetime import datetime from typing import Any +from typing import assert_type from typing import Dict +from typing import Sequence from uuid import UUID as _py_uuid from sqlalchemy import cast @@ -7,6 +11,7 @@ from sqlalchemy import Column from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import or_ +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import Text from sqlalchemy import UniqueConstraint @@ -18,6 +23,7 @@ from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import INT4RANGE from sqlalchemy.dialects.postgresql import INT8MULTIRANGE from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import Range from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import DeclarativeBase @@ -28,13 +34,11 @@ from sqlalchemy.orm import mapped_column c1 = Column(UUID()) -# EXPECTED_TYPE: Column[UUID] -reveal_type(c1) +assert_type(c1, Column[_py_uuid]) c2 = Column(UUID(as_uuid=False)) -# EXPECTED_TYPE: Column[str] -reveal_type(c2) +assert_type(c2, Column[str]) class Base(DeclarativeBase): @@ -69,11 +73,9 @@ print(stmt) t1 = Test() -# EXPECTED_RE_TYPE: .*dict\[.*str, Any\] -reveal_type(t1.data) +assert_type(t1.data, dict[str, Any]) -# EXPECTED_TYPE: UUID -reveal_type(t1.ident) +assert_type(t1.ident, _py_uuid) unique = UniqueConstraint(name="my_constraint") insert(Test).on_conflict_do_nothing( @@ -86,52 +88,40 @@ s1 = insert(Test) s1.on_conflict_do_update(set_=s1.excluded) -# EXPECTED_TYPE: Column[Range[int]] -reveal_type(Column(INT4RANGE())) -# EXPECTED_TYPE: Column[Range[datetime.date]] -reveal_type(Column("foo", DATERANGE())) -# EXPECTED_TYPE: Column[Sequence[Range[int]]] -reveal_type(Column(INT8MULTIRANGE())) -# EXPECTED_TYPE: Column[Sequence[Range[datetime.datetime]]] -reveal_type(Column("foo", TSTZMULTIRANGE())) +assert_type(Column(INT4RANGE()), Column[Range[int]]) +assert_type(Column("foo", DATERANGE()), Column[Range[date]]) +assert_type(Column(INT8MULTIRANGE()), Column[Sequence[Range[int]]]) +assert_type(Column("foo", TSTZMULTIRANGE()), Column[Sequence[Range[datetime]]]) range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE())) -# EXPECTED_TYPE: Select[Range[int], Sequence[Range[int]]] -reveal_type(range_col_stmt) +assert_type(range_col_stmt, Select[Range[int], Sequence[Range[int]]]) array_from_ints = array(range(2)) -# EXPECTED_TYPE: array[int] -reveal_type(array_from_ints) +assert_type(array_from_ints, array[int]) array_of_strings = array([], type_=Text) -# EXPECTED_TYPE: array[str] -reveal_type(array_of_strings) +assert_type(array_of_strings, array[str]) array_of_ints = array([0], type_=Integer) -# EXPECTED_TYPE: array[int] -reveal_type(array_of_ints) +assert_type(array_of_ints, array[int]) # EXPECTED_MYPY_RE: Cannot infer .* of "array" array([0], type_=Text) -# EXPECTED_TYPE: ARRAY[str] -reveal_type(ARRAY(Text)) +assert_type(ARRAY(Text), ARRAY[str]) -# EXPECTED_TYPE: Column[Sequence[int]] -reveal_type(Column(type_=ARRAY(Integer))) +assert_type(Column(type_=ARRAY(Integer)), Column[Sequence[int]]) stmt_array_agg = select(func.array_agg(Column("num", type_=Integer))) -# EXPECTED_TYPE: Select[Sequence[int]] -reveal_type(stmt_array_agg) +assert_type(stmt_array_agg, Select[Sequence[int]]) -# EXPECTED_TYPE: Select[Sequence[str]] -reveal_type(select(func.array_agg(Test.ident_str))) +assert_type(select(func.array_agg(Test.ident_str)), Select[Sequence[str]]) stmt_array_agg_order_by_1 = select( func.array_agg( @@ -143,8 +133,7 @@ stmt_array_agg_order_by_1 = select( ) ) -# EXPECTED_TYPE: Select[Sequence[str]] -reveal_type(stmt_array_agg_order_by_1) +assert_type(stmt_array_agg_order_by_1, Select[Sequence[str]]) stmt_array_agg_order_by_2 = select( func.array_agg( @@ -152,5 +141,4 @@ stmt_array_agg_order_by_2 = select( ) ) -# EXPECTED_TYPE: Select[Sequence[str]] -reveal_type(stmt_array_agg_order_by_2) +assert_type(stmt_array_agg_order_by_2, Select[Sequence[str]]) diff --git a/test/typing/plain_files/engine/engine_inspection.py b/test/typing/plain_files/engine/engine_inspection.py index 0ca331f189..0660f44380 100644 --- a/test/typing/plain_files/engine/engine_inspection.py +++ b/test/typing/plain_files/engine/engine_inspection.py @@ -1,7 +1,11 @@ import typing +from typing import assert_type from sqlalchemy import create_engine from sqlalchemy import inspect +from sqlalchemy.engine import Inspector +from sqlalchemy.engine.base import Engine +from sqlalchemy.engine.interfaces import ReflectedColumn e = create_engine("sqlite://") @@ -13,11 +17,8 @@ cols = insp.get_columns("some_table") c1 = cols[0] if typing.TYPE_CHECKING: - # EXPECTED_RE_TYPE: sqlalchemy.engine.base.Engine - reveal_type(e) + assert_type(e, Engine) - # EXPECTED_RE_TYPE: sqlalchemy.engine.reflection.Inspector.* - reveal_type(insp) + assert_type(insp, Inspector) - # EXPECTED_RE_TYPE: .*list.*TypedDict.*ReflectedColumn.* - reveal_type(cols) + assert_type(cols, list[ReflectedColumn]) diff --git a/test/typing/plain_files/engine/engine_result.py b/test/typing/plain_files/engine/engine_result.py index 1c76cf68b4..4c4b030f18 100644 --- a/test/typing/plain_files/engine/engine_result.py +++ b/test/typing/plain_files/engine/engine_result.py @@ -1,73 +1,56 @@ +from typing import Any +from typing import assert_type +from typing import Sequence + from sqlalchemy import column from sqlalchemy.engine import Result from sqlalchemy.engine import Row +from sqlalchemy.engine import RowMapping +from sqlalchemy.engine.result import FrozenResult +from sqlalchemy.engine.result import MappingResult +from sqlalchemy.engine.result import ScalarResult def row_one(row: Row[int, str, bool]) -> None: - # EXPECTED_TYPE: int - reveal_type(row[0]) - # EXPECTED_TYPE: str - reveal_type(row[1]) - # EXPECTED_TYPE: bool - reveal_type(row[2]) + assert_type(row[0], int) + assert_type(row[1], str) + assert_type(row[2], bool) # EXPECTED_MYPY: Tuple index out of range row[3] # EXPECTED_MYPY: No overload variant of "__getitem__" of "tuple" matches argument type "str" # noqa: E501 row["a"] - # EXPECTED_TYPE: RowMapping - reveal_type(row._mapping) + assert_type(row._mapping, RowMapping) rm = row._mapping - # EXPECTED_TYPE: Any - reveal_type(rm["foo"]) - # EXPECTED_TYPE: Any - reveal_type(rm[column("bar")]) + assert_type(rm["foo"], Any) + assert_type(rm[column("bar")], Any) # EXPECTED_MYPY_RE: Invalid index type "int" for "RowMapping"; expected type "(str \| SQLCoreOperations\[Any\]|Union\[str, SQLCoreOperations\[Any\]\])" # noqa: E501 rm[3] def result_one(res: Result[int, str]) -> None: - # EXPECTED_ROW_TYPE: Row[int, str] - reveal_type(res.one()) - # EXPECTED_ROW_TYPE: Row[int, str] | None - reveal_type(res.one_or_none()) - # EXPECTED_ROW_TYPE: Row[int, str] | None - reveal_type(res.fetchone()) - # EXPECTED_ROW_TYPE: Row[int, str] | None - reveal_type(res.first()) - # EXPECTED_ROW_TYPE: Sequence[Row[int, str]] - reveal_type(res.all()) - # EXPECTED_ROW_TYPE: Sequence[Row[int, str]] - reveal_type(res.fetchmany()) - # EXPECTED_ROW_TYPE: Sequence[Row[int, str]] - reveal_type(res.fetchall()) - # EXPECTED_ROW_TYPE: Row[int, str] - reveal_type(next(res)) + assert_type(res.one(), Row[int, str]) + assert_type(res.one_or_none(), Row[int, str] | None) + assert_type(res.fetchone(), Row[int, str] | None) + assert_type(res.first(), Row[int, str] | None) + assert_type(res.all(), Sequence[Row[int, str]]) + assert_type(res.fetchmany(), Sequence[Row[int, str]]) + assert_type(res.fetchall(), Sequence[Row[int, str]]) + assert_type(next(res), Row[int, str]) for rf in res: - # EXPECTED_ROW_TYPE: Row[int, str] - reveal_type(rf) + assert_type(rf, Row[int, str]) for rp in res.partitions(): - # EXPECTED_ROW_TYPE: Sequence[Row[int, str]] - reveal_type(rp) - - # EXPECTED_TYPE: ScalarResult[int] - res_s = reveal_type(res.scalars()) - # EXPECTED_TYPE: ScalarResult[int] - res_s = reveal_type(res.scalars(0)) - # EXPECTED_TYPE: int - reveal_type(res_s.one()) - # EXPECTED_TYPE: ScalarResult[Any] - reveal_type(res.scalars(1)) - # EXPECTED_TYPE: MappingResult - reveal_type(res.mappings()) - # EXPECTED_TYPE: FrozenResult[int, str] - reveal_type(res.freeze()) - - # EXPECTED_TYPE: int - reveal_type(res.scalar_one()) - # EXPECTED_TYPE: int | None - reveal_type(res.scalar_one_or_none()) - # EXPECTED_TYPE: int | None - reveal_type(res.scalar()) + assert_type(rp, Sequence[Row[int, str]]) + + res_s = assert_type(res.scalars(), ScalarResult[int]) + res_s = assert_type(res.scalars(0), ScalarResult[int]) + assert_type(res_s.one(), int) + assert_type(res.scalars(1), ScalarResult[Any]) + assert_type(res.mappings(), MappingResult) + assert_type(res.freeze(), FrozenResult[int, str]) + + assert_type(res.scalar_one(), int) + assert_type(res.scalar_one_or_none(), int | None) + assert_type(res.scalar(), int | None) diff --git a/test/typing/plain_files/engine/engines.py b/test/typing/plain_files/engine/engines.py index 15aa774e6a..7e06beaede 100644 --- a/test/typing/plain_files/engine/engines.py +++ b/test/typing/plain_files/engine/engines.py @@ -1,32 +1,34 @@ +from typing import Any +from typing import assert_type +from typing import Unpack + +from sqlalchemy import Connection from sqlalchemy import create_engine from sqlalchemy import Pool from sqlalchemy import select from sqlalchemy import text +from sqlalchemy.engine.base import Engine +from sqlalchemy.engine.cursor import CursorResult def regular() -> None: e = create_engine("sqlite://") - # EXPECTED_TYPE: Engine - reveal_type(e) + assert_type(e, Engine) with e.connect() as conn: - # EXPECTED_TYPE: Connection - reveal_type(conn) + assert_type(conn, Connection) result = conn.execute(text("select * from table")) - # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]] - reveal_type(result) + assert_type(result, CursorResult[Unpack[tuple[Any, ...]]]) with e.begin() as conn: - # EXPECTED_TYPE: Connection - reveal_type(conn) + assert_type(conn, Connection) result = conn.execute(text("select * from table")) - # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]] - reveal_type(result) + assert_type(result, CursorResult[Unpack[tuple[Any, ...]]]) engine = create_engine("postgresql://scott:tiger@localhost/test") status: str = engine.pool.status() diff --git a/test/typing/plain_files/ext/association_proxy/association_proxy_one.py b/test/typing/plain_files/ext/association_proxy/association_proxy_one.py index cb9f0b85d7..c6dd37b0c9 100644 --- a/test/typing/plain_files/ext/association_proxy/association_proxy_one.py +++ b/test/typing/plain_files/ext/association_proxy/association_proxy_one.py @@ -1,4 +1,5 @@ import typing +from typing import assert_type from typing import Set from sqlalchemy import ForeignKey @@ -6,6 +7,7 @@ from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.associationproxy import AssociationProxy +from sqlalchemy.ext.associationproxy import AssociationProxyInstance from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -40,8 +42,6 @@ class Address(Base): u1 = User() if typing.TYPE_CHECKING: - # EXPECTED_RE_TYPE: sqlalchemy.*.associationproxy.AssociationProxyInstance\[builtins.set\*?\[builtins.str\]\] - reveal_type(User.email_addresses) + assert_type(User.email_addresses, AssociationProxyInstance[set[str]]) - # EXPECTED_RE_TYPE: builtins.set\*?\[builtins.str\] - reveal_type(u1.email_addresses) + assert_type(u1.email_addresses, set[str]) diff --git a/test/typing/plain_files/ext/association_proxy/association_proxy_three.py b/test/typing/plain_files/ext/association_proxy/association_proxy_three.py index f338681f7c..2f18a9aff3 100644 --- a/test/typing/plain_files/ext/association_proxy/association_proxy_three.py +++ b/test/typing/plain_files/ext/association_proxy/association_proxy_three.py @@ -1,5 +1,6 @@ from __future__ import annotations +from typing import assert_type from typing import List from sqlalchemy import ForeignKey @@ -42,5 +43,4 @@ bm = BranchMilestone() x1 = bm.user_ids -# EXPECTED_TYPE: list[int] -reveal_type(x1) +assert_type(x1, list[int]) diff --git a/test/typing/plain_files/ext/association_proxy/association_proxy_two.py b/test/typing/plain_files/ext/association_proxy/association_proxy_two.py index 074a6a71a8..95bc47da3d 100644 --- a/test/typing/plain_files/ext/association_proxy/association_proxy_two.py +++ b/test/typing/plain_files/ext/association_proxy/association_proxy_two.py @@ -1,5 +1,6 @@ from __future__ import annotations +from typing import assert_type from typing import Final from sqlalchemy import Column @@ -52,14 +53,12 @@ user_keyword_table: Final[Table] = Table( user = User("jek") -# EXPECTED_TYPE: list[Keyword] -reveal_type(user.kw) +assert_type(user.kw, list[Keyword]) user.kw.append(Keyword("cheese-inspector")) user.keywords.append("cheese-inspector") -# EXPECTED_TYPE: list[str] -reveal_type(user.keywords) +assert_type(user.keywords, list[str]) user.keywords.append("snack ninja") diff --git a/test/typing/plain_files/ext/asyncio/engines.py b/test/typing/plain_files/ext/asyncio/engines.py index 7af764ecd8..9ddd59c898 100644 --- a/test/typing/plain_files/ext/asyncio/engines.py +++ b/test/typing/plain_files/ext/asyncio/engines.py @@ -1,11 +1,18 @@ from typing import Any +from typing import assert_type +from typing import Unpack from sqlalchemy import Connection from sqlalchemy import Enum from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import text +from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio.engine import AsyncConnection +from sqlalchemy.ext.asyncio.engine import AsyncEngine +from sqlalchemy.ext.asyncio.result import AsyncResult +from sqlalchemy.ext.asyncio.result import AsyncScalarResult def work_sync(conn: Connection, foo: int) -> Any: @@ -15,54 +22,45 @@ def work_sync(conn: Connection, foo: int) -> Any: async def asyncio() -> None: e = create_async_engine("sqlite://") - # EXPECTED_TYPE: AsyncEngine - reveal_type(e) + assert_type(e, AsyncEngine) async with e.connect() as conn: - # EXPECTED_TYPE: AsyncConnection - reveal_type(conn) + assert_type(conn, AsyncConnection) result = await conn.execute(text("select * from table")) - # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]] - reveal_type(result) + assert_type(result, CursorResult[Unpack[tuple[Any, ...]]]) # stream with direct await async_result = await conn.stream(text("select * from table")) - # EXPECTED_TYPE: AsyncResult[Unpack[.*tuple[Any, ...]]] - reveal_type(async_result) + assert_type(async_result, AsyncResult[Unpack[tuple[Any, ...]]]) # stream with context manager async with conn.stream( text("select * from table") ) as ctx_async_result: - # EXPECTED_TYPE: AsyncResult[Unpack[.*tuple[Any, ...]]] - reveal_type(ctx_async_result) + assert_type(ctx_async_result, AsyncResult[Unpack[tuple[Any, ...]]]) # stream_scalars with direct await async_scalar_result = await conn.stream_scalars( text("select * from table") ) - # EXPECTED_TYPE: AsyncScalarResult[Any] - reveal_type(async_scalar_result) + assert_type(async_scalar_result, AsyncScalarResult[Any]) # stream_scalars with context manager async with conn.stream_scalars( text("select * from table") ) as ctx_async_scalar_result: - # EXPECTED_TYPE: AsyncScalarResult[Any] - reveal_type(ctx_async_scalar_result) + assert_type(ctx_async_scalar_result, AsyncScalarResult[Any]) async with e.begin() as conn: - # EXPECTED_TYPE: AsyncConnection - reveal_type(conn) + assert_type(conn, AsyncConnection) result = await conn.execute(text("select * from table")) - # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]] - reveal_type(result) + assert_type(result, CursorResult[Unpack[tuple[Any, ...]]]) await conn.run_sync(work_sync, 1) diff --git a/test/typing/plain_files/ext/hybrid/hybrid_one.py b/test/typing/plain_files/ext/hybrid/hybrid_one.py index aef41395fe..09f98781c2 100644 --- a/test/typing/plain_files/ext/hybrid/hybrid_one.py +++ b/test/typing/plain_files/ext/hybrid/hybrid_one.py @@ -1,13 +1,18 @@ from __future__ import annotations import typing +from typing import assert_type +from sqlalchemy import Select from sqlalchemy import select +from sqlalchemy.ext.hybrid import _HybridClassLevelAccessor from sqlalchemy.ext.hybrid import hybrid_method from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.sql.elements import SQLCoreOperations +from sqlalchemy.sql.expression import BinaryExpression class Base(DeclarativeBase): @@ -66,26 +71,18 @@ stmt1 = select(Interval).where(expr1).where(expr4) stmt2 = select(expr4) if typing.TYPE_CHECKING: - # EXPECTED_RE_TYPE: builtins.int\*? - reveal_type(i1.length) + assert_type(i1.length, int) - # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.int\*?\] - reveal_type(Interval.length) + assert_type(Interval.length, _HybridClassLevelAccessor[int]) - # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] - reveal_type(expr1) + assert_type(expr1, BinaryExpression[bool]) - # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] - reveal_type(expr2) + assert_type(expr2, SQLCoreOperations[int]) - # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] - reveal_type(expr3) + assert_type(expr3, SQLCoreOperations[int]) - # EXPECTED_TYPE: bool - reveal_type(i1.fancy_thing(1, 2, 3)) + assert_type(i1.fancy_thing(1, 2, 3), bool) - # EXPECTED_TYPE: SQLCoreOperations[bool] - reveal_type(expr4) + assert_type(expr4, SQLCoreOperations[bool]) - # EXPECTED_TYPE: Select[bool] - reveal_type(stmt2) + assert_type(stmt2, Select[bool]) diff --git a/test/typing/plain_files/ext/hybrid/hybrid_two.py b/test/typing/plain_files/ext/hybrid/hybrid_two.py index b4f2aca769..a0b8e32542 100644 --- a/test/typing/plain_files/ext/hybrid/hybrid_two.py +++ b/test/typing/plain_files/ext/hybrid/hybrid_two.py @@ -1,13 +1,17 @@ from __future__ import annotations import typing +from typing import assert_type from sqlalchemy import Float from sqlalchemy import func +from sqlalchemy import Function +from sqlalchemy.ext.hybrid import _HybridClassLevelAccessor from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.sql.expression import BinaryExpression from sqlalchemy.sql.expression import ColumnElement @@ -44,11 +48,9 @@ class Interval(Base): # while we are here, check some Float[] / div type stuff if typing.TYPE_CHECKING: - # EXPECTED_RE_TYPE: sqlalchemy.*Function\[builtins.float\*?\] - reveal_type(f1) + assert_type(f1, Function[float]) - # EXPECTED_RE_TYPE: sqlalchemy.*ColumnElement\[builtins.float\*?\] - reveal_type(expr) + assert_type(expr, ColumnElement[float]) return expr # new way - use the original decorator with inplace @@ -66,11 +68,9 @@ class Interval(Base): # while we are here, check some Float[] / div type stuff if typing.TYPE_CHECKING: - # EXPECTED_RE_TYPE: sqlalchemy.*Function\[builtins.float\*?\] - reveal_type(f1) + assert_type(f1, Function[float]) - # EXPECTED_RE_TYPE: sqlalchemy.*ColumnElement\[builtins.float\*?\] - reveal_type(expr) + assert_type(expr, ColumnElement[float]) return expr @@ -92,38 +92,27 @@ expr3n = Interval.new_radius.in_([0.5, 5.2]) if typing.TYPE_CHECKING: - # EXPECTED_RE_TYPE: builtins.int\*? - reveal_type(i1.length) + assert_type(i1.length, int) - # EXPECTED_RE_TYPE: builtins.float\*? - reveal_type(i2.old_radius) + assert_type(i2.old_radius, float) - # EXPECTED_RE_TYPE: builtins.float\*? - reveal_type(i2.new_radius) + assert_type(i2.new_radius, float) - # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.int\*?\] - reveal_type(Interval.length) + assert_type(Interval.length, _HybridClassLevelAccessor[int]) - # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\] - reveal_type(Interval.old_radius) + assert_type(Interval.old_radius, _HybridClassLevelAccessor[float]) - # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\] - reveal_type(Interval.new_radius) + assert_type(Interval.new_radius, _HybridClassLevelAccessor[float]) - # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] - reveal_type(expr1) + assert_type(expr1, BinaryExpression[bool]) - # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\] - reveal_type(expr2o) + assert_type(expr2o, _HybridClassLevelAccessor[float]) - # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\] - reveal_type(expr2n) + assert_type(expr2n, _HybridClassLevelAccessor[float]) - # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] - reveal_type(expr3o) + assert_type(expr3o, BinaryExpression[bool]) - # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] - reveal_type(expr3n) + assert_type(expr3n, BinaryExpression[bool]) # test #9268 diff --git a/test/typing/plain_files/ext/indexable.py b/test/typing/plain_files/ext/indexable.py index c6c1c35299..976577d6fb 100644 --- a/test/typing/plain_files/ext/indexable.py +++ b/test/typing/plain_files/ext/indexable.py @@ -1,12 +1,15 @@ from __future__ import annotations from datetime import date +from typing import assert_type from typing import Dict from typing import List from sqlalchemy import ARRAY from sqlalchemy import JSON +from sqlalchemy import Select from sqlalchemy import select +from sqlalchemy.ext.hybrid import _HybridClassLevelAccessor from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.indexable import index_property from sqlalchemy.orm import DeclarativeBase @@ -38,29 +41,22 @@ a = Article( updates=[date(2025, 7, 28), date(2025, 7, 29)], ) -# EXPECTED_TYPE: str -reveal_type(a.topic) +assert_type(a.topic, str) -# EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.str\*?\] -reveal_type(Article.topic) +assert_type(Article.topic, _HybridClassLevelAccessor[str]) -# EXPECTED_TYPE: date -reveal_type(a.created_at) +assert_type(a.created_at, date) -# EXPECTED_TYPE: date -reveal_type(a.updated_at) +assert_type(a.updated_at, date) a.created_at = date(2025, 7, 30) -# EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[datetime.date\*?\] -reveal_type(Article.created_at) +assert_type(Article.created_at, _HybridClassLevelAccessor[date]) -# EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[datetime.date\*?\] -reveal_type(Article.updated_at) +assert_type(Article.updated_at, _HybridClassLevelAccessor[date]) stmt = select(Article.id, Article.topic, Article.created_at).where( Article.id == 1 ) -# EXPECTED_RE_TYPE: .*Select\[.*int, .*str, datetime\.date\] -reveal_type(stmt) +assert_type(stmt, Select[int, str, date]) diff --git a/test/typing/plain_files/ext/orderinglist/orderinglist_one.py b/test/typing/plain_files/ext/orderinglist/orderinglist_one.py index d2b7c5ece0..8371f00375 100644 --- a/test/typing/plain_files/ext/orderinglist/orderinglist_one.py +++ b/test/typing/plain_files/ext/orderinglist/orderinglist_one.py @@ -1,11 +1,14 @@ from __future__ import annotations import re +from typing import assert_type +from typing import Callable from typing import Sequence from typing import TYPE_CHECKING from sqlalchemy import ForeignKey from sqlalchemy.ext.orderinglist import ordering_list +from sqlalchemy.ext.orderinglist import OrderingList from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -47,8 +50,6 @@ slide = Slide() if TYPE_CHECKING: - # EXPECTED_RE_TYPE: def \(\) -> sqlalchemy.*.orderinglist.OrderingList\[orderinglist_one.Bullet\] - reveal_type(pos_from_text) + assert_type(pos_from_text, Callable[[], OrderingList[Bullet]]) - # EXPECTED_TYPE: builtins.list[orderinglist_one.Bullet] - reveal_type(slide.bullets) + assert_type(slide.bullets, list[Bullet]) diff --git a/test/typing/plain_files/inspection_inspect.py b/test/typing/plain_files/inspection_inspect.py index 886484dbc9..a37d400efb 100644 --- a/test/typing/plain_files/inspection_inspect.py +++ b/test/typing/plain_files/inspection_inspect.py @@ -1,3 +1,5 @@ +from typing import Any +from typing import assert_type from typing import List from sqlalchemy import create_engine @@ -8,6 +10,7 @@ from sqlalchemy.orm import DeclarativeBaseNoMeta from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import Mapper +from sqlalchemy.orm.state import InstanceState class Base(DeclarativeBase): @@ -32,10 +35,8 @@ class B(BaseNoMeta): data: Mapped[str] -# EXPECTED_TYPE: Mapper[Any] -reveal_type(A.__mapper__) -# EXPECTED_TYPE: Mapper[Any] -reveal_type(B.__mapper__) +assert_type(A.__mapper__, Mapper[Any]) +assert_type(B.__mapper__, Mapper[Any]) a1 = A(data="d") b1 = B(data="d") @@ -45,22 +46,17 @@ e = create_engine("sqlite://") insp_a1 = inspect(a1) t: bool = insp_a1.transient -# EXPECTED_TYPE: InstanceState[A] -reveal_type(insp_a1) -# EXPECTED_TYPE: InstanceState[B] -reveal_type(inspect(b1)) +assert_type(insp_a1, InstanceState[A]) +assert_type(inspect(b1), InstanceState[B]) m: Mapper[A] = inspect(A) -# EXPECTED_TYPE: Mapper[A] -reveal_type(inspect(A)) -# EXPECTED_TYPE: Mapper[B] -reveal_type(inspect(B)) +assert_type(inspect(A), Mapper[A]) +assert_type(inspect(B), Mapper[B]) tables: List[str] = inspect(e).get_table_names() i: Inspector = inspect(e) -# EXPECTED_TYPE: Inspector -reveal_type(inspect(e)) +assert_type(inspect(e), Inspector) with e.connect() as conn: diff --git a/test/typing/plain_files/orm/composite.py b/test/typing/plain_files/orm/composite.py index f82bbe7c2d..f808696ca5 100644 --- a/test/typing/plain_files/orm/composite.py +++ b/test/typing/plain_files/orm/composite.py @@ -1,6 +1,8 @@ from typing import Any +from typing import assert_type from typing import Tuple +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy.orm import composite from sqlalchemy.orm import DeclarativeBase @@ -58,14 +60,10 @@ v1 = Vertex(start=Point(3, 4), end=Point(5, 6)) stmt = select(Vertex).where(Vertex.start.in_([Point(3, 4)])) -# EXPECTED_TYPE: Select[Vertex] -reveal_type(stmt) +assert_type(stmt, Select[Vertex]) -# EXPECTED_TYPE: composite.Point -reveal_type(v1.start) +assert_type(v1.start, Point) -# EXPECTED_TYPE: composite.Point -reveal_type(v1.end) +assert_type(v1.end, Point) -# EXPECTED_TYPE: int -reveal_type(v1.end.y) +assert_type(v1.end.y, int) diff --git a/test/typing/plain_files/orm/composite_dc.py b/test/typing/plain_files/orm/composite_dc.py index 3d8117a999..25aaae1703 100644 --- a/test/typing/plain_files/orm/composite_dc.py +++ b/test/typing/plain_files/orm/composite_dc.py @@ -1,5 +1,7 @@ import dataclasses +from typing import assert_type +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy.orm import composite from sqlalchemy.orm import DeclarativeBase @@ -38,14 +40,10 @@ v1 = Vertex(start=Point(3, 4), end=Point(5, 6)) stmt = select(Vertex).where(Vertex.start.in_([Point(3, 4)])) -# EXPECTED_TYPE: Select[Vertex] -reveal_type(stmt) +assert_type(stmt, Select[Vertex]) -# EXPECTED_TYPE: composite.Point -reveal_type(v1.start) +assert_type(v1.start, Point) -# EXPECTED_TYPE: composite.Point -reveal_type(v1.end) +assert_type(v1.end, Point) -# EXPECTED_TYPE: int -reveal_type(v1.end.y) +assert_type(v1.end.y, int) diff --git a/test/typing/plain_files/orm/dataclass_transforms_decorator.py b/test/typing/plain_files/orm/dataclass_transforms_decorator.py index 01114c51e9..6738788d46 100644 --- a/test/typing/plain_files/orm/dataclass_transforms_decorator.py +++ b/test/typing/plain_files/orm/dataclass_transforms_decorator.py @@ -1,3 +1,5 @@ +from typing import assert_type + from sqlalchemy import Integer from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_as_dataclass @@ -19,5 +21,4 @@ class Relationships: rs = Relationships(entity_id1=1, entity_id2=2, level=1) -# EXPECTED_TYPE: int -reveal_type(rs.entity_id1) +assert_type(rs.entity_id1, int) diff --git a/test/typing/plain_files/orm/dataclass_transforms_decorator_w_mixins.py b/test/typing/plain_files/orm/dataclass_transforms_decorator_w_mixins.py index ecdb8b3eda..5ca3be4612 100644 --- a/test/typing/plain_files/orm/dataclass_transforms_decorator_w_mixins.py +++ b/test/typing/plain_files/orm/dataclass_transforms_decorator_w_mixins.py @@ -1,3 +1,5 @@ +from typing import assert_type + from sqlalchemy import Integer from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_as_dataclass @@ -32,8 +34,6 @@ class Relationships(RelationshipsModel): # (this is the type checker, not us) rs = Relationships(entity_id1=1, entity_id2=2, level=1) -# EXPECTED_TYPE: int -reveal_type(rs.entity_id1) +assert_type(rs.entity_id1, int) -# EXPECTED_TYPE: int -reveal_type(rs.level) +assert_type(rs.level, int) diff --git a/test/typing/plain_files/orm/dataclass_transforms_one.py b/test/typing/plain_files/orm/dataclass_transforms_one.py index 986483d8ef..b99adfab0d 100644 --- a/test/typing/plain_files/orm/dataclass_transforms_one.py +++ b/test/typing/plain_files/orm/dataclass_transforms_one.py @@ -1,5 +1,6 @@ from __future__ import annotations +from typing import assert_type from typing import Optional from sqlalchemy.orm import column_property @@ -25,11 +26,9 @@ class TestInitialSupport(Base): tis = TestInitialSupport(data="some data", y=5) -# EXPECTED_TYPE: str -reveal_type(tis.data) +assert_type(tis.data, str) -# EXPECTED_RE_TYPE: .*builtins.int \| None -reveal_type(tis.y) +assert_type(tis.y, int | None) tis.data = "some other data" diff --git a/test/typing/plain_files/orm/declared_attr_one.py b/test/typing/plain_files/orm/declared_attr_one.py index 79f1548e36..4493a2667a 100644 --- a/test/typing/plain_files/orm/declared_attr_one.py +++ b/test/typing/plain_files/orm/declared_attr_one.py @@ -1,16 +1,22 @@ from datetime import datetime import typing +from typing import Any +from typing import assert_type +from typing import Unpack from sqlalchemy import DateTime from sqlalchemy import Index from sqlalchemy import Integer +from sqlalchemy import Select from sqlalchemy import String from sqlalchemy import UniqueConstraint from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import MappedClassProtocol +from sqlalchemy.orm.mapper import Mapper from sqlalchemy.sql.schema import PrimaryKeyConstraint @@ -74,14 +80,11 @@ class Manager(Employee): def do_something_with_mapped_class( cls_: MappedClassProtocol[Employee], ) -> None: - # EXPECTED_TYPE: Select[Unpack[.*tuple[Any, ...]]] - reveal_type(cls_.__table__.select()) + assert_type(cls_.__table__.select(), Select[Unpack[tuple[Any, ...]]]) - # EXPECTED_TYPE: Mapper[Employee] - reveal_type(cls_.__mapper__) + assert_type(cls_.__mapper__, Mapper[Employee]) - # EXPECTED_TYPE: Employee - reveal_type(cls_()) + assert_type(cls_(), Employee) do_something_with_mapped_class(Manager) @@ -89,8 +92,6 @@ do_something_with_mapped_class(Engineer) if typing.TYPE_CHECKING: - # EXPECTED_TYPE: InstrumentedAttribute[datetime] - reveal_type(Engineer.start_date) + assert_type(Engineer.start_date, InstrumentedAttribute[datetime]) - # EXPECTED_TYPE: InstrumentedAttribute[datetime] - reveal_type(Manager.start_date) + assert_type(Manager.start_date, InstrumentedAttribute[datetime]) diff --git a/test/typing/plain_files/orm/declared_attr_two.py b/test/typing/plain_files/orm/declared_attr_two.py index c8e12ee931..3792570513 100644 --- a/test/typing/plain_files/orm/declared_attr_two.py +++ b/test/typing/plain_files/orm/declared_attr_two.py @@ -1,9 +1,11 @@ import typing +from typing import assert_type from sqlalchemy import Integer from sqlalchemy import Text from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -39,14 +41,10 @@ class Foo(Base): u1 = User() if typing.TYPE_CHECKING: - # EXPECTED_TYPE: str - reveal_type(User.__tablename__) + assert_type(User.__tablename__, str) - # EXPECTED_TYPE: str - reveal_type(Foo.__tablename__) + assert_type(Foo.__tablename__, str) - # EXPECTED_TYPE: str - reveal_type(u1.related_data) + assert_type(u1.related_data, str) - # EXPECTED_TYPE: InstrumentedAttribute[str] - reveal_type(User.related_data) + assert_type(User.related_data, InstrumentedAttribute[str]) diff --git a/test/typing/plain_files/orm/dynamic_rel.py b/test/typing/plain_files/orm/dynamic_rel.py index 8b406bb171..c9ebdbc1e4 100644 --- a/test/typing/plain_files/orm/dynamic_rel.py +++ b/test/typing/plain_files/orm/dynamic_rel.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing +from typing import assert_type from sqlalchemy import ForeignKey from sqlalchemy import select @@ -10,6 +11,7 @@ from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import Session +from sqlalchemy.orm.dynamic import AppenderQuery class Base(DeclarativeBase): @@ -37,19 +39,16 @@ with Session() as session: session.commit() if typing.TYPE_CHECKING: - # EXPECTED_TYPE: AppenderQuery[Address] - reveal_type(u.addresses) + assert_type(u.addresses, AppenderQuery[Address]) count = u.addresses.count() if typing.TYPE_CHECKING: - # EXPECTED_TYPE: int - reveal_type(count) + assert_type(count, int) address = u.addresses.filter(Address.email_address.like("xyz")).one() if typing.TYPE_CHECKING: - # EXPECTED_TYPE: Address - reveal_type(address) + assert_type(address, Address) u.addresses.append(Address()) u.addresses.extend([Address(), Address()]) @@ -57,8 +56,7 @@ with Session() as session: current_addresses = list(u.addresses) if typing.TYPE_CHECKING: - # EXPECTED_TYPE: list[Address] - reveal_type(current_addresses) + assert_type(current_addresses, list[Address]) # can assign plain list u.addresses = [] @@ -68,15 +66,13 @@ with Session() as session: if typing.TYPE_CHECKING: # still an AppenderQuery - # EXPECTED_TYPE: AppenderQuery[Address] - reveal_type(u.addresses) + assert_type(u.addresses, AppenderQuery[Address]) u.addresses = {Address(), Address()} if typing.TYPE_CHECKING: # still an AppenderQuery - # EXPECTED_TYPE: AppenderQuery[Address] - reveal_type(u.addresses) + assert_type(u.addresses, AppenderQuery[Address]) u.addresses.append(Address()) diff --git a/test/typing/plain_files/orm/issue_9340.py b/test/typing/plain_files/orm/issue_9340.py index 6ccd2eed31..81155f4c80 100644 --- a/test/typing/plain_files/orm/issue_9340.py +++ b/test/typing/plain_files/orm/issue_9340.py @@ -1,13 +1,16 @@ +from typing import assert_type from typing import Sequence from typing import TYPE_CHECKING from sqlalchemy import create_engine +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import Session from sqlalchemy.orm import with_polymorphic +from sqlalchemy.orm.util import AliasedClass class Base(DeclarativeBase): ... @@ -39,8 +42,7 @@ def get_messages() -> Sequence[Message]: message_query = select(Message) if TYPE_CHECKING: - # EXPECTED_TYPE: Select[Message] - reveal_type(message_query) + assert_type(message_query, Select[Message]) return session.scalars(message_query).all() @@ -50,13 +52,11 @@ def get_poly_messages() -> Sequence[Message]: PolymorphicMessage = with_polymorphic(Message, (UserComment,)) if TYPE_CHECKING: - # EXPECTED_TYPE: AliasedClass[Message] - reveal_type(PolymorphicMessage) + assert_type(PolymorphicMessage, AliasedClass[Message]) poly_query = select(PolymorphicMessage) if TYPE_CHECKING: - # EXPECTED_TYPE: Select[Message] - reveal_type(poly_query) + assert_type(poly_query, Select[Message]) return session.scalars(poly_query).all() diff --git a/test/typing/plain_files/orm/keyfunc_dict.py b/test/typing/plain_files/orm/keyfunc_dict.py index 831861000c..0b275bac8d 100644 --- a/test/typing/plain_files/orm/keyfunc_dict.py +++ b/test/typing/plain_files/orm/keyfunc_dict.py @@ -1,4 +1,5 @@ import typing +from typing import assert_type from typing import Dict from typing import Optional @@ -42,5 +43,4 @@ item = Item() item.notes["a"] = Note("a", "atext") if typing.TYPE_CHECKING: - # EXPECTED_TYPE: dict_items[str, Note] - reveal_type(item.notes.items()) + assert_type(list(item.notes.items()), list[tuple[str, Note]]) diff --git a/test/typing/plain_files/orm/relationship.py b/test/typing/plain_files/orm/relationship.py index 82e668ceeb..f818791970 100644 --- a/test/typing/plain_files/orm/relationship.py +++ b/test/typing/plain_files/orm/relationship.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing +from typing import assert_type from typing import ClassVar from typing import List from typing import Optional @@ -16,6 +17,7 @@ from sqlalchemy import select from sqlalchemy import Table from sqlalchemy.orm import aliased from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.orm import joinedload from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -24,6 +26,7 @@ from sqlalchemy.orm import Relationship from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm import with_polymorphic +from sqlalchemy.orm.attributes import QueryableAttribute class Base(DeclarativeBase): @@ -131,47 +134,46 @@ class Engineer(Employee): if typing.TYPE_CHECKING: - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str \| None\] - reveal_type(User.extra) + assert_type(User.extra, InstrumentedAttribute[str | None]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str \| None\] - reveal_type(User.extra_name) + assert_type(User.extra_name, InstrumentedAttribute[str | None]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\] - reveal_type(Address.email) + assert_type(Address.email, InstrumentedAttribute[str]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\] - reveal_type(Address.email_name) + assert_type(Address.email_name, InstrumentedAttribute[str]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[relationship.Address\]\] - reveal_type(User.addresses_style_one) + assert_type(User.addresses_style_one, InstrumentedAttribute[list[Address]]) - # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[relationship.Address\]\] - reveal_type(User.addresses_style_two) + assert_type(User.addresses_style_two, InstrumentedAttribute[set[Address]]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[relationship.User\]\] - reveal_type(Group.addresses_style_one_anno_only) + assert_type( + Group.addresses_style_one_anno_only, InstrumentedAttribute[list[User]] + ) - # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[relationship.User\]\] - reveal_type(Group.addresses_style_two_anno_only) + assert_type( + Group.addresses_style_two_anno_only, InstrumentedAttribute[set[User]] + ) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[relationship.MoreMail\]\] - reveal_type(Address.rel_style_one) + assert_type(Address.rel_style_one, InstrumentedAttribute[list[MoreMail]]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[relationship.MoreMail\]\] - reveal_type(Address.rel_style_one_anno_only) + assert_type( + Address.rel_style_one_anno_only, InstrumentedAttribute[set[MoreMail]] + ) - # EXPECTED_RE_TYPE: sqlalchemy.*.QueryableAttribute\[relationship.Engineer\] - reveal_type(Team.employees.of_type(Engineer)) + assert_type(Team.employees.of_type(Engineer), QueryableAttribute[Engineer]) - # EXPECTED_RE_TYPE: sqlalchemy.*.QueryableAttribute\[relationship.Employee\] - reveal_type(Team.employees.of_type(aliased(Employee))) + assert_type( + Team.employees.of_type(aliased(Employee)), QueryableAttribute[Employee] + ) - # EXPECTED_RE_TYPE: sqlalchemy.*.QueryableAttribute\[relationship.Engineer\] - reveal_type(Team.employees.of_type(aliased(Engineer))) + assert_type( + Team.employees.of_type(aliased(Engineer)), QueryableAttribute[Engineer] + ) - # EXPECTED_RE_TYPE: sqlalchemy.*.QueryableAttribute\[relationship.Employee\] - reveal_type(Team.employees.of_type(with_polymorphic(Employee, [Engineer]))) + assert_type( + Team.employees.of_type(with_polymorphic(Employee, [Engineer])), + QueryableAttribute[Employee], + ) mapper_registry: registry = registry() diff --git a/test/typing/plain_files/orm/session.py b/test/typing/plain_files/orm/session.py index 1cc5b1c014..af0de3386b 100644 --- a/test/typing/plain_files/orm/session.py +++ b/test/typing/plain_files/orm/session.py @@ -1,10 +1,12 @@ from __future__ import annotations import asyncio +from typing import assert_type from typing import List from sqlalchemy import create_engine from sqlalchemy import ForeignKey +from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession @@ -15,6 +17,8 @@ from sqlalchemy.orm import relationship from sqlalchemy.orm import scoped_session from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import SessionTransaction +from sqlalchemy.orm.query import Query class Base(DeclarativeBase): @@ -50,19 +54,16 @@ with Session(e) as sess: q = sess.query(User).filter_by(id=7) - # EXPECTED_TYPE: Query[User] - reveal_type(q) + assert_type(q, Query[User]) rows1 = q.all() - # EXPECTED_RE_TYPE: builtins.list\[.*User\*?\] - reveal_type(rows1) + assert_type(rows1, list[User]) q2 = sess.query(User.id).filter_by(id=7) rows2 = q2.all() - # EXPECTED_TYPE: list[.*Row[.*int].*] - reveal_type(rows2) + assert_type(rows2, list[Row[int]]) # test #8280 @@ -86,12 +87,10 @@ with Session(e) as sess: # test #9125 for row in sess.query(User.id, User.name): - # EXPECTED_TYPE: .*Row[int, str].* - reveal_type(row) + assert_type(row, Row[int, str]) for uobj1 in sess.query(User): - # EXPECTED_TYPE: User - reveal_type(uobj1) + assert_type(uobj1, User) sess.query(User).limit(None).offset(None).limit(10).offset(10).limit( User.id @@ -100,8 +99,7 @@ with Session(e) as sess: # test #11083 with sess.begin() as tx: - # EXPECTED_TYPE: SessionTransaction - reveal_type(tx) + assert_type(tx, SessionTransaction) # more result tests in typed_results.py diff --git a/test/typing/plain_files/orm/sessionmakers.py b/test/typing/plain_files/orm/sessionmakers.py index 60d2e8b33e..8e959cea98 100644 --- a/test/typing/plain_files/orm/sessionmakers.py +++ b/test/typing/plain_files/orm/sessionmakers.py @@ -1,5 +1,7 @@ """test sessionmaker, originally for #7656""" +from typing import assert_type + from sqlalchemy import create_engine from sqlalchemy import Engine from sqlalchemy.ext.asyncio import async_scoped_session @@ -11,6 +13,7 @@ from sqlalchemy.orm import QueryPropertyDescriptor from sqlalchemy.orm import scoped_session from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm.query import Query async_engine = create_async_engine("...") @@ -39,19 +42,16 @@ async def async_main() -> None: fac = async_session_factory(async_engine) async with fac() as sess: - # EXPECTED_TYPE: MyAsyncSession - reveal_type(sess) + assert_type(sess, MyAsyncSession) async with fac.begin() as sess: - # EXPECTED_TYPE: MyAsyncSession - reveal_type(sess) + assert_type(sess, MyAsyncSession) scoped_fac = async_scoped_session_factory(async_engine) sess = scoped_fac() - # EXPECTED_TYPE: MyAsyncSession - reveal_type(sess) + assert_type(sess, MyAsyncSession) engine = create_engine("...") @@ -75,49 +75,41 @@ def main() -> None: fac = session_factory(engine) with fac() as sess: - # EXPECTED_TYPE: MySession - reveal_type(sess) + assert_type(sess, MySession) with fac.begin() as sess: - # EXPECTED_TYPE: MySession - reveal_type(sess) + assert_type(sess, MySession) scoped_fac = scoped_session_factory(engine) sess = scoped_fac() - # EXPECTED_TYPE: MySession - reveal_type(sess) + assert_type(sess, MySession) def test_8837_sync() -> None: sm = sessionmaker() - # EXPECTED_TYPE: sessionmaker[Session] - reveal_type(sm) + assert_type(sm, sessionmaker[Session]) session = sm() - # EXPECTED_TYPE: Session - reveal_type(session) + assert_type(session, Session) def test_8837_async() -> None: as_ = async_sessionmaker() - # EXPECTED_TYPE: async_sessionmaker[AsyncSession] - reveal_type(as_) + assert_type(as_, async_sessionmaker[AsyncSession]) async_session = as_() - # EXPECTED_TYPE: AsyncSession - reveal_type(async_session) + assert_type(async_session, AsyncSession) # test #9338 ss_9338 = scoped_session_factory(engine) -# EXPECTED_TYPE: QueryPropertyDescriptor -reveal_type(ss_9338.query_property()) +assert_type(ss_9338.query_property(), QueryPropertyDescriptor) qp: QueryPropertyDescriptor = ss_9338.query_property() @@ -125,16 +117,13 @@ class Foo: query = qp -# EXPECTED_TYPE: Query[Foo] -reveal_type(Foo.query) +assert_type(Foo.query, Query[Foo]) -# EXPECTED_TYPE: list[Foo] -reveal_type(Foo.query.all()) +assert_type(Foo.query.all(), list[Foo]) class Bar: query: QueryPropertyDescriptor = ss_9338.query_property() -# EXPECTED_TYPE: Query[Bar] -reveal_type(Bar.query) +assert_type(Bar.query, Query[Bar]) diff --git a/test/typing/plain_files/orm/trad_relationship_uselist.py b/test/typing/plain_files/orm/trad_relationship_uselist.py index e15fe70934..f8f9111e82 100644 --- a/test/typing/plain_files/orm/trad_relationship_uselist.py +++ b/test/typing/plain_files/orm/trad_relationship_uselist.py @@ -1,6 +1,8 @@ """traditional relationship patterns with explicit uselist.""" import typing +from typing import Any +from typing import assert_type from typing import cast from typing import Dict from typing import List @@ -11,6 +13,7 @@ from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship @@ -99,45 +102,34 @@ class Address(Base): if typing.TYPE_CHECKING: - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[trad_relationship_uselist.Address\]\] - reveal_type(User.addresses_style_one) + assert_type(User.addresses_style_one, InstrumentedAttribute[list[Address]]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[trad_relationship_uselist.Address\]\] - reveal_type(User.addresses_style_two) + assert_type(User.addresses_style_two, InstrumentedAttribute[set[Address]]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] - reveal_type(User.addresses_style_three) + assert_type(User.addresses_style_three, InstrumentedAttribute[Any]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] - reveal_type(User.addresses_style_three_cast) + assert_type(User.addresses_style_three_cast, InstrumentedAttribute[Any]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] - reveal_type(User.addresses_style_four) + assert_type(User.addresses_style_four, InstrumentedAttribute[Any]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] - reveal_type(Address.user_style_one) + assert_type(Address.user_style_one, InstrumentedAttribute[Any]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\] - reveal_type(Address.user_style_one_typed) + assert_type(Address.user_style_one_typed, InstrumentedAttribute[User]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] - reveal_type(Address.user_style_two) + assert_type(Address.user_style_two, InstrumentedAttribute[Any]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\] - reveal_type(Address.user_style_two_typed) + assert_type(Address.user_style_two_typed, InstrumentedAttribute[User]) # reveal_type(Address.user_style_six) # reveal_type(Address.user_style_seven) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] - reveal_type(Address.user_style_eight) + assert_type(Address.user_style_eight, InstrumentedAttribute[Any]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] - reveal_type(Address.user_style_nine) + assert_type(Address.user_style_nine, InstrumentedAttribute[Any]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] - reveal_type(Address.user_style_ten) + assert_type(Address.user_style_ten, InstrumentedAttribute[Any]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.dict\*?\[builtins.str, trad_relationship_uselist.User\]\] - reveal_type(Address.user_style_ten_typed) + assert_type( + Address.user_style_ten_typed, InstrumentedAttribute[dict[str, User]] + ) diff --git a/test/typing/plain_files/orm/traditional_relationship.py b/test/typing/plain_files/orm/traditional_relationship.py index bd6bada528..062ea2b3f0 100644 --- a/test/typing/plain_files/orm/traditional_relationship.py +++ b/test/typing/plain_files/orm/traditional_relationship.py @@ -7,6 +7,8 @@ if no uselists are present. """ import typing +from typing import Any +from typing import assert_type from typing import List from typing import Set @@ -14,6 +16,7 @@ from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship @@ -80,29 +83,20 @@ class Address(Base): if typing.TYPE_CHECKING: - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.Address\]\] - reveal_type(User.addresses_style_one) + assert_type(User.addresses_style_one, InstrumentedAttribute[list[Address]]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[traditional_relationship.Address\]\] - reveal_type(User.addresses_style_two) + assert_type(User.addresses_style_two, InstrumentedAttribute[set[Address]]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] - reveal_type(Address.user_style_one) + assert_type(Address.user_style_one, InstrumentedAttribute[Any]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\] - reveal_type(Address.user_style_one_typed) + assert_type(Address.user_style_one_typed, InstrumentedAttribute[User]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] - reveal_type(Address.user_style_two) + assert_type(Address.user_style_two, InstrumentedAttribute[Any]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\] - reveal_type(Address.user_style_two_typed) + assert_type(Address.user_style_two_typed, InstrumentedAttribute[User]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\] - reveal_type(Address.user_style_three) + assert_type(Address.user_style_three, InstrumentedAttribute[list[User]]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\] - reveal_type(Address.user_style_four) + assert_type(Address.user_style_four, InstrumentedAttribute[list[User]]) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] - reveal_type(Address.user_style_five) + assert_type(Address.user_style_five, InstrumentedAttribute[Any]) diff --git a/test/typing/plain_files/orm/typed_queries.py b/test/typing/plain_files/orm/typed_queries.py index a3c07dd016..d21922e869 100644 --- a/test/typing/plain_files/orm/typed_queries.py +++ b/test/typing/plain_files/orm/typed_queries.py @@ -1,5 +1,9 @@ from __future__ import annotations +from typing import Any +from typing import assert_type +from typing import Unpack + from sqlalchemy import Column from sqlalchemy import column from sqlalchemy import create_engine @@ -16,11 +20,21 @@ from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import text from sqlalchemy import update +from sqlalchemy.engine import Result +from sqlalchemy.engine.row import Row from sqlalchemy.orm import aliased from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import Session +from sqlalchemy.orm.query import Query +from sqlalchemy.orm.query import RowReturningQuery +from sqlalchemy.sql.dml import ReturningInsert +from sqlalchemy.sql.elements import KeyedColumnElement +from sqlalchemy.sql.expression import FromClause +from sqlalchemy.sql.expression import TextClause +from sqlalchemy.sql.selectable import ScalarSelect +from sqlalchemy.sql.selectable import TextualSelect class Base(DeclarativeBase): @@ -51,13 +65,11 @@ connection = e.connect() def t_select_1() -> None: stmt = select(User.id, User.name).filter(User.id == 5) - # EXPECTED_TYPE: Select[int, str] - reveal_type(stmt) + assert_type(stmt, Select[int, str]) result = session.execute(stmt) - # EXPECTED_TYPE: .*Result[int, str].* - reveal_type(result) + assert_type(result, Result[int, str]) def t_select_2() -> None: @@ -75,13 +87,11 @@ def t_select_2() -> None: .fetch(User.id) ) - # EXPECTED_TYPE: Select[User] - reveal_type(stmt) + assert_type(stmt, Select[User]) result = session.execute(stmt) - # EXPECTED_TYPE: .*Result[User].* - reveal_type(result) + assert_type(result, Result[User]) def t_select_3() -> None: @@ -95,196 +105,157 @@ def t_select_3() -> None: # awkwardnesses that aren't really worth it ua(id=1, name="foo") - # EXPECTED_RE_TYPE: [tT]ype\[.*\.User\] - reveal_type(ua) + assert_type(ua, type[User]) stmt = select(ua.id, ua.name).filter(User.id == 5) - # EXPECTED_TYPE: Select[int, str] - reveal_type(stmt) + assert_type(stmt, Select[int, str]) result = session.execute(stmt) - # EXPECTED_TYPE: .*Result[int, str].* - reveal_type(result) + assert_type(result, Result[int, str]) def t_select_4() -> None: ua = aliased(User) stmt = select(ua, User).filter(User.id == 5) - # EXPECTED_TYPE: Select[User, User] - reveal_type(stmt) + assert_type(stmt, Select[User, User]) result = session.execute(stmt) - # EXPECTED_TYPE: Result[User, User] - reveal_type(result) + assert_type(result, Result[User, User]) def t_legacy_query_single_entity() -> None: q1 = session.query(User).filter(User.id == 5) - # EXPECTED_TYPE: Query[User] - reveal_type(q1) + assert_type(q1, Query[User]) - # EXPECTED_TYPE: User - reveal_type(q1.one()) + assert_type(q1.one(), User) - # EXPECTED_TYPE: list[User] - reveal_type(q1.all()) + assert_type(q1.all(), list[User]) # mypy switches to builtins.list for some reason here - # EXPECTED_RE_TYPE: .*\.list\[.*Row\*?\[.*User\].*\] - reveal_type(q1.only_return_tuples(True).all()) + assert_type(q1.only_return_tuples(True).all(), list[Row[User]]) - # EXPECTED_TYPE: list[tuple[User]] - reveal_type(q1.tuples().all()) + assert_type(q1.tuples().all(), list[tuple[User]]) def t_legacy_query_cols_1() -> None: q1 = session.query(User.id, User.name).filter(User.id == 5) - # EXPECTED_TYPE: RowReturningQuery[int, str] - reveal_type(q1) + assert_type(q1, RowReturningQuery[int, str]) - # EXPECTED_TYPE: .*Row[int, str].* - reveal_type(q1.one()) + assert_type(q1.one(), Row[int, str]) r1 = q1.one() x, y = r1 - # EXPECTED_TYPE: int - reveal_type(x) + assert_type(x, int) - # EXPECTED_TYPE: str - reveal_type(y) + assert_type(y, str) def t_legacy_query_cols_tupleq_1() -> None: q1 = session.query(User.id, User.name).filter(User.id == 5) - # EXPECTED_TYPE: RowReturningQuery[int, str] - reveal_type(q1) + assert_type(q1, RowReturningQuery[int, str]) q2 = q1.tuples() - # EXPECTED_TYPE: tuple[int, str] - reveal_type(q2.one()) + assert_type(q2.one(), tuple[int, str]) r1 = q2.one() x, y = r1 - # EXPECTED_TYPE: int - reveal_type(x) + assert_type(x, int) - # EXPECTED_TYPE: str - reveal_type(y) + assert_type(y, str) def t_legacy_query_cols_1_with_entities() -> None: q1 = session.query(User).filter(User.id == 5) - # EXPECTED_TYPE: Query[User] - reveal_type(q1) + assert_type(q1, Query[User]) q2 = q1.with_entities(User.id, User.name) - # EXPECTED_TYPE: RowReturningQuery[int, str] - reveal_type(q2) + assert_type(q2, RowReturningQuery[int, str]) - # EXPECTED_TYPE: .*Row[int, str].* - reveal_type(q2.one()) + assert_type(q2.one(), Row[int, str]) r1 = q2.one() x, y = r1 - # EXPECTED_TYPE: int - reveal_type(x) + assert_type(x, int) - # EXPECTED_TYPE: str - reveal_type(y) + assert_type(y, str) def t_select_with_only_cols() -> None: q1 = select(User).where(User.id == 5) - # EXPECTED_TYPE: Select[User] - reveal_type(q1) + assert_type(q1, Select[User]) q2 = q1.with_only_columns(User.id, User.name) - # EXPECTED_TYPE: Select[int, str] - reveal_type(q2) + assert_type(q2, Select[int, str]) row = connection.execute(q2).one() - # EXPECTED_TYPE: .*Row[int, str].* - reveal_type(row) + assert_type(row, Row[int, str]) x, y = row - # EXPECTED_TYPE: int - reveal_type(x) + assert_type(x, int) - # EXPECTED_TYPE: str - reveal_type(y) + assert_type(y, str) def t_legacy_query_cols_2() -> None: a1 = aliased(User) q1 = session.query(User, a1, User.name).filter(User.id == 5) - # EXPECTED_TYPE: RowReturningQuery[User, User, str] - reveal_type(q1) + assert_type(q1, RowReturningQuery[User, User, str]) - # EXPECTED_TYPE: .*Row[User, User, str].* - reveal_type(q1.one()) + assert_type(q1.one(), Row[User, User, str]) r1 = q1.one() x, y, z = r1 - # EXPECTED_TYPE: User - reveal_type(x) + assert_type(x, User) - # EXPECTED_TYPE: User - reveal_type(y) + assert_type(y, User) - # EXPECTED_TYPE: str - reveal_type(z) + assert_type(z, str) def t_legacy_query_cols_2_with_entities() -> None: q1 = session.query(User) - # EXPECTED_TYPE: Query[User] - reveal_type(q1) + assert_type(q1, Query[User]) a1 = aliased(User) q2 = q1.with_entities(User, a1, User.name).filter(User.id == 5) - # EXPECTED_TYPE: RowReturningQuery[User, User, str] - reveal_type(q2) + assert_type(q2, RowReturningQuery[User, User, str]) - # EXPECTED_TYPE: .*Row[User, User, str].* - reveal_type(q2.one()) + assert_type(q2.one(), Row[User, User, str]) r1 = q2.one() x, y, z = r1 - # EXPECTED_TYPE: User - reveal_type(x) + assert_type(x, User) - # EXPECTED_TYPE: User - reveal_type(y) + assert_type(y, User) - # EXPECTED_TYPE: str - reveal_type(z) + assert_type(z, str) def t_select_add_col_loses_type() -> None: @@ -293,8 +264,7 @@ def t_select_add_col_loses_type() -> None: q2 = q1.add_columns(User.data) # note this should not match Select - # EXPECTED_TYPE: Select[Unpack[.*tuple[Any, ...]]] - reveal_type(q2) + assert_type(q2, Select[Unpack[tuple[Any, ...]]]) def t_legacy_query_add_col_loses_type() -> None: @@ -303,14 +273,12 @@ def t_legacy_query_add_col_loses_type() -> None: q2 = q1.add_columns(User.data) # this should match only Any - # EXPECTED_TYPE: Query[Any] - reveal_type(q2) + assert_type(q2, Query[Any]) ua = aliased(User) q3 = q1.add_entity(ua) - # EXPECTED_TYPE: Query[Any] - reveal_type(q3) + assert_type(q3, Query[Any]) def t_legacy_query_scalar_subquery() -> None: @@ -322,29 +290,25 @@ def t_legacy_query_scalar_subquery() -> None: # this should be int but mypy can't see it due to the # overload that tries to match an entity. - # EXPECTED_RE_TYPE: .*ScalarSelect\[(?:int|Any)\] - reveal_type(q2) + assert_type(q2, ScalarSelect[Any]) q3 = session.query(User) q4 = q3.scalar_subquery() - # EXPECTED_TYPE: ScalarSelect[Any] - reveal_type(q4) + assert_type(q4, ScalarSelect[Any]) q5 = session.query(User, User.name) q6 = q5.scalar_subquery() - # EXPECTED_TYPE: ScalarSelect[Any] - reveal_type(q6) + assert_type(q6, ScalarSelect[Any]) # try to simulate the problem with select() q7 = session.query(User).only_return_tuples(True) q8 = q7.scalar_subquery() - # EXPECTED_TYPE: ScalarSelect[Any] - reveal_type(q8) + assert_type(q8, ScalarSelect[Any]) def t_select_scalar_subquery() -> None: @@ -355,16 +319,14 @@ def t_select_scalar_subquery() -> None: # this should be int but mypy can't see it due to the # overload that tries to match an entity. - # EXPECTED_TYPE: ScalarSelect[Any] - reveal_type(s2) + assert_type(s2, ScalarSelect[Any]) s3 = select(User) s4 = s3.scalar_subquery() # it's more important that mypy doesn't get a false positive of # 'User' here - # EXPECTED_TYPE: ScalarSelect[Any] - reveal_type(s4) + assert_type(s4, ScalarSelect[Any]) def t_select_w_core_selectables() -> None: @@ -374,8 +336,7 @@ def t_select_w_core_selectables() -> None: """ s1 = select(User.id, User.name).subquery() - # EXPECTED_TYPE: KeyedColumnElement[Any] - reveal_type(s1.c.name) + assert_type(s1.c.name, KeyedColumnElement[Any]) s2 = select(User.id, s1.c.name) @@ -386,31 +347,26 @@ def t_select_w_core_selectables() -> None: # mypy would downgrade to Any rather than picking the basemost type. # with typing integrated into Select etc. we can at least get a Select # object back. - # EXPECTED_TYPE: Select[Unpack[.*tuple[Any, ...]]] - reveal_type(s2) + assert_type(s2, Select[Unpack[tuple[Any, ...]]]) # so a fully explicit type may be given s2_typed: Select[tuple[int, str]] = select(User.id, s1.c.name) - # EXPECTED_TYPE: Select[tuple[int, str]] - reveal_type(s2_typed) + assert_type(s2_typed, Select[tuple[int, str]]) # plain FromClause etc we at least get Select s3 = select(s1) - # EXPECTED_TYPE: Select[Unpack[.*tuple[Any, ...]]] - reveal_type(s3) + assert_type(s3, Select[Unpack[tuple[Any, ...]]]) t1 = User.__table__ assert t1 is not None - # EXPECTED_TYPE: FromClause - reveal_type(t1) + assert_type(t1, FromClause) s4 = select(t1) - # EXPECTED_TYPE: Select[Unpack[.*tuple[Any, ...]]] - reveal_type(s4) + assert_type(s4, Select[Unpack[tuple[Any, ...]]]) def t_dml_insert() -> None: @@ -418,53 +374,45 @@ def t_dml_insert() -> None: r1 = session.execute(s1) - # EXPECTED_TYPE: Result[int, str] - reveal_type(r1) + assert_type(r1, Result[int, str]) s2 = insert(User).returning(User) r2 = session.execute(s2) - # EXPECTED_TYPE: Result[User] - reveal_type(r2) + assert_type(r2, Result[User]) s3 = insert(User).returning(func.foo(), column("q")) - # EXPECTED_TYPE: ReturningInsert[Unpack[.*tuple[Any, ...]]] - reveal_type(s3) + assert_type(s3, ReturningInsert[Unpack[tuple[Any, ...]]]) r3 = session.execute(s3) - # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]] - reveal_type(r3) + assert_type(r3, Result[Unpack[tuple[Any, ...]]]) def t_dml_bare_insert() -> None: s1 = insert(User) r1 = session.execute(s1) - # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]] - reveal_type(r1) + assert_type(r1, Result[Unpack[tuple[Any, ...]]]) def t_dml_bare_update() -> None: s1 = update(User) r1 = session.execute(s1) - # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]] - reveal_type(r1) + assert_type(r1, Result[Unpack[tuple[Any, ...]]]) def t_dml_update_with_values() -> None: s1 = update(User).values({User.id: 123, User.data: "value"}) r1 = session.execute(s1) - # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]] - reveal_type(r1) + assert_type(r1, Result[Unpack[tuple[Any, ...]]]) def t_dml_bare_delete() -> None: s1 = delete(User) r1 = session.execute(s1) - # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]] - reveal_type(r1) + assert_type(r1, Result[Unpack[tuple[Any, ...]]]) def t_dml_update() -> None: @@ -472,8 +420,7 @@ def t_dml_update() -> None: r1 = session.execute(s1) - # EXPECTED_TYPE: Result[int, str] - reveal_type(r1) + assert_type(r1, Result[int, str]) def t_dml_delete() -> None: @@ -481,22 +428,19 @@ def t_dml_delete() -> None: r1 = session.execute(s1) - # EXPECTED_TYPE: Result[int, str] - reveal_type(r1) + assert_type(r1, Result[int, str]) def t_from_statement() -> None: t = text("select * from user") - # EXPECTED_TYPE: TextClause - reveal_type(t) + assert_type(t, TextClause) select(User).from_statement(t) ts = text("select * from user").columns(User.id, User.name) - # EXPECTED_TYPE: TextualSelect - reveal_type(ts) + assert_type(ts, TextualSelect) select(User).from_statement(ts) @@ -504,8 +448,7 @@ def t_from_statement() -> None: user_table.c.id, user_table.c.name ) - # EXPECTED_TYPE: TextualSelect - reveal_type(ts2) + assert_type(ts2, TextualSelect) select(User).from_statement(ts2) @@ -519,17 +462,13 @@ def t_aliased_fromclause() -> None: a4 = aliased(user_table) - # EXPECTED_RE_TYPE: [tT]ype\[.*\.User\] - reveal_type(a1) + assert_type(a1, type[User]) - # EXPECTED_RE_TYPE: [tT]ype\[.*\.User\] - reveal_type(a2) + assert_type(a2, type[User]) - # EXPECTED_RE_TYPE: [tT]ype\[.*\.User\] - reveal_type(a3) + assert_type(a3, type[User]) - # EXPECTED_TYPE: FromClause - reveal_type(a4) + assert_type(a4, FromClause) def test_select_from() -> None: diff --git a/test/typing/plain_files/orm/write_only.py b/test/typing/plain_files/orm/write_only.py index 619cde74e8..0ea8663e2d 100644 --- a/test/typing/plain_files/orm/write_only.py +++ b/test/typing/plain_files/orm/write_only.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing +from typing import assert_type from sqlalchemy import ForeignKey from sqlalchemy import select @@ -10,6 +11,7 @@ from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm import WriteOnlyMapped +from sqlalchemy.orm.writeonly import WriteOnlyCollection class Base(DeclarativeBase): @@ -35,16 +37,14 @@ with Session() as session: session.commit() if typing.TYPE_CHECKING: - # EXPECTED_TYPE: WriteOnlyCollection[Address] - reveal_type(u.addresses) + assert_type(u.addresses, WriteOnlyCollection[Address]) address = session.scalars( u.addresses.select().filter(Address.email_address.like("xyz")) ).one() if typing.TYPE_CHECKING: - # EXPECTED_TYPE: Address - reveal_type(address) + assert_type(address, Address) u.addresses.add(Address()) u.addresses.add_all([Address(), Address()]) diff --git a/test/typing/plain_files/sql/common_sql_element.py b/test/typing/plain_files/sql/common_sql_element.py index 3428a640df..e8a10e553d 100644 --- a/test/typing/plain_files/sql/common_sql_element.py +++ b/test/typing/plain_files/sql/common_sql_element.py @@ -8,6 +8,10 @@ unions. from __future__ import annotations +from typing import assert_type +from typing import Never +from typing import Unpack + from sqlalchemy import asc from sqlalchemy import Column from sqlalchemy import column @@ -20,16 +24,21 @@ from sqlalchemy import intersect from sqlalchemy import intersect_all from sqlalchemy import literal from sqlalchemy import MetaData +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import SQLColumnExpression from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import union from sqlalchemy import union_all +from sqlalchemy.engine import Result from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import Session +from sqlalchemy.orm.query import RowReturningQuery +from sqlalchemy.sql.expression import BindParameter +from sqlalchemy.sql.expression import CompoundSelect class Base(DeclarativeBase): @@ -67,26 +76,22 @@ def core_expr(email: str) -> SQLColumnExpression[bool]: e1 = orm_expr("hi") -# EXPECTED_TYPE: SQLColumnExpression[bool] -reveal_type(e1) +assert_type(e1, SQLColumnExpression[bool]) stmt = select(e1) -# EXPECTED_TYPE: Select[bool] -reveal_type(stmt) +assert_type(stmt, Select[bool]) stmt = stmt.where(e1) e2 = core_expr("hi") -# EXPECTED_TYPE: SQLColumnExpression[bool] -reveal_type(e2) +assert_type(e2, SQLColumnExpression[bool]) stmt = select(e2) -# EXPECTED_TYPE: Select[bool] -reveal_type(stmt) +assert_type(stmt, Select[bool]) stmt = stmt.where(e2) @@ -95,20 +100,17 @@ stmt2 = select(User.id).order_by("id", "email").group_by("email", "id") stmt2 = ( select(User.id).order_by(asc("id"), desc("email")).group_by("email", "id") ) -# EXPECTED_TYPE: Select[int] -reveal_type(stmt2) +assert_type(stmt2, Select[int]) stmt2 = select(User.id).order_by(User.id).group_by(User.email) stmt2 = ( select(User.id).order_by(User.id, User.email).group_by(User.email, User.id) ) -# EXPECTED_TYPE: Select[int] -reveal_type(stmt2) +assert_type(stmt2, Select[int]) stmt3 = select(User.id).exists().select() -# EXPECTED_TYPE: Select[bool] -reveal_type(stmt3) +assert_type(stmt3, Select[bool]) receives_str_col_expr(User.email) @@ -129,8 +131,7 @@ receives_bool_col_expr(user_table.c.email == "x") q1 = Session().query(User.id).order_by("email").group_by("email") q1 = Session().query(User.id).order_by("id", "email").group_by("email", "id") -# EXPECTED_TYPE: RowReturningQuery[int] -reveal_type(q1) +assert_type(q1, RowReturningQuery[int]) q1 = Session().query(User.id).order_by(User.id).group_by(User.email) q1 = ( @@ -139,8 +140,7 @@ q1 = ( .order_by(User.id, User.email) .group_by(User.email, User.id) ) -# EXPECTED_TYPE: RowReturningQuery[int] -reveal_type(q1) +assert_type(q1, RowReturningQuery[int]) # test 9174 s9174_1 = select(User).with_for_update(of=User) @@ -164,14 +164,10 @@ user = session.query(user_table).with_for_update( ) # literal -# EXPECTED_TYPE: BindParameter[str] -reveal_type(literal("5")) -# EXPECTED_TYPE: BindParameter[str] -reveal_type(literal("5", None)) -# EXPECTED_TYPE: BindParameter[int] -reveal_type(literal("123", Integer)) -# EXPECTED_TYPE: BindParameter[int] -reveal_type(literal("123", Integer)) +assert_type(literal("5"), BindParameter[str]) +assert_type(literal("5", None), BindParameter[str]) +assert_type(literal("123", Integer), BindParameter[int]) +assert_type(literal("123", Integer), BindParameter[int]) # hashable (issue #10353): @@ -193,36 +189,26 @@ first_stmt = select(str_col, int_col) second_stmt = select(str_col, int_col) third_stmt = select(int_col, str_col) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(union(first_stmt, second_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(union_all(first_stmt, second_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(except_(first_stmt, second_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(except_all(first_stmt, second_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(intersect(first_stmt, second_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(intersect_all(first_stmt, second_stmt)) - -# EXPECTED_TYPE: Result[str, int] -reveal_type(Session().execute(union(first_stmt, second_stmt))) -# EXPECTED_TYPE: Result[str, int] -reveal_type(Session().execute(union_all(first_stmt, second_stmt))) - -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(first_stmt.union(second_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(first_stmt.union_all(second_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(first_stmt.except_(second_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(first_stmt.except_all(second_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(first_stmt.intersect(second_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(first_stmt.intersect_all(second_stmt)) +assert_type(union(first_stmt, second_stmt), CompoundSelect[str, int]) +assert_type(union_all(first_stmt, second_stmt), CompoundSelect[str, int]) +assert_type(except_(first_stmt, second_stmt), CompoundSelect[str, int]) +assert_type(except_all(first_stmt, second_stmt), CompoundSelect[str, int]) +assert_type(intersect(first_stmt, second_stmt), CompoundSelect[str, int]) +assert_type(intersect_all(first_stmt, second_stmt), CompoundSelect[str, int]) + +assert_type( + Session().execute(union(first_stmt, second_stmt)), Result[str, int] +) +assert_type( + Session().execute(union_all(first_stmt, second_stmt)), Result[str, int] +) + +assert_type(first_stmt.union(second_stmt), CompoundSelect[str, int]) +assert_type(first_stmt.union_all(second_stmt), CompoundSelect[str, int]) +assert_type(first_stmt.except_(second_stmt), CompoundSelect[str, int]) +assert_type(first_stmt.except_all(second_stmt), CompoundSelect[str, int]) +assert_type(first_stmt.intersect(second_stmt), CompoundSelect[str, int]) +assert_type(first_stmt.intersect_all(second_stmt), CompoundSelect[str, int]) # TODO: the following do not error because _SelectStatementForCompoundArgument # includes untyped elements so the type checker falls back on them when @@ -230,28 +216,32 @@ reveal_type(first_stmt.intersect_all(second_stmt)) # looses the plot and returns a random type back. See TODO in the # overloads -# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]] -reveal_type(union(first_stmt, third_stmt)) -# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]] -reveal_type(union_all(first_stmt, third_stmt)) -# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]] -reveal_type(except_(first_stmt, third_stmt)) -# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]] -reveal_type(except_all(first_stmt, third_stmt)) -# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]] -reveal_type(intersect(first_stmt, third_stmt)) -# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]] -reveal_type(intersect_all(first_stmt, third_stmt)) - -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(first_stmt.union(third_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(first_stmt.union_all(third_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(first_stmt.except_(third_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(first_stmt.except_all(third_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(first_stmt.intersect(third_stmt)) -# EXPECTED_TYPE: CompoundSelect[str, int] -reveal_type(first_stmt.intersect_all(third_stmt)) +assert_type( + union(first_stmt, third_stmt), CompoundSelect[Unpack[tuple[Never, ...]]] +) +assert_type( + union_all(first_stmt, third_stmt), + CompoundSelect[Unpack[tuple[Never, ...]]], +) +assert_type( + except_(first_stmt, third_stmt), CompoundSelect[Unpack[tuple[Never, ...]]] +) +assert_type( + except_all(first_stmt, third_stmt), + CompoundSelect[Unpack[tuple[Never, ...]]], +) +assert_type( + intersect(first_stmt, third_stmt), + CompoundSelect[Unpack[tuple[Never, ...]]], +) +assert_type( + intersect_all(first_stmt, third_stmt), + CompoundSelect[Unpack[tuple[Never, ...]]], +) + +assert_type(first_stmt.union(third_stmt), CompoundSelect[str, int]) +assert_type(first_stmt.union_all(third_stmt), CompoundSelect[str, int]) +assert_type(first_stmt.except_(third_stmt), CompoundSelect[str, int]) +assert_type(first_stmt.except_all(third_stmt), CompoundSelect[str, int]) +assert_type(first_stmt.intersect(third_stmt), CompoundSelect[str, int]) +assert_type(first_stmt.intersect_all(third_stmt), CompoundSelect[str, int]) diff --git a/test/typing/plain_files/sql/functions.py b/test/typing/plain_files/sql/functions.py index 3660417887..beb72c4df6 100644 --- a/test/typing/plain_files/sql/functions.py +++ b/test/typing/plain_files/sql/functions.py @@ -1,11 +1,18 @@ """this file is generated by tools/generate_sql_functions.py""" +from datetime import date +from datetime import datetime +from datetime import time +from decimal import Decimal +from typing import assert_type +from typing import Sequence + from sqlalchemy import column from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import Select from sqlalchemy import select -from sqlalchemy import Sequence +from sqlalchemy import Sequence as SqlAlchemySequence from sqlalchemy import String # START GENERATED FUNCTION TYPING TESTS @@ -15,152 +22,127 @@ from sqlalchemy import String stmt1 = select(func.aggregate_strings(column("x", String), ",")) -# EXPECTED_RE_TYPE: .*Select\[.*str\] -reveal_type(stmt1) +assert_type(stmt1, Select[str]) stmt2 = select(func.array_agg(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[.*Sequence\[.*int\]\] -reveal_type(stmt2) +assert_type(stmt2, Select[Sequence[int]]) stmt3 = select(func.char_length(column("x"))) -# EXPECTED_RE_TYPE: .*Select\[.*int\] -reveal_type(stmt3) +assert_type(stmt3, Select[int]) stmt4 = select(func.coalesce(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[.*int\] -reveal_type(stmt4) +assert_type(stmt4, Select[int]) stmt5 = select(func.concat()) -# EXPECTED_RE_TYPE: .*Select\[.*str\] -reveal_type(stmt5) +assert_type(stmt5, Select[str]) stmt6 = select(func.count(column("x"))) -# EXPECTED_RE_TYPE: .*Select\[.*int\] -reveal_type(stmt6) +assert_type(stmt6, Select[int]) stmt7 = select(func.cume_dist()) -# EXPECTED_RE_TYPE: .*Select\[.*Decimal\] -reveal_type(stmt7) +assert_type(stmt7, Select[Decimal]) stmt8 = select(func.current_date()) -# EXPECTED_RE_TYPE: .*Select\[.*date\] -reveal_type(stmt8) +assert_type(stmt8, Select[date]) stmt9 = select(func.current_time()) -# EXPECTED_RE_TYPE: .*Select\[.*time\] -reveal_type(stmt9) +assert_type(stmt9, Select[time]) stmt10 = select(func.current_timestamp()) -# EXPECTED_RE_TYPE: .*Select\[.*datetime\] -reveal_type(stmt10) +assert_type(stmt10, Select[datetime]) stmt11 = select(func.current_user()) -# EXPECTED_RE_TYPE: .*Select\[.*str\] -reveal_type(stmt11) +assert_type(stmt11, Select[str]) stmt12 = select(func.dense_rank()) -# EXPECTED_RE_TYPE: .*Select\[.*int\] -reveal_type(stmt12) +assert_type(stmt12, Select[int]) stmt13 = select(func.localtime()) -# EXPECTED_RE_TYPE: .*Select\[.*datetime\] -reveal_type(stmt13) +assert_type(stmt13, Select[datetime]) stmt14 = select(func.localtimestamp()) -# EXPECTED_RE_TYPE: .*Select\[.*datetime\] -reveal_type(stmt14) +assert_type(stmt14, Select[datetime]) stmt15 = select(func.max(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[.*int\] -reveal_type(stmt15) +assert_type(stmt15, Select[int]) stmt16 = select(func.min(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[.*int\] -reveal_type(stmt16) +assert_type(stmt16, Select[int]) -stmt17 = select(func.next_value(Sequence("x_seq"))) +stmt17 = select(func.next_value(SqlAlchemySequence("x_seq"))) -# EXPECTED_RE_TYPE: .*Select\[.*int\] -reveal_type(stmt17) +assert_type(stmt17, Select[int]) stmt18 = select(func.now()) -# EXPECTED_RE_TYPE: .*Select\[.*datetime\] -reveal_type(stmt18) +assert_type(stmt18, Select[datetime]) stmt19 = select(func.percent_rank()) -# EXPECTED_RE_TYPE: .*Select\[.*Decimal\] -reveal_type(stmt19) +assert_type(stmt19, Select[Decimal]) stmt20 = select(func.pow(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[.*int\] -reveal_type(stmt20) +assert_type(stmt20, Select[int]) stmt21 = select(func.rank()) -# EXPECTED_RE_TYPE: .*Select\[.*int\] -reveal_type(stmt21) +assert_type(stmt21, Select[int]) stmt22 = select(func.session_user()) -# EXPECTED_RE_TYPE: .*Select\[.*str\] -reveal_type(stmt22) +assert_type(stmt22, Select[str]) stmt23 = select(func.sum(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[.*int\] -reveal_type(stmt23) +assert_type(stmt23, Select[int]) stmt24 = select(func.sysdate()) -# EXPECTED_RE_TYPE: .*Select\[.*datetime\] -reveal_type(stmt24) +assert_type(stmt24, Select[datetime]) stmt25 = select(func.user()) -# EXPECTED_RE_TYPE: .*Select\[.*str\] -reveal_type(stmt25) +assert_type(stmt25, Select[str]) # END GENERATED FUNCTION TYPING TESTS diff --git a/test/typing/plain_files/sql/functions_again.py b/test/typing/plain_files/sql/functions_again.py index fc000277d0..1be8c5ce78 100644 --- a/test/typing/plain_files/sql/functions_again.py +++ b/test/typing/plain_files/sql/functions_again.py @@ -1,10 +1,21 @@ +from typing import Any +from typing import assert_type + from sqlalchemy import column from sqlalchemy import func +from sqlalchemy import Function from sqlalchemy import Integer +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.sql.expression import FunctionFilter +from sqlalchemy.sql.expression import Over +from sqlalchemy.sql.expression import WithinGroup +from sqlalchemy.sql.functions import coalesce +from sqlalchemy.sql.functions import max as functions_max +from sqlalchemy.sql.selectable import TableValuedAlias class Base(DeclarativeBase): @@ -20,53 +31,45 @@ class Foo(Base): c: Mapped[str] -# EXPECTED_TYPE: Over[Any] -reveal_type(func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc())) +assert_type( + func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc()), + Over[Any], +) func.row_number().over(order_by=[Foo.a.desc(), Foo.b.desc()]) func.row_number().over(partition_by=[Foo.a.desc(), Foo.b.desc()]) func.row_number().over(order_by="a", partition_by=("a", "b")) func.row_number().over(partition_by="a", order_by=("a", "b")) -# EXPECTED_TYPE: Function[Any] -reveal_type(func.row_number().filter()) -# EXPECTED_TYPE: FunctionFilter[Any] -reveal_type(func.row_number().filter(Foo.a > 0)) -# EXPECTED_TYPE: FunctionFilter[Any] -reveal_type(func.row_number().within_group(Foo.a).filter(Foo.b < 0)) -# EXPECTED_TYPE: WithinGroup[Any] -reveal_type(func.row_number().within_group(Foo.a)) -# EXPECTED_TYPE: WithinGroup[Any] -reveal_type(func.row_number().filter(Foo.a > 0).within_group(Foo.a)) -# EXPECTED_TYPE: Over[Any] -reveal_type(func.row_number().filter(Foo.a > 0).over()) -# EXPECTED_TYPE: Over[Any] -reveal_type(func.row_number().within_group(Foo.a).over()) +assert_type(func.row_number().filter(), Function[Any]) +assert_type(func.row_number().filter(Foo.a > 0), FunctionFilter[Any]) +assert_type( + func.row_number().within_group(Foo.a).filter(Foo.b < 0), + FunctionFilter[Any], +) +assert_type(func.row_number().within_group(Foo.a), WithinGroup[Any]) +assert_type( + func.row_number().filter(Foo.a > 0).within_group(Foo.a), WithinGroup[Any] +) +assert_type(func.row_number().filter(Foo.a > 0).over(), Over[Any]) +assert_type(func.row_number().within_group(Foo.a).over(), Over[Any]) # test #10801 -# EXPECTED_TYPE: max[int] -reveal_type(func.max(Foo.b)) +assert_type(func.max(Foo.b), functions_max[int]) stmt1 = select(Foo.a, func.min(Foo.b)).group_by(Foo.a) -# EXPECTED_TYPE: Select[int, int] -reveal_type(stmt1) +assert_type(stmt1, Select[int, int]) # test #10818 -# EXPECTED_TYPE: coalesce[str] -reveal_type(func.coalesce(Foo.c, "a", "b")) -# EXPECTED_TYPE: coalesce[str] -reveal_type(func.coalesce("a", "b")) -# EXPECTED_TYPE: coalesce[int] -reveal_type(func.coalesce(column("x", Integer), 3)) +assert_type(func.coalesce(Foo.c, "a", "b"), coalesce[str]) +assert_type(func.coalesce("a", "b"), coalesce[str]) +assert_type(func.coalesce(column("x", Integer), 3), coalesce[int]) stmt2 = select(Foo.a, func.coalesce(Foo.c, "a", "b")).group_by(Foo.a) -# EXPECTED_TYPE: Select[int, str] -reveal_type(stmt2) +assert_type(stmt2, Select[int, str]) -# EXPECTED_TYPE: TableValuedAlias -reveal_type(func.json_each().table_valued("key", "value")) -# EXPECTED_TYPE: TableValuedAlias -reveal_type(func.json_each().table_valued(Foo.a, Foo.b)) +assert_type(func.json_each().table_valued("key", "value"), TableValuedAlias) +assert_type(func.json_each().table_valued(Foo.a, Foo.b), TableValuedAlias) diff --git a/test/typing/plain_files/sql/lambda_stmt.py b/test/typing/plain_files/sql/lambda_stmt.py index 035fde800d..1725a57b33 100644 --- a/test/typing/plain_files/sql/lambda_stmt.py +++ b/test/typing/plain_files/sql/lambda_stmt.py @@ -1,6 +1,9 @@ from __future__ import annotations +from typing import Any +from typing import assert_type from typing import TYPE_CHECKING +from typing import Unpack from sqlalchemy import Column from sqlalchemy import create_engine @@ -11,9 +14,11 @@ from sqlalchemy import Result from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.sql.lambdas import StatementLambdaElement class Base(DeclarativeBase): @@ -48,11 +53,9 @@ s6 = lambda_stmt(lambda: select(User)) + (lambda s: s.where(User.id == 5)) if TYPE_CHECKING: - # EXPECTED_TYPE: StatementLambdaElement - reveal_type(s5) + assert_type(s5, StatementLambdaElement) - # EXPECTED_TYPE: StatementLambdaElement - reveal_type(s6) + assert_type(s6, StatementLambdaElement) e = create_engine("sqlite://") @@ -61,8 +64,7 @@ with e.connect() as conn: result = conn.execute(s6) if TYPE_CHECKING: - # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]] - reveal_type(result) + assert_type(result, CursorResult[Unpack[tuple[Any, ...]]]) # we can type these like this my_result: Result[User] = conn.execute(s6) @@ -71,5 +73,4 @@ with e.connect() as conn: # pyright and mypy disagree on the specific type here, # mypy sees Result as we said, pyright seems to upgrade it to # CursorResult - # EXPECTED_RE_TYPE: .*(?:Cursor)?Result\[.*User\] - reveal_type(my_result) + assert_type(my_result, Result[User]) diff --git a/test/typing/plain_files/sql/misc.py b/test/typing/plain_files/sql/misc.py index 2a9e539dc3..338ee98007 100644 --- a/test/typing/plain_files/sql/misc.py +++ b/test/typing/plain_files/sql/misc.py @@ -1,10 +1,12 @@ from typing import Any +from typing import assert_type from sqlalchemy import column from sqlalchemy import ColumnElement from sqlalchemy import Integer from sqlalchemy import literal from sqlalchemy import table +from sqlalchemy.sql.expression import ColumnClause def test_col_accessors() -> None: @@ -27,11 +29,10 @@ def test_col_get() -> None: col_alt = column("alt", Integer) tbl = table("mytable", col_id) - # EXPECTED_TYPE: ColumnClause[Any] | None - reveal_type(tbl.c.get("id")) - # EXPECTED_TYPE: ColumnClause[Any] | None - reveal_type(tbl.c.get("id", None)) - # EXPECTED_TYPE: ColumnClause[Any] | ColumnClause[int] - reveal_type(tbl.c.get("alt", col_alt)) + assert_type(tbl.c.get("id"), ColumnClause[Any] | None) + assert_type(tbl.c.get("id", None), ColumnClause[Any] | None) + assert_type( + tbl.c.get("alt", col_alt), ColumnClause[Any] | ColumnClause[int] + ) col: ColumnElement[Any] = tbl.c.get("foo", literal("bar")) print(col) diff --git a/test/typing/plain_files/sql/operators.py b/test/typing/plain_files/sql/operators.py index c09029fc14..e9680afe71 100644 --- a/test/typing/plain_files/sql/operators.py +++ b/test/typing/plain_files/sql/operators.py @@ -1,6 +1,7 @@ import datetime as dt from decimal import Decimal from typing import Any +from typing import assert_type from typing import List from sqlalchemy import ARRAY @@ -15,6 +16,8 @@ from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.sql import operators +from sqlalchemy.sql.elements import Grouping +from sqlalchemy.sql.expression import BinaryExpression class Base(DeclarativeBase): @@ -147,15 +150,14 @@ op_e: "ColumnElement[bool]" = col.bool_op("&")("1") op_a1 = col.op("&")(1) -# EXPECTED_TYPE: BinaryExpression[Any] -reveal_type(op_a1) +assert_type(op_a1, BinaryExpression[Any]) # op functions t1 = operators.eq(A.id, 1) select().where(t1) -# EXPECTED_TYPE: BinaryExpression[Any] -reveal_type(col.op("->>")("field")) -# EXPECTED_TYPE: BinaryExpression[Any] | Grouping[Any] -reveal_type(col.op("->>")("field").self_group()) +assert_type(col.op("->>")("field"), BinaryExpression[Any]) +assert_type( + col.op("->>")("field").self_group(), BinaryExpression[Any] | Grouping[Any] +) diff --git a/test/typing/plain_files/sql/sql_operations.py b/test/typing/plain_files/sql/sql_operations.py index f0025b2cb3..ef3b2dc390 100644 --- a/test/typing/plain_files/sql/sql_operations.py +++ b/test/typing/plain_files/sql/sql_operations.py @@ -1,4 +1,7 @@ +from decimal import Decimal import typing +from typing import Any +from typing import assert_type from sqlalchemy import and_ from sqlalchemy import Boolean @@ -13,6 +16,10 @@ from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import true +from sqlalchemy.sql.elements import ColumnElement +from sqlalchemy.sql.elements import UnaryExpression +from sqlalchemy.sql.expression import BinaryExpression +from sqlalchemy.sql.expression import ColumnClause # builtin.pyi stubs define object.__eq__() as returning bool, which @@ -99,57 +106,38 @@ def test_issue_9650_char() -> None: def test_issue_9650_bitwise() -> None: - # EXPECTED_TYPE: BinaryExpression[Any] - reveal_type(c2.bitwise_and(5)) - # EXPECTED_TYPE: BinaryExpression[Any] - reveal_type(c2.bitwise_or(5)) - # EXPECTED_TYPE: BinaryExpression[Any] - reveal_type(c2.bitwise_xor(5)) - # EXPECTED_TYPE: UnaryExpression[int] - reveal_type(c2.bitwise_not()) - # EXPECTED_TYPE: BinaryExpression[Any] - reveal_type(c2.bitwise_lshift(5)) - # EXPECTED_TYPE: BinaryExpression[Any] - reveal_type(c2.bitwise_rshift(5)) - # EXPECTED_TYPE: ColumnElement[int] - reveal_type(c2 << 5) - # EXPECTED_TYPE: ColumnElement[int] - reveal_type(c2 >> 5) + assert_type(c2.bitwise_and(5), BinaryExpression[Any]) + assert_type(c2.bitwise_or(5), BinaryExpression[Any]) + assert_type(c2.bitwise_xor(5), BinaryExpression[Any]) + assert_type(c2.bitwise_not(), UnaryExpression[int]) + assert_type(c2.bitwise_lshift(5), BinaryExpression[Any]) + assert_type(c2.bitwise_rshift(5), BinaryExpression[Any]) + assert_type(c2 << 5, ColumnElement[int]) + assert_type(c2 >> 5, ColumnElement[int]) if typing.TYPE_CHECKING: # as far as if this is ColumnElement, BinaryElement, SQLCoreOperations, # that might change. main thing is it's SomeSQLColThing[bool] and # not 'bool' or 'Any'. - # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.bool\] - reveal_type(expr1) + assert_type(expr1, ColumnElement[bool]) - # EXPECTED_RE_TYPE: sqlalchemy..*ColumnClause\[builtins.str.?\] - reveal_type(c1) + assert_type(c1, ColumnClause[str]) - # EXPECTED_RE_TYPE: sqlalchemy..*ColumnClause\[builtins.int.?\] - reveal_type(c2) + assert_type(c2, ColumnClause[int]) - # EXPECTED_RE_TYPE: sqlalchemy..*BinaryExpression\[builtins.bool\] - reveal_type(expr2) + assert_type(expr2, BinaryExpression[bool]) - # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.float | .*\.Decimal\] - reveal_type(expr3) + assert_type(expr3, ColumnElement[float | Decimal]) - # EXPECTED_RE_TYPE: sqlalchemy..*UnaryExpression\[builtins.int.?\] - reveal_type(expr4) + assert_type(expr4, UnaryExpression[int]) - # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.bool.?\] - reveal_type(expr5) + assert_type(expr5, ColumnElement[bool]) - # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.bool.?\] - reveal_type(expr6) + assert_type(expr6, ColumnElement[bool]) - # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.str\] - reveal_type(expr7) + assert_type(expr7, ColumnElement[str]) - # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.int.?\] - reveal_type(expr8) + assert_type(expr8, ColumnElement[int]) - # EXPECTED_TYPE: BinaryExpression[bool] - reveal_type(expr9) + assert_type(expr9, BinaryExpression[bool]) diff --git a/test/typing/plain_files/sql/sqltypes.py b/test/typing/plain_files/sql/sqltypes.py index 230cb957d4..0b5cc1bc92 100644 --- a/test/typing/plain_files/sql/sqltypes.py +++ b/test/typing/plain_files/sql/sqltypes.py @@ -1,12 +1,11 @@ +from decimal import Decimal +from typing import assert_type + from sqlalchemy import Float from sqlalchemy import Numeric -# EXPECTED_TYPE: Float[float] -reveal_type(Float()) -# EXPECTED_TYPE: Float[Decimal] -reveal_type(Float(asdecimal=True)) +assert_type(Float(), Float[float]) +assert_type(Float(asdecimal=True), Float[Decimal]) -# EXPECTED_TYPE: Numeric[Decimal] -reveal_type(Numeric()) -# EXPECTED_TYPE: Numeric[float] -reveal_type(Numeric(asdecimal=False)) +assert_type(Numeric(), Numeric[Decimal]) +assert_type(Numeric(asdecimal=False), Numeric[float]) diff --git a/test/typing/plain_files/sql/typed_results.py b/test/typing/plain_files/sql/typed_results.py index a544aa434f..98dde5ad9f 100644 --- a/test/typing/plain_files/sql/typed_results.py +++ b/test/typing/plain_files/sql/typed_results.py @@ -1,9 +1,13 @@ from __future__ import annotations import asyncio +from typing import Any +from typing import assert_type from typing import cast from typing import Optional +from typing import Sequence from typing import Type +from typing import Unpack from sqlalchemy import Column from sqlalchemy import column @@ -19,9 +23,17 @@ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import table +from sqlalchemy.engine import Result +from sqlalchemy.engine.cursor import CursorResult +from sqlalchemy.engine.result import MappingResult +from sqlalchemy.engine.result import ScalarResult +from sqlalchemy.engine.result import TupleResult +from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio.result import AsyncScalarResult +from sqlalchemy.ext.asyncio.result import AsyncTupleResult from sqlalchemy.orm import aliased from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped @@ -66,8 +78,7 @@ async def async_connect() -> AsyncConnection: async_connection = asyncio.run(async_connect()) -# EXPECTED_RE_TYPE: sqlalchemy..*AsyncConnection\*? -reveal_type(async_connection) +assert_type(async_connection, AsyncConnection) async_session = AsyncSession(async_connection) @@ -81,41 +92,33 @@ user = session.query(User).one() user_iter = iter(session.scalars(select(User))) -# EXPECTED_RE_TYPE: sqlalchemy..*AsyncSession\*? -reveal_type(async_session) +assert_type(async_session, AsyncSession) single_stmt = select(User.name).where(User.name == "foo") -# EXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[builtins.str\*?\] -reveal_type(single_stmt) +assert_type(single_stmt, Select[str]) multi_stmt = select(User.id, User.name).where(User.name == "foo") -# EXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[builtins.int\*?, builtins.str\*?\] -reveal_type(multi_stmt) +assert_type(multi_stmt, Select[int, str]) def t_result_ctxmanager() -> None: with connection.execute(select(column("q", Integer))) as r1: - # EXPECTED_TYPE: CursorResult[int] - reveal_type(r1) + assert_type(r1, CursorResult[int]) with r1.mappings() as r1m: - # EXPECTED_TYPE: MappingResult - reveal_type(r1m) + assert_type(r1m, MappingResult) with connection.scalars(select(column("q", Integer))) as r2: - # EXPECTED_TYPE: ScalarResult[int] - reveal_type(r2) + assert_type(r2, ScalarResult[int]) with session.execute(select(User.id)) as r3: - # EXPECTED_TYPE: Result[int] - reveal_type(r3) + assert_type(r3, Result[int]) with session.scalars(select(User.id)) as r4: - # EXPECTED_TYPE: ScalarResult[int] - reveal_type(r4) + assert_type(r4, ScalarResult[int]) def t_mappings() -> None: @@ -143,22 +146,18 @@ def t_entity_varieties() -> None: r1 = session.execute(s1) - # EXPECTED_RE_TYPE: sqlalchemy..*.Result\[builtins.int\*?, typed_results.User\*?, builtins.str\*?\] - reveal_type(r1) + assert_type(r1, Result[int, User, str]) s2 = select(User, a1).where(User.name == "foo") r2 = session.execute(s2) - # EXPECTED_RE_TYPE: sqlalchemy.*Result\[typed_results.User\*?, typed_results.User\*?\] - reveal_type(r2) + assert_type(r2, Result[User, User]) row = r2.t.one() - # EXPECTED_RE_TYPE: .*typed_results.User\*? - reveal_type(row[0]) - # EXPECTED_RE_TYPE: .*typed_results.User\*? - reveal_type(row[1]) + assert_type(row[0], User) + assert_type(row[1], User) # testing that plain Mapped[x] gets picked up as well as # aliased class @@ -166,19 +165,17 @@ def t_entity_varieties() -> None: # automatically typed since they are dynamically generated a1_id = cast(Mapped[int], a1.id) s3 = select(User.id, a1_id, a1, User).where(User.name == "foo") - # EXPECTED_RE_TYPE: sqlalchemy.*Select\*?\[builtins.int\*?, builtins.int\*?, typed_results.User\*?, typed_results.User\*?\] - reveal_type(s3) + assert_type(s3, Select[int, int, User, User]) # testing Mapped[entity] some_mp = cast(Mapped[User], object()) s4 = select(some_mp, a1, User).where(User.name == "foo") - # NOTEXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[typed_results.User\*?, typed_results.User\*?, typed_results.User\*?\] + # NOTEXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[User\*?, User\*?, User\*?\] - # sqlalchemy.sql._gen_overloads.Select[typed_results.User, typed_results.User, typed_results.User] + # sqlalchemy.sql._gen_overloads.Select[User, User, User] - # EXPECTED_TYPE: Select[User, User, User] - reveal_type(s4) + assert_type(s4, Select[User, User, User]) # test plain core expressions x = Column("x", Integer) @@ -186,43 +183,36 @@ def t_entity_varieties() -> None: s5 = select(x, y, User.name + "hi") - # EXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[builtins.int\*?, builtins.int\*?\, builtins.str\*?] - reveal_type(s5) + assert_type(s5, Select[int, int, str]) def t_ambiguous_result_type_one() -> None: stmt = select(column("q", Integer), table("x", column("y"))) - # EXPECTED_TYPE: Select[Unpack[.*tuple[Any, ...]]] - reveal_type(stmt) + assert_type(stmt, Select[Unpack[tuple[Any, ...]]]) result = session.execute(stmt) - # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]] - reveal_type(result) + assert_type(result, Result[Unpack[tuple[Any, ...]]]) def t_ambiguous_result_type_two() -> None: stmt = select(column("q")) - # EXPECTED_TYPE: Select[Any] - reveal_type(stmt) + assert_type(stmt, Select[Any]) result = session.execute(stmt) - # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]] - reveal_type(result) + assert_type(result, Result[Unpack[tuple[Any, ...]]]) def t_aliased() -> None: a1 = aliased(User) s1 = select(a1) - # EXPECTED_TYPE: Select[User] - reveal_type(s1) + assert_type(s1, Select[User]) s4 = select(a1.name, a1, a1, User).where(User.name == "foo") - # EXPECTED_TYPE: Select[str, User, User, User] - reveal_type(s4) + assert_type(s4, Select[str, User, User, User]) def t_result_scalar_accessors() -> None: @@ -230,28 +220,23 @@ def t_result_scalar_accessors() -> None: r1 = result.scalar() - # EXPECTED_RE_TYPE: builtins.str \| None - reveal_type(r1) + assert_type(r1, str | None) r2 = result.scalar_one() - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(r2) + assert_type(r2, str) r3 = result.scalar_one_or_none() - # EXPECTED_RE_TYPE: builtins.str \| None - reveal_type(r3) + assert_type(r3, str | None) r4 = result.scalars() - # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\] - reveal_type(r4) + assert_type(r4, ScalarResult[str]) r5 = result.scalars(0) - # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\] - reveal_type(r5) + assert_type(r5, ScalarResult[str]) async def t_async_result_scalar_accessors() -> None: @@ -259,28 +244,23 @@ async def t_async_result_scalar_accessors() -> None: r1 = await result.scalar() - # EXPECTED_RE_TYPE: builtins.str \| None - reveal_type(r1) + assert_type(r1, str | None) r2 = await result.scalar_one() - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(r2) + assert_type(r2, str) r3 = await result.scalar_one_or_none() - # EXPECTED_RE_TYPE: builtins.str \| None - reveal_type(r3) + assert_type(r3, str | None) r4 = result.scalars() - # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\] - reveal_type(r4) + assert_type(r4, AsyncScalarResult[str]) r5 = result.scalars(0) - # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\] - reveal_type(r5) + assert_type(r5, AsyncScalarResult[str]) def t_result_insertmanyvalues_scalars() -> None: @@ -295,8 +275,7 @@ def t_result_insertmanyvalues_scalars() -> None: ], ).all() - # EXPECTED_TYPE: Sequence[int] - reveal_type(uids1) + assert_type(uids1, Sequence[int]) uids2 = ( connection.execute( @@ -311,8 +290,7 @@ def t_result_insertmanyvalues_scalars() -> None: .all() ) - # EXPECTED_TYPE: Sequence[int] - reveal_type(uids2) + assert_type(uids2, Sequence[int]) async def t_async_result_insertmanyvalues_scalars() -> None: @@ -329,8 +307,7 @@ async def t_async_result_insertmanyvalues_scalars() -> None: ) ).all() - # EXPECTED_TYPE: Sequence[int] - reveal_type(uids1) + assert_type(uids1, Sequence[int]) uids2 = ( ( @@ -347,344 +324,279 @@ async def t_async_result_insertmanyvalues_scalars() -> None: .all() ) - # EXPECTED_TYPE: Sequence[int] - reveal_type(uids2) + assert_type(uids2, Sequence[int]) def t_connection_execute_multi_row_t() -> None: result = connection.execute(multi_stmt) - # EXPECTED_RE_TYPE: sqlalchemy.*CursorResult\[builtins.int\*?, builtins.str\*?\] - reveal_type(result) + assert_type(result, CursorResult[int, str]) row = result.one() - # EXPECTED_RE_TYPE: .*sqlalchemy.*Row\[builtins.int\*?, builtins.str\*?\].* - reveal_type(row) + assert_type(row, Row[int, str]) x, y = row.t - # EXPECTED_RE_TYPE: builtins.int\*? - reveal_type(x) + assert_type(x, int) - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(y) + assert_type(y, str) def t_connection_execute_multi() -> None: result = connection.execute(multi_stmt).t - # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.int\*?, builtins.str\*?\]\] - reveal_type(result) + assert_type(result, TupleResult[tuple[int, str]]) row = result.one() - # EXPECTED_RE_TYPE: tuple\[builtins.int\*?, builtins.str\*?\] - reveal_type(row) + assert_type(row, tuple[int, str]) x, y = row - # EXPECTED_RE_TYPE: builtins.int\*? - reveal_type(x) + assert_type(x, int) - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(y) + assert_type(y, str) def t_connection_execute_single() -> None: result = connection.execute(single_stmt).t - # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.str\*?\]\] - reveal_type(result) + assert_type(result, TupleResult[tuple[str]]) row = result.one() - # EXPECTED_RE_TYPE: tuple\[builtins.str\*?\] - reveal_type(row) + assert_type(row, tuple[str]) (x,) = row - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(x) + assert_type(x, str) def t_connection_execute_single_row_scalar() -> None: result = connection.execute(single_stmt).t - # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.str\*?\]\] - reveal_type(result) + assert_type(result, TupleResult[tuple[str]]) x = result.scalar() - # EXPECTED_RE_TYPE: builtins.str \| None - reveal_type(x) + assert_type(x, str | None) def t_connection_scalar() -> None: obj = connection.scalar(single_stmt) - # EXPECTED_RE_TYPE: builtins.str \| None - reveal_type(obj) + assert_type(obj, str | None) def t_connection_scalars() -> None: result = connection.scalars(single_stmt) - # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\[builtins.str\*?\] - reveal_type(result) + assert_type(result, ScalarResult[str]) data = result.all() - # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\] - reveal_type(data) + assert_type(data, Sequence[str]) def t_session_execute_multi() -> None: result = session.execute(multi_stmt).t - # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.int\*?, builtins.str\*?\]\] - reveal_type(result) + assert_type(result, TupleResult[tuple[int, str]]) row = result.one() - # EXPECTED_RE_TYPE: tuple\[builtins.int\*?, builtins.str\*?\] - reveal_type(row) + assert_type(row, tuple[int, str]) x, y = row - # EXPECTED_RE_TYPE: builtins.int\*? - reveal_type(x) + assert_type(x, int) - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(y) + assert_type(y, str) def t_session_execute_single() -> None: result = session.execute(single_stmt).t - # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.str\*?\]\] - reveal_type(result) + assert_type(result, TupleResult[tuple[str]]) row = result.one() - # EXPECTED_RE_TYPE: tuple\[builtins.str\*?\] - reveal_type(row) + assert_type(row, tuple[str]) (x,) = row - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(x) + assert_type(x, str) def t_session_scalar() -> None: obj = session.scalar(single_stmt) - # EXPECTED_RE_TYPE: builtins.str \| None - reveal_type(obj) + assert_type(obj, str | None) def t_session_scalars() -> None: result = session.scalars(single_stmt) - # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\[builtins.str\*?\] - reveal_type(result) + assert_type(result, ScalarResult[str]) data = result.all() - # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\] - reveal_type(data) + assert_type(data, Sequence[str]) async def t_async_connection_execute_multi() -> None: result = (await async_connection.execute(multi_stmt)).t - # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.int\*?, builtins.str\*?\]\] - reveal_type(result) + assert_type(result, TupleResult[tuple[int, str]]) row = result.one() - # EXPECTED_RE_TYPE: tuple\[builtins.int\*?, builtins.str\*?\] - reveal_type(row) + assert_type(row, tuple[int, str]) x, y = row - # EXPECTED_RE_TYPE: builtins.int\*? - reveal_type(x) + assert_type(x, int) - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(y) + assert_type(y, str) async def t_async_connection_execute_single() -> None: result = (await async_connection.execute(single_stmt)).t - # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.str\*?\]\] - reveal_type(result) + assert_type(result, TupleResult[tuple[str]]) row = result.one() - # EXPECTED_RE_TYPE: tuple\[builtins.str\*?\] - reveal_type(row) + assert_type(row, tuple[str]) (x,) = row - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(x) + assert_type(x, str) async def t_async_connection_scalar() -> None: obj = await async_connection.scalar(single_stmt) - # EXPECTED_RE_TYPE: builtins.str \| None - reveal_type(obj) + assert_type(obj, str | None) async def t_async_connection_scalars() -> None: result = await async_connection.scalars(single_stmt) - # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\*?\[builtins.str\*?\] - reveal_type(result) + assert_type(result, ScalarResult[str]) data = result.all() - # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\] - reveal_type(data) + assert_type(data, Sequence[str]) async def t_async_session_execute_multi() -> None: result = (await async_session.execute(multi_stmt)).t - # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.int\*?, builtins.str\*?\]\] - reveal_type(result) + assert_type(result, TupleResult[tuple[int, str]]) row = result.one() - # EXPECTED_RE_TYPE: tuple\[builtins.int\*?, builtins.str\*?\] - reveal_type(row) + assert_type(row, tuple[int, str]) x, y = row - # EXPECTED_RE_TYPE: builtins.int\*? - reveal_type(x) + assert_type(x, int) - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(y) + assert_type(y, str) async def t_async_session_execute_single() -> None: result = (await async_session.execute(single_stmt)).t - # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.str\*?\]\] - reveal_type(result) + assert_type(result, TupleResult[tuple[str]]) row = result.one() - # EXPECTED_RE_TYPE: tuple\[builtins.str\*?\] - reveal_type(row) + assert_type(row, tuple[str]) (x,) = row - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(x) + assert_type(x, str) async def t_async_session_scalar() -> None: obj = await async_session.scalar(single_stmt) - # EXPECTED_RE_TYPE: builtins.str \| None - reveal_type(obj) + assert_type(obj, str | None) async def t_async_session_scalars() -> None: result = await async_session.scalars(single_stmt) - # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\*?\[builtins.str\*?\] - reveal_type(result) + assert_type(result, ScalarResult[str]) data = result.all() - # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\] - reveal_type(data) + assert_type(data, Sequence[str]) async def t_async_connection_stream_multi() -> None: result = (await async_connection.stream(multi_stmt)).t - # EXPECTED_RE_TYPE: sqlalchemy.*AsyncTupleResult\[tuple\[builtins.int\*?, builtins.str\*?\]\] - reveal_type(result) + assert_type(result, AsyncTupleResult[tuple[int, str]]) row = await result.one() - # EXPECTED_RE_TYPE: tuple\[builtins.int\*?, builtins.str\*?\] - reveal_type(row) + assert_type(row, tuple[int, str]) x, y = row - # EXPECTED_RE_TYPE: builtins.int\*? - reveal_type(x) + assert_type(x, int) - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(y) + assert_type(y, str) async def t_async_connection_stream_single() -> None: result = (await async_connection.stream(single_stmt)).t - # EXPECTED_RE_TYPE: sqlalchemy.*AsyncTupleResult\[tuple\[builtins.str\*?\]\] - reveal_type(result) + assert_type(result, AsyncTupleResult[tuple[str]]) row = await result.one() - # EXPECTED_RE_TYPE: tuple\[builtins.str\*?\] - reveal_type(row) + assert_type(row, tuple[str]) (x,) = row - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(x) + assert_type(x, str) async def t_async_connection_stream_scalars() -> None: result = await async_connection.stream_scalars(single_stmt) - # EXPECTED_RE_TYPE: sqlalchemy.*AsyncScalarResult\*?\[builtins.str\*?\] - reveal_type(result) + assert_type(result, AsyncScalarResult[str]) data = await result.all() - # EXPECTED_RE_TYPE: typing.Sequence\*?\[builtins.str\*?\] - reveal_type(data) + assert_type(data, Sequence[str]) async def t_async_session_stream_multi() -> None: result = (await async_session.stream(multi_stmt)).t - # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.int\*?, builtins.str\*?\]\] - reveal_type(result) + assert_type(result, AsyncTupleResult[tuple[int, str]]) row = await result.one() - # EXPECTED_RE_TYPE: tuple\[builtins.int\*?, builtins.str\*?\] - reveal_type(row) + assert_type(row, tuple[int, str]) x, y = row - # EXPECTED_RE_TYPE: builtins.int\*? - reveal_type(x) + assert_type(x, int) - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(y) + assert_type(y, str) async def t_async_session_stream_single() -> None: result = (await async_session.stream(single_stmt)).t - # EXPECTED_RE_TYPE: sqlalchemy.*AsyncTupleResult\[tuple\[builtins.str\*?\]\] - reveal_type(result) + assert_type(result, AsyncTupleResult[tuple[str]]) row = await result.one() - # EXPECTED_RE_TYPE: tuple\[builtins.str\*?\] - reveal_type(row) + assert_type(row, tuple[str]) (x,) = row - # EXPECTED_RE_TYPE: builtins.str\*? - reveal_type(x) + assert_type(x, str) async def t_async_session_stream_scalars() -> None: result = await async_session.stream_scalars(single_stmt) - # EXPECTED_RE_TYPE: sqlalchemy.*AsyncScalarResult\*?\[builtins.str\*?\] - reveal_type(result) + assert_type(result, AsyncScalarResult[str]) data = await result.all() - # EXPECTED_RE_TYPE: typing.Sequence\*?\[builtins.str\*?\] - reveal_type(data) + assert_type(data, Sequence[str]) def test_outerjoin_10173() -> None: diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py index 7b6c93de14..a78e2492a5 100644 --- a/tools/generate_sql_functions.py +++ b/tools/generate_sql_functions.py @@ -176,17 +176,16 @@ def {key}(self) -> Type[{_type}]:{_reserved_word} # The origin type, if rtype is a generic orig_type = typing.get_origin(rtype) if orig_type is not None: - coltype = rf".*{orig_type.__name__}\[.*int\]" + coltype = rf"{orig_type.__name__}[int]" else: - coltype = ".*int" + coltype = "int" buf.write( textwrap.indent( rf""" stmt{count} = select(func.{key}(column('x', Integer))) -# EXPECTED_RE_TYPE: .*Select\[{coltype}\] -reveal_type(stmt{count}) +assert_type(stmt{count}, Select[{coltype}]) """, indent, @@ -199,8 +198,7 @@ reveal_type(stmt{count}) rf""" stmt{count} = select(func.{key}(column('x', String), ',')) -# EXPECTED_RE_TYPE: .*Select\[.*str\] -reveal_type(stmt{count}) +assert_type(stmt{count}, Select[str]) """, indent, @@ -211,10 +209,10 @@ reveal_type(stmt{count}) fn_class.type, TypeEngine ): python_type = fn_class.type.python_type - python_expr = rf".*{python_type.__name__}" + python_expr = python_type.__name__ argspec = inspect.getfullargspec(fn_class) if fn_class.__name__ == "next_value": - args = "Sequence('x_seq')" + args = "SqlAlchemySequence('x_seq')" else: args = ", ".join( 'column("x")' for elem in argspec.args[1:] @@ -226,8 +224,7 @@ reveal_type(stmt{count}) rf""" stmt{count} = select(func.{key}({args})) -# EXPECTED_RE_TYPE: .*Select\[{python_expr}\] -reveal_type(stmt{count}) +assert_type(stmt{count}, Select[{python_expr}]) """, indent, -- 2.47.3