]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Change typing tests to use `assert_type` instead of `reveal_type`
authorRebecca Chen <rechen@fb.com>
Sat, 18 Oct 2025 14:20:55 +0000 (10:20 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 18 Oct 2025 15:23:20 +0000 (11:23 -0400)
Closes: #12922
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12922
Pull-request-sha: 580f6638168c33e6c50e95066312ac605433665f

Change-Id: I9f3bdb4c105971f53fa10ed8a934356203ddb080

41 files changed:
lib/sqlalchemy/testing/fixtures/mypy.py
test/typing/plain_files/dialects/postgresql/pg_stuff.py
test/typing/plain_files/engine/engine_inspection.py
test/typing/plain_files/engine/engine_result.py
test/typing/plain_files/engine/engines.py
test/typing/plain_files/ext/association_proxy/association_proxy_one.py
test/typing/plain_files/ext/association_proxy/association_proxy_three.py
test/typing/plain_files/ext/association_proxy/association_proxy_two.py
test/typing/plain_files/ext/asyncio/engines.py
test/typing/plain_files/ext/hybrid/hybrid_one.py
test/typing/plain_files/ext/hybrid/hybrid_two.py
test/typing/plain_files/ext/indexable.py
test/typing/plain_files/ext/orderinglist/orderinglist_one.py
test/typing/plain_files/inspection_inspect.py
test/typing/plain_files/orm/composite.py
test/typing/plain_files/orm/composite_dc.py
test/typing/plain_files/orm/dataclass_transforms_decorator.py
test/typing/plain_files/orm/dataclass_transforms_decorator_w_mixins.py
test/typing/plain_files/orm/dataclass_transforms_one.py
test/typing/plain_files/orm/declared_attr_one.py
test/typing/plain_files/orm/declared_attr_two.py
test/typing/plain_files/orm/dynamic_rel.py
test/typing/plain_files/orm/issue_9340.py
test/typing/plain_files/orm/keyfunc_dict.py
test/typing/plain_files/orm/relationship.py
test/typing/plain_files/orm/session.py
test/typing/plain_files/orm/sessionmakers.py
test/typing/plain_files/orm/trad_relationship_uselist.py
test/typing/plain_files/orm/traditional_relationship.py
test/typing/plain_files/orm/typed_queries.py
test/typing/plain_files/orm/write_only.py
test/typing/plain_files/sql/common_sql_element.py
test/typing/plain_files/sql/functions.py
test/typing/plain_files/sql/functions_again.py
test/typing/plain_files/sql/lambda_stmt.py
test/typing/plain_files/sql/misc.py
test/typing/plain_files/sql/operators.py
test/typing/plain_files/sql/sql_operations.py
test/typing/plain_files/sql/sqltypes.py
test/typing/plain_files/sql/typed_results.py
tools/generate_sql_functions.py

index b1d2ee0e81642a6158d31f84e00744b005132440..cc16fa37448bd9127e5742c97a6f80c2b71a837a 100644 (file)
@@ -29,8 +29,6 @@ try:
 except ImportError:
     _mypy_vers_tuple = (0, 0, 0)
 
-mypy_14 = _mypy_vers_tuple >= (1, 4)
-
 
 @config.add_to_marker.mypy
 class MypyTest(TestBase):
@@ -128,9 +126,7 @@ class MypyTest(TestBase):
 
     def _collect_messages(self, path):
         expected_messages = []
-        expected_re = re.compile(
-            r"\s*# EXPECTED(_MYPY)?(_RE)?(_ROW)?(_TYPE)?: (.+)"
-        )
+        expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)")
         py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)")
         with open(path) as file_:
             current_assert_messages = []
@@ -148,71 +144,21 @@ class MypyTest(TestBase):
                 if m:
                     is_mypy = bool(m.group(1))
                     is_re = bool(m.group(2))
-                    is_row = bool(m.group(3))
-                    is_type = bool(m.group(4))
-
-                    expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(5))
-                    if is_row:
-                        expected_msg = re.sub(
-                            r"Row\[([^\]]+)\]",
-                            lambda m: f"tuple[{m.group(1)}, fallback=s"
-                            f"qlalchemy.engine.row.{m.group(0)}]",
-                            expected_msg,
-                        )
-                        # For some reason it does not use or syntax (|)
-                        expected_msg = re.sub(
-                            r"Optional\[(.*)\]",
-                            lambda m: f"Union[{m.group(1)}, None]",
-                            expected_msg,
-                        )
+                    is_type = bool(m.group(3))
+                    expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4))
 
                     if is_type:
-                        if not is_re:
-                            # the goal here is that we can cut-and-paste
-                            # from vscode -> pylance into the
-                            # EXPECTED_TYPE: line, then the test suite will
-                            # validate that line against what mypy produces
-                            expected_msg = re.sub(
-                                r"([\[\]])",
-                                lambda m: rf"\{m.group(0)}",
-                                expected_msg,
-                            )
-
-                            # note making sure preceding text matches
-                            # with a dot, so that an expect for "Select"
-                            # does not match "TypedSelect"
-                            expected_msg = re.sub(
-                                r"([\w_]+)",
-                                lambda m: rf"(?:.*\.)?{m.group(1)}\*?",
-                                expected_msg,
-                            )
-
-                            expected_msg = re.sub(
-                                "List", "builtins.list", expected_msg
-                            )
-
-                            expected_msg = re.sub(
-                                r"\b(int|str|float|bool)\b",
-                                lambda m: rf"builtins.{m.group(0)}\*?",
-                                expected_msg,
-                            )
-                            # expected_msg = re.sub(
-                            #     r"(Sequence|Tuple|List|Union)",
-                            #     lambda m: fr"typing.{m.group(0)}\*?",
-                            #     expected_msg,
-                            # )
 
                         is_mypy = is_re = True
                         expected_msg = f'Revealed type is "{expected_msg}"'
 
-                    if mypy_14:
-                        # use_or_syntax
-                        # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368  # noqa: E501
-                        expected_msg = re.sub(
-                            r"Optional\[(.*?)\]",
-                            lambda m: f"{m.group(1)} | None",
-                            expected_msg,
-                        )
+                    # use_or_syntax
+                    # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368  # noqa: E501
+                    expected_msg = re.sub(
+                        r"Optional\[(.*?)\]",
+                        lambda m: f"{m.group(1)} | None",
+                        expected_msg,
+                    )
                     current_assert_messages.append(
                         (is_mypy, is_re, expected_msg.strip())
                     )
index 20785bc2cb34560a51d6bac8cfcf6750f785535e..ec99ec1b0c5e4b3af6ccf5c41e8d1b4520d1adf9 100644 (file)
@@ -1,5 +1,9 @@
+from datetime import date
+from datetime import datetime
 from typing import Any
+from typing import assert_type
 from typing import Dict
+from typing import Sequence
 from uuid import UUID as _py_uuid
 
 from sqlalchemy import cast
@@ -7,6 +11,7 @@ from sqlalchemy import Column
 from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import or_
+from sqlalchemy import Select
 from sqlalchemy import select
 from sqlalchemy import Text
 from sqlalchemy import UniqueConstraint
@@ -18,6 +23,7 @@ from sqlalchemy.dialects.postgresql import insert
 from sqlalchemy.dialects.postgresql import INT4RANGE
 from sqlalchemy.dialects.postgresql import INT8MULTIRANGE
 from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.dialects.postgresql import Range
 from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE
 from sqlalchemy.dialects.postgresql import UUID
 from sqlalchemy.orm import DeclarativeBase
@@ -28,13 +34,11 @@ from sqlalchemy.orm import mapped_column
 
 c1 = Column(UUID())
 
-# EXPECTED_TYPE: Column[UUID]
-reveal_type(c1)
+assert_type(c1, Column[_py_uuid])
 
 c2 = Column(UUID(as_uuid=False))
 
-# EXPECTED_TYPE: Column[str]
-reveal_type(c2)
+assert_type(c2, Column[str])
 
 
 class Base(DeclarativeBase):
@@ -69,11 +73,9 @@ print(stmt)
 
 t1 = Test()
 
-# EXPECTED_RE_TYPE: .*dict\[.*str, Any\]
-reveal_type(t1.data)
+assert_type(t1.data, dict[str, Any])
 
-# EXPECTED_TYPE: UUID
-reveal_type(t1.ident)
+assert_type(t1.ident, _py_uuid)
 
 unique = UniqueConstraint(name="my_constraint")
 insert(Test).on_conflict_do_nothing(
@@ -86,52 +88,40 @@ s1 = insert(Test)
 s1.on_conflict_do_update(set_=s1.excluded)
 
 
-# EXPECTED_TYPE: Column[Range[int]]
-reveal_type(Column(INT4RANGE()))
-# EXPECTED_TYPE: Column[Range[datetime.date]]
-reveal_type(Column("foo", DATERANGE()))
-# EXPECTED_TYPE: Column[Sequence[Range[int]]]
-reveal_type(Column(INT8MULTIRANGE()))
-# EXPECTED_TYPE: Column[Sequence[Range[datetime.datetime]]]
-reveal_type(Column("foo", TSTZMULTIRANGE()))
+assert_type(Column(INT4RANGE()), Column[Range[int]])
+assert_type(Column("foo", DATERANGE()), Column[Range[date]])
+assert_type(Column(INT8MULTIRANGE()), Column[Sequence[Range[int]]])
+assert_type(Column("foo", TSTZMULTIRANGE()), Column[Sequence[Range[datetime]]])
 
 
 range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE()))
 
-# EXPECTED_TYPE: Select[Range[int], Sequence[Range[int]]]
-reveal_type(range_col_stmt)
+assert_type(range_col_stmt, Select[Range[int], Sequence[Range[int]]])
 
 array_from_ints = array(range(2))
 
-# EXPECTED_TYPE: array[int]
-reveal_type(array_from_ints)
+assert_type(array_from_ints, array[int])
 
 array_of_strings = array([], type_=Text)
 
-# EXPECTED_TYPE: array[str]
-reveal_type(array_of_strings)
+assert_type(array_of_strings, array[str])
 
 array_of_ints = array([0], type_=Integer)
 
-# EXPECTED_TYPE: array[int]
-reveal_type(array_of_ints)
+assert_type(array_of_ints, array[int])
 
 # EXPECTED_MYPY_RE: Cannot infer .* of "array"
 array([0], type_=Text)
 
-# EXPECTED_TYPE: ARRAY[str]
-reveal_type(ARRAY(Text))
+assert_type(ARRAY(Text), ARRAY[str])
 
-# EXPECTED_TYPE: Column[Sequence[int]]
-reveal_type(Column(type_=ARRAY(Integer)))
+assert_type(Column(type_=ARRAY(Integer)), Column[Sequence[int]])
 
 stmt_array_agg = select(func.array_agg(Column("num", type_=Integer)))
 
-# EXPECTED_TYPE: Select[Sequence[int]]
-reveal_type(stmt_array_agg)
+assert_type(stmt_array_agg, Select[Sequence[int]])
 
-# EXPECTED_TYPE: Select[Sequence[str]]
-reveal_type(select(func.array_agg(Test.ident_str)))
+assert_type(select(func.array_agg(Test.ident_str)), Select[Sequence[str]])
 
 stmt_array_agg_order_by_1 = select(
     func.array_agg(
@@ -143,8 +133,7 @@ stmt_array_agg_order_by_1 = select(
     )
 )
 
-# EXPECTED_TYPE: Select[Sequence[str]]
-reveal_type(stmt_array_agg_order_by_1)
+assert_type(stmt_array_agg_order_by_1, Select[Sequence[str]])
 
 stmt_array_agg_order_by_2 = select(
     func.array_agg(
@@ -152,5 +141,4 @@ stmt_array_agg_order_by_2 = select(
     )
 )
 
-# EXPECTED_TYPE: Select[Sequence[str]]
-reveal_type(stmt_array_agg_order_by_2)
+assert_type(stmt_array_agg_order_by_2, Select[Sequence[str]])
index 0ca331f189f8c702361bee6b160430c6c81686ba..0660f44380163f8d9b0e81e11fe6fc5d11198e5a 100644 (file)
@@ -1,7 +1,11 @@
 import typing
+from typing import assert_type
 
 from sqlalchemy import create_engine
 from sqlalchemy import inspect
+from sqlalchemy.engine import Inspector
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.engine.interfaces import ReflectedColumn
 
 
 e = create_engine("sqlite://")
@@ -13,11 +17,8 @@ cols = insp.get_columns("some_table")
 c1 = cols[0]
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_RE_TYPE: sqlalchemy.engine.base.Engine
-    reveal_type(e)
+    assert_type(e, Engine)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.engine.reflection.Inspector.*
-    reveal_type(insp)
+    assert_type(insp, Inspector)
 
-    # EXPECTED_RE_TYPE: .*list.*TypedDict.*ReflectedColumn.*
-    reveal_type(cols)
+    assert_type(cols, list[ReflectedColumn])
index 1c76cf68b44d1c4196b3871a6d26adda28b44e18..4c4b030f18c3f763cfaf7b3905e687aa46715367 100644 (file)
@@ -1,73 +1,56 @@
+from typing import Any
+from typing import assert_type
+from typing import Sequence
+
 from sqlalchemy import column
 from sqlalchemy.engine import Result
 from sqlalchemy.engine import Row
+from sqlalchemy.engine import RowMapping
+from sqlalchemy.engine.result import FrozenResult
+from sqlalchemy.engine.result import MappingResult
+from sqlalchemy.engine.result import ScalarResult
 
 
 def row_one(row: Row[int, str, bool]) -> None:
-    # EXPECTED_TYPE: int
-    reveal_type(row[0])
-    # EXPECTED_TYPE: str
-    reveal_type(row[1])
-    # EXPECTED_TYPE: bool
-    reveal_type(row[2])
+    assert_type(row[0], int)
+    assert_type(row[1], str)
+    assert_type(row[2], bool)
 
     # EXPECTED_MYPY: Tuple index out of range
     row[3]
     # EXPECTED_MYPY: No overload variant of "__getitem__" of "tuple" matches argument type "str"  # noqa: E501
     row["a"]
 
-    # EXPECTED_TYPE: RowMapping
-    reveal_type(row._mapping)
+    assert_type(row._mapping, RowMapping)
     rm = row._mapping
-    # EXPECTED_TYPE: Any
-    reveal_type(rm["foo"])
-    # EXPECTED_TYPE: Any
-    reveal_type(rm[column("bar")])
+    assert_type(rm["foo"], Any)
+    assert_type(rm[column("bar")], Any)
 
     # EXPECTED_MYPY_RE: Invalid index type "int" for "RowMapping"; expected type "(str \| SQLCoreOperations\[Any\]|Union\[str, SQLCoreOperations\[Any\]\])"  # noqa: E501
     rm[3]
 
 
 def result_one(res: Result[int, str]) -> None:
-    # EXPECTED_ROW_TYPE: Row[int, str]
-    reveal_type(res.one())
-    # EXPECTED_ROW_TYPE: Row[int, str] | None
-    reveal_type(res.one_or_none())
-    # EXPECTED_ROW_TYPE: Row[int, str] | None
-    reveal_type(res.fetchone())
-    # EXPECTED_ROW_TYPE: Row[int, str] | None
-    reveal_type(res.first())
-    # EXPECTED_ROW_TYPE: Sequence[Row[int, str]]
-    reveal_type(res.all())
-    # EXPECTED_ROW_TYPE: Sequence[Row[int, str]]
-    reveal_type(res.fetchmany())
-    # EXPECTED_ROW_TYPE: Sequence[Row[int, str]]
-    reveal_type(res.fetchall())
-    # EXPECTED_ROW_TYPE: Row[int, str]
-    reveal_type(next(res))
+    assert_type(res.one(), Row[int, str])
+    assert_type(res.one_or_none(), Row[int, str] | None)
+    assert_type(res.fetchone(), Row[int, str] | None)
+    assert_type(res.first(), Row[int, str] | None)
+    assert_type(res.all(), Sequence[Row[int, str]])
+    assert_type(res.fetchmany(), Sequence[Row[int, str]])
+    assert_type(res.fetchall(), Sequence[Row[int, str]])
+    assert_type(next(res), Row[int, str])
     for rf in res:
-        # EXPECTED_ROW_TYPE: Row[int, str]
-        reveal_type(rf)
+        assert_type(rf, Row[int, str])
     for rp in res.partitions():
-        # EXPECTED_ROW_TYPE: Sequence[Row[int, str]]
-        reveal_type(rp)
-
-    # EXPECTED_TYPE: ScalarResult[int]
-    res_s = reveal_type(res.scalars())
-    # EXPECTED_TYPE: ScalarResult[int]
-    res_s = reveal_type(res.scalars(0))
-    # EXPECTED_TYPE: int
-    reveal_type(res_s.one())
-    # EXPECTED_TYPE: ScalarResult[Any]
-    reveal_type(res.scalars(1))
-    # EXPECTED_TYPE: MappingResult
-    reveal_type(res.mappings())
-    # EXPECTED_TYPE: FrozenResult[int, str]
-    reveal_type(res.freeze())
-
-    # EXPECTED_TYPE: int
-    reveal_type(res.scalar_one())
-    # EXPECTED_TYPE: int | None
-    reveal_type(res.scalar_one_or_none())
-    # EXPECTED_TYPE: int | None
-    reveal_type(res.scalar())
+        assert_type(rp, Sequence[Row[int, str]])
+
+    res_s = assert_type(res.scalars(), ScalarResult[int])
+    res_s = assert_type(res.scalars(0), ScalarResult[int])
+    assert_type(res_s.one(), int)
+    assert_type(res.scalars(1), ScalarResult[Any])
+    assert_type(res.mappings(), MappingResult)
+    assert_type(res.freeze(), FrozenResult[int, str])
+
+    assert_type(res.scalar_one(), int)
+    assert_type(res.scalar_one_or_none(), int | None)
+    assert_type(res.scalar(), int | None)
index 15aa774e6aed891b1026e9058321425f96735fb0..7e06beaedee318f9226c8f78746bf251cd798ee3 100644 (file)
@@ -1,32 +1,34 @@
+from typing import Any
+from typing import assert_type
+from typing import Unpack
+
+from sqlalchemy import Connection
 from sqlalchemy import create_engine
 from sqlalchemy import Pool
 from sqlalchemy import select
 from sqlalchemy import text
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.engine.cursor import CursorResult
 
 
 def regular() -> None:
     e = create_engine("sqlite://")
 
-    # EXPECTED_TYPE: Engine
-    reveal_type(e)
+    assert_type(e, Engine)
 
     with e.connect() as conn:
-        # EXPECTED_TYPE: Connection
-        reveal_type(conn)
+        assert_type(conn, Connection)
 
         result = conn.execute(text("select * from table"))
 
-        # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]]
-        reveal_type(result)
+        assert_type(result, CursorResult[Unpack[tuple[Any, ...]]])
 
     with e.begin() as conn:
-        # EXPECTED_TYPE: Connection
-        reveal_type(conn)
+        assert_type(conn, Connection)
 
         result = conn.execute(text("select * from table"))
 
-        # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]]
-        reveal_type(result)
+        assert_type(result, CursorResult[Unpack[tuple[Any, ...]]])
 
     engine = create_engine("postgresql://scott:tiger@localhost/test")
     status: str = engine.pool.status()
index cb9f0b85d7dc85646e416db51a87e048cd29a175..c6dd37b0c92b59fa308edde40c4eccef9e72efd8 100644 (file)
@@ -1,4 +1,5 @@
 import typing
+from typing import assert_type
 from typing import Set
 
 from sqlalchemy import ForeignKey
@@ -6,6 +7,7 @@ from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.ext.associationproxy import AssociationProxy
+from sqlalchemy.ext.associationproxy import AssociationProxyInstance
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
@@ -40,8 +42,6 @@ class Address(Base):
 u1 = User()
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_RE_TYPE: sqlalchemy.*.associationproxy.AssociationProxyInstance\[builtins.set\*?\[builtins.str\]\]
-    reveal_type(User.email_addresses)
+    assert_type(User.email_addresses, AssociationProxyInstance[set[str]])
 
-    # EXPECTED_RE_TYPE: builtins.set\*?\[builtins.str\]
-    reveal_type(u1.email_addresses)
+    assert_type(u1.email_addresses, set[str])
index f338681f7c498e0554daa8004e48a6c20f0a1f21..2f18a9aff34b08452d0554a3894acfa47aeb8f00 100644 (file)
@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+from typing import assert_type
 from typing import List
 
 from sqlalchemy import ForeignKey
@@ -42,5 +43,4 @@ bm = BranchMilestone()
 
 x1 = bm.user_ids
 
-# EXPECTED_TYPE: list[int]
-reveal_type(x1)
+assert_type(x1, list[int])
index 074a6a71a8369972246e1f9080ab00194f31df10..95bc47da3d353cdd277c94f3ad8b8101c18df97b 100644 (file)
@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+from typing import assert_type
 from typing import Final
 
 from sqlalchemy import Column
@@ -52,14 +53,12 @@ user_keyword_table: Final[Table] = Table(
 
 user = User("jek")
 
-# EXPECTED_TYPE: list[Keyword]
-reveal_type(user.kw)
+assert_type(user.kw, list[Keyword])
 
 user.kw.append(Keyword("cheese-inspector"))
 
 user.keywords.append("cheese-inspector")
 
-# EXPECTED_TYPE: list[str]
-reveal_type(user.keywords)
+assert_type(user.keywords, list[str])
 
 user.keywords.append("snack ninja")
index 7af764ecd8afe61fdbc5cc875e043026d6bcdf40..9ddd59c8987b9162120d780d52248eb5b429cc94 100644 (file)
@@ -1,11 +1,18 @@
 from typing import Any
+from typing import assert_type
+from typing import Unpack
 
 from sqlalchemy import Connection
 from sqlalchemy import Enum
 from sqlalchemy import MetaData
 from sqlalchemy import select
 from sqlalchemy import text
+from sqlalchemy.engine.cursor import CursorResult
 from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.ext.asyncio.engine import AsyncConnection
+from sqlalchemy.ext.asyncio.engine import AsyncEngine
+from sqlalchemy.ext.asyncio.result import AsyncResult
+from sqlalchemy.ext.asyncio.result import AsyncScalarResult
 
 
 def work_sync(conn: Connection, foo: int) -> Any:
@@ -15,54 +22,45 @@ def work_sync(conn: Connection, foo: int) -> Any:
 async def asyncio() -> None:
     e = create_async_engine("sqlite://")
 
-    # EXPECTED_TYPE: AsyncEngine
-    reveal_type(e)
+    assert_type(e, AsyncEngine)
 
     async with e.connect() as conn:
-        # EXPECTED_TYPE: AsyncConnection
-        reveal_type(conn)
+        assert_type(conn, AsyncConnection)
 
         result = await conn.execute(text("select * from table"))
 
-        # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]]
-        reveal_type(result)
+        assert_type(result, CursorResult[Unpack[tuple[Any, ...]]])
 
         # stream with direct await
         async_result = await conn.stream(text("select * from table"))
 
-        # EXPECTED_TYPE: AsyncResult[Unpack[.*tuple[Any, ...]]]
-        reveal_type(async_result)
+        assert_type(async_result, AsyncResult[Unpack[tuple[Any, ...]]])
 
         # stream with context manager
         async with conn.stream(
             text("select * from table")
         ) as ctx_async_result:
-            # EXPECTED_TYPE: AsyncResult[Unpack[.*tuple[Any, ...]]]
-            reveal_type(ctx_async_result)
+            assert_type(ctx_async_result, AsyncResult[Unpack[tuple[Any, ...]]])
 
         # stream_scalars with direct await
         async_scalar_result = await conn.stream_scalars(
             text("select * from table")
         )
 
-        # EXPECTED_TYPE: AsyncScalarResult[Any]
-        reveal_type(async_scalar_result)
+        assert_type(async_scalar_result, AsyncScalarResult[Any])
 
         # stream_scalars with context manager
         async with conn.stream_scalars(
             text("select * from table")
         ) as ctx_async_scalar_result:
-            # EXPECTED_TYPE: AsyncScalarResult[Any]
-            reveal_type(ctx_async_scalar_result)
+            assert_type(ctx_async_scalar_result, AsyncScalarResult[Any])
 
     async with e.begin() as conn:
-        # EXPECTED_TYPE: AsyncConnection
-        reveal_type(conn)
+        assert_type(conn, AsyncConnection)
 
         result = await conn.execute(text("select * from table"))
 
-        # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]]
-        reveal_type(result)
+        assert_type(result, CursorResult[Unpack[tuple[Any, ...]]])
 
         await conn.run_sync(work_sync, 1)
 
index aef41395fee4b85d72cd2dda2286b84b4c9a1b8c..09f98781c202d7895b72430f57917b6b5ded9b9f 100644 (file)
@@ -1,13 +1,18 @@
 from __future__ import annotations
 
 import typing
+from typing import assert_type
 
+from sqlalchemy import Select
 from sqlalchemy import select
+from sqlalchemy.ext.hybrid import _HybridClassLevelAccessor
 from sqlalchemy.ext.hybrid import hybrid_method
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.sql.elements import SQLCoreOperations
+from sqlalchemy.sql.expression import BinaryExpression
 
 
 class Base(DeclarativeBase):
@@ -66,26 +71,18 @@ stmt1 = select(Interval).where(expr1).where(expr4)
 stmt2 = select(expr4)
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_RE_TYPE: builtins.int\*?
-    reveal_type(i1.length)
+    assert_type(i1.length, int)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.int\*?\]
-    reveal_type(Interval.length)
+    assert_type(Interval.length, _HybridClassLevelAccessor[int])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
-    reveal_type(expr1)
+    assert_type(expr1, BinaryExpression[bool])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
-    reveal_type(expr2)
+    assert_type(expr2, SQLCoreOperations[int])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
-    reveal_type(expr3)
+    assert_type(expr3, SQLCoreOperations[int])
 
-    # EXPECTED_TYPE: bool
-    reveal_type(i1.fancy_thing(1, 2, 3))
+    assert_type(i1.fancy_thing(1, 2, 3), bool)
 
-    # EXPECTED_TYPE: SQLCoreOperations[bool]
-    reveal_type(expr4)
+    assert_type(expr4, SQLCoreOperations[bool])
 
-    # EXPECTED_TYPE: Select[bool]
-    reveal_type(stmt2)
+    assert_type(stmt2, Select[bool])
index b4f2aca769527ccdc9bda6b6276439e0fcf4345f..a0b8e325427ed8738cb1388eede673fa67ebf757 100644 (file)
@@ -1,13 +1,17 @@
 from __future__ import annotations
 
 import typing
+from typing import assert_type
 
 from sqlalchemy import Float
 from sqlalchemy import func
+from sqlalchemy import Function
+from sqlalchemy.ext.hybrid import _HybridClassLevelAccessor
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.sql.expression import BinaryExpression
 from sqlalchemy.sql.expression import ColumnElement
 
 
@@ -44,11 +48,9 @@ class Interval(Base):
 
         # while we are here, check some Float[] / div type stuff
         if typing.TYPE_CHECKING:
-            # EXPECTED_RE_TYPE: sqlalchemy.*Function\[builtins.float\*?\]
-            reveal_type(f1)
+            assert_type(f1, Function[float])
 
-            # EXPECTED_RE_TYPE: sqlalchemy.*ColumnElement\[builtins.float\*?\]
-            reveal_type(expr)
+            assert_type(expr, ColumnElement[float])
         return expr
 
     # new way - use the original decorator with inplace
@@ -66,11 +68,9 @@ class Interval(Base):
 
         # while we are here, check some Float[] / div type stuff
         if typing.TYPE_CHECKING:
-            # EXPECTED_RE_TYPE: sqlalchemy.*Function\[builtins.float\*?\]
-            reveal_type(f1)
+            assert_type(f1, Function[float])
 
-            # EXPECTED_RE_TYPE: sqlalchemy.*ColumnElement\[builtins.float\*?\]
-            reveal_type(expr)
+            assert_type(expr, ColumnElement[float])
         return expr
 
 
@@ -92,38 +92,27 @@ expr3n = Interval.new_radius.in_([0.5, 5.2])
 
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_RE_TYPE: builtins.int\*?
-    reveal_type(i1.length)
+    assert_type(i1.length, int)
 
-    # EXPECTED_RE_TYPE: builtins.float\*?
-    reveal_type(i2.old_radius)
+    assert_type(i2.old_radius, float)
 
-    # EXPECTED_RE_TYPE: builtins.float\*?
-    reveal_type(i2.new_radius)
+    assert_type(i2.new_radius, float)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.int\*?\]
-    reveal_type(Interval.length)
+    assert_type(Interval.length, _HybridClassLevelAccessor[int])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\]
-    reveal_type(Interval.old_radius)
+    assert_type(Interval.old_radius, _HybridClassLevelAccessor[float])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\]
-    reveal_type(Interval.new_radius)
+    assert_type(Interval.new_radius, _HybridClassLevelAccessor[float])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
-    reveal_type(expr1)
+    assert_type(expr1, BinaryExpression[bool])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\]
-    reveal_type(expr2o)
+    assert_type(expr2o, _HybridClassLevelAccessor[float])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\]
-    reveal_type(expr2n)
+    assert_type(expr2n, _HybridClassLevelAccessor[float])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
-    reveal_type(expr3o)
+    assert_type(expr3o, BinaryExpression[bool])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
-    reveal_type(expr3n)
+    assert_type(expr3n, BinaryExpression[bool])
 
 # test #9268
 
index c6c1c35299bbc470be1527656c8db50e172902e6..976577d6fb2fd3688cfcac90f0aecb2d8fd1e079 100644 (file)
@@ -1,12 +1,15 @@
 from __future__ import annotations
 
 from datetime import date
+from typing import assert_type
 from typing import Dict
 from typing import List
 
 from sqlalchemy import ARRAY
 from sqlalchemy import JSON
+from sqlalchemy import Select
 from sqlalchemy import select
+from sqlalchemy.ext.hybrid import _HybridClassLevelAccessor
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.ext.indexable import index_property
 from sqlalchemy.orm import DeclarativeBase
@@ -38,29 +41,22 @@ a = Article(
     updates=[date(2025, 7, 28), date(2025, 7, 29)],
 )
 
-# EXPECTED_TYPE: str
-reveal_type(a.topic)
+assert_type(a.topic, str)
 
-# EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.str\*?\]
-reveal_type(Article.topic)
+assert_type(Article.topic, _HybridClassLevelAccessor[str])
 
-# EXPECTED_TYPE: date
-reveal_type(a.created_at)
+assert_type(a.created_at, date)
 
-# EXPECTED_TYPE: date
-reveal_type(a.updated_at)
+assert_type(a.updated_at, date)
 
 a.created_at = date(2025, 7, 30)
 
-# EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[datetime.date\*?\]
-reveal_type(Article.created_at)
+assert_type(Article.created_at, _HybridClassLevelAccessor[date])
 
-# EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[datetime.date\*?\]
-reveal_type(Article.updated_at)
+assert_type(Article.updated_at, _HybridClassLevelAccessor[date])
 
 stmt = select(Article.id, Article.topic, Article.created_at).where(
     Article.id == 1
 )
 
-# EXPECTED_RE_TYPE: .*Select\[.*int, .*str, datetime\.date\]
-reveal_type(stmt)
+assert_type(stmt, Select[int, str, date])
index d2b7c5ece0ed934ea4628b461c9f85bf8d1f3736..8371f0037588f6c78301cf60463091e5c28dcec7 100644 (file)
@@ -1,11 +1,14 @@
 from __future__ import annotations
 
 import re
+from typing import assert_type
+from typing import Callable
 from typing import Sequence
 from typing import TYPE_CHECKING
 
 from sqlalchemy import ForeignKey
 from sqlalchemy.ext.orderinglist import ordering_list
+from sqlalchemy.ext.orderinglist import OrderingList
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
@@ -47,8 +50,6 @@ slide = Slide()
 
 
 if TYPE_CHECKING:
-    # EXPECTED_RE_TYPE: def \(\) -> sqlalchemy.*.orderinglist.OrderingList\[orderinglist_one.Bullet\]
-    reveal_type(pos_from_text)
+    assert_type(pos_from_text, Callable[[], OrderingList[Bullet]])
 
-    # EXPECTED_TYPE: builtins.list[orderinglist_one.Bullet]
-    reveal_type(slide.bullets)
+    assert_type(slide.bullets, list[Bullet])
index 886484dbc9aa574e6efd432f6229c7aa6e45fdc6..a37d400efb730dc6932f193e9209ab004f3832d0 100644 (file)
@@ -1,3 +1,5 @@
+from typing import Any
+from typing import assert_type
 from typing import List
 
 from sqlalchemy import create_engine
@@ -8,6 +10,7 @@ from sqlalchemy.orm import DeclarativeBaseNoMeta
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import Mapper
+from sqlalchemy.orm.state import InstanceState
 
 
 class Base(DeclarativeBase):
@@ -32,10 +35,8 @@ class B(BaseNoMeta):
     data: Mapped[str]
 
 
-# EXPECTED_TYPE: Mapper[Any]
-reveal_type(A.__mapper__)
-# EXPECTED_TYPE: Mapper[Any]
-reveal_type(B.__mapper__)
+assert_type(A.__mapper__, Mapper[Any])
+assert_type(B.__mapper__, Mapper[Any])
 
 a1 = A(data="d")
 b1 = B(data="d")
@@ -45,22 +46,17 @@ e = create_engine("sqlite://")
 insp_a1 = inspect(a1)
 
 t: bool = insp_a1.transient
-# EXPECTED_TYPE: InstanceState[A]
-reveal_type(insp_a1)
-# EXPECTED_TYPE: InstanceState[B]
-reveal_type(inspect(b1))
+assert_type(insp_a1, InstanceState[A])
+assert_type(inspect(b1), InstanceState[B])
 
 m: Mapper[A] = inspect(A)
-# EXPECTED_TYPE: Mapper[A]
-reveal_type(inspect(A))
-# EXPECTED_TYPE: Mapper[B]
-reveal_type(inspect(B))
+assert_type(inspect(A), Mapper[A])
+assert_type(inspect(B), Mapper[B])
 
 tables: List[str] = inspect(e).get_table_names()
 
 i: Inspector = inspect(e)
-# EXPECTED_TYPE: Inspector
-reveal_type(inspect(e))
+assert_type(inspect(e), Inspector)
 
 
 with e.connect() as conn:
index f82bbe7c2dfd7d40d992cd11ac2b85af42dd50e3..f808696ca540497ae2e89256e8ab55afff0dcef2 100644 (file)
@@ -1,6 +1,8 @@
 from typing import Any
+from typing import assert_type
 from typing import Tuple
 
+from sqlalchemy import Select
 from sqlalchemy import select
 from sqlalchemy.orm import composite
 from sqlalchemy.orm import DeclarativeBase
@@ -58,14 +60,10 @@ v1 = Vertex(start=Point(3, 4), end=Point(5, 6))
 
 stmt = select(Vertex).where(Vertex.start.in_([Point(3, 4)]))
 
-# EXPECTED_TYPE: Select[Vertex]
-reveal_type(stmt)
+assert_type(stmt, Select[Vertex])
 
-# EXPECTED_TYPE: composite.Point
-reveal_type(v1.start)
+assert_type(v1.start, Point)
 
-# EXPECTED_TYPE: composite.Point
-reveal_type(v1.end)
+assert_type(v1.end, Point)
 
-# EXPECTED_TYPE: int
-reveal_type(v1.end.y)
+assert_type(v1.end.y, int)
index 3d8117a999a79e5a7edc1110687c7f8437f5ee93..25aaae17036d9f0f808a58eb7127f7373a21f910 100644 (file)
@@ -1,5 +1,7 @@
 import dataclasses
+from typing import assert_type
 
+from sqlalchemy import Select
 from sqlalchemy import select
 from sqlalchemy.orm import composite
 from sqlalchemy.orm import DeclarativeBase
@@ -38,14 +40,10 @@ v1 = Vertex(start=Point(3, 4), end=Point(5, 6))
 
 stmt = select(Vertex).where(Vertex.start.in_([Point(3, 4)]))
 
-# EXPECTED_TYPE: Select[Vertex]
-reveal_type(stmt)
+assert_type(stmt, Select[Vertex])
 
-# EXPECTED_TYPE: composite.Point
-reveal_type(v1.start)
+assert_type(v1.start, Point)
 
-# EXPECTED_TYPE: composite.Point
-reveal_type(v1.end)
+assert_type(v1.end, Point)
 
-# EXPECTED_TYPE: int
-reveal_type(v1.end.y)
+assert_type(v1.end.y, int)
index 01114c51e99c4253123f855f77bb26cf43a40825..6738788d462aa32cd284b37f0b9623ec0158802d 100644 (file)
@@ -1,3 +1,5 @@
+from typing import assert_type
+
 from sqlalchemy import Integer
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_as_dataclass
@@ -19,5 +21,4 @@ class Relationships:
 rs = Relationships(entity_id1=1, entity_id2=2, level=1)
 
 
-# EXPECTED_TYPE: int
-reveal_type(rs.entity_id1)
+assert_type(rs.entity_id1, int)
index ecdb8b3eda8426213d804941bffc172878973c86..5ca3be4612bcb1080ebabfc956d2dbd73d1aba69 100644 (file)
@@ -1,3 +1,5 @@
+from typing import assert_type
+
 from sqlalchemy import Integer
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_as_dataclass
@@ -32,8 +34,6 @@ class Relationships(RelationshipsModel):
 # (this is the type checker, not us)
 rs = Relationships(entity_id1=1, entity_id2=2, level=1)
 
-# EXPECTED_TYPE: int
-reveal_type(rs.entity_id1)
+assert_type(rs.entity_id1, int)
 
-# EXPECTED_TYPE: int
-reveal_type(rs.level)
+assert_type(rs.level, int)
index 986483d8ef06ae743585d8154bd724d88afadc58..b99adfab0d1ea296e7ce01068914350da45a5cb0 100644 (file)
@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+from typing import assert_type
 from typing import Optional
 
 from sqlalchemy.orm import column_property
@@ -25,11 +26,9 @@ class TestInitialSupport(Base):
 
 tis = TestInitialSupport(data="some data", y=5)
 
-# EXPECTED_TYPE: str
-reveal_type(tis.data)
+assert_type(tis.data, str)
 
-# EXPECTED_RE_TYPE: .*builtins.int \| None
-reveal_type(tis.y)
+assert_type(tis.y, int | None)
 
 tis.data = "some other data"
 
index 79f1548e36504653181da2ccd6ced4fc11266054..4493a2667a246244956cb8e9ace1c8c093e8c986 100644 (file)
@@ -1,16 +1,22 @@
 from datetime import datetime
 import typing
+from typing import Any
+from typing import assert_type
+from typing import Unpack
 
 from sqlalchemy import DateTime
 from sqlalchemy import Index
 from sqlalchemy import Integer
+from sqlalchemy import Select
 from sqlalchemy import String
 from sqlalchemy import UniqueConstraint
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import declared_attr
+from sqlalchemy.orm import InstrumentedAttribute
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import MappedClassProtocol
+from sqlalchemy.orm.mapper import Mapper
 from sqlalchemy.sql.schema import PrimaryKeyConstraint
 
 
@@ -74,14 +80,11 @@ class Manager(Employee):
 def do_something_with_mapped_class(
     cls_: MappedClassProtocol[Employee],
 ) -> None:
-    # EXPECTED_TYPE: Select[Unpack[.*tuple[Any, ...]]]
-    reveal_type(cls_.__table__.select())
+    assert_type(cls_.__table__.select(), Select[Unpack[tuple[Any, ...]]])
 
-    # EXPECTED_TYPE: Mapper[Employee]
-    reveal_type(cls_.__mapper__)
+    assert_type(cls_.__mapper__, Mapper[Employee])
 
-    # EXPECTED_TYPE: Employee
-    reveal_type(cls_())
+    assert_type(cls_(), Employee)
 
 
 do_something_with_mapped_class(Manager)
@@ -89,8 +92,6 @@ do_something_with_mapped_class(Engineer)
 
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_TYPE: InstrumentedAttribute[datetime]
-    reveal_type(Engineer.start_date)
+    assert_type(Engineer.start_date, InstrumentedAttribute[datetime])
 
-    # EXPECTED_TYPE: InstrumentedAttribute[datetime]
-    reveal_type(Manager.start_date)
+    assert_type(Manager.start_date, InstrumentedAttribute[datetime])
index c8e12ee93199aad6e0d2c9dfd3ab8f3b66741bf8..3792570513baa4a02e7dd0dd34c4eead0042e158 100644 (file)
@@ -1,9 +1,11 @@
 import typing
+from typing import assert_type
 
 from sqlalchemy import Integer
 from sqlalchemy import Text
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import declared_attr
+from sqlalchemy.orm import InstrumentedAttribute
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 
@@ -39,14 +41,10 @@ class Foo(Base):
 u1 = User()
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_TYPE: str
-    reveal_type(User.__tablename__)
+    assert_type(User.__tablename__, str)
 
-    # EXPECTED_TYPE: str
-    reveal_type(Foo.__tablename__)
+    assert_type(Foo.__tablename__, str)
 
-    # EXPECTED_TYPE: str
-    reveal_type(u1.related_data)
+    assert_type(u1.related_data, str)
 
-    # EXPECTED_TYPE: InstrumentedAttribute[str]
-    reveal_type(User.related_data)
+    assert_type(User.related_data, InstrumentedAttribute[str])
index 8b406bb171e09f680f2dc177a9a02c90c20e11e7..c9ebdbc1e441038ea4531544c755718c8749fa44 100644 (file)
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 import typing
+from typing import assert_type
 
 from sqlalchemy import ForeignKey
 from sqlalchemy import select
@@ -10,6 +11,7 @@ from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
+from sqlalchemy.orm.dynamic import AppenderQuery
 
 
 class Base(DeclarativeBase):
@@ -37,19 +39,16 @@ with Session() as session:
     session.commit()
 
     if typing.TYPE_CHECKING:
-        # EXPECTED_TYPE: AppenderQuery[Address]
-        reveal_type(u.addresses)
+        assert_type(u.addresses, AppenderQuery[Address])
 
     count = u.addresses.count()
     if typing.TYPE_CHECKING:
-        # EXPECTED_TYPE: int
-        reveal_type(count)
+        assert_type(count, int)
 
     address = u.addresses.filter(Address.email_address.like("xyz")).one()
 
     if typing.TYPE_CHECKING:
-        # EXPECTED_TYPE: Address
-        reveal_type(address)
+        assert_type(address, Address)
 
     u.addresses.append(Address())
     u.addresses.extend([Address(), Address()])
@@ -57,8 +56,7 @@ with Session() as session:
     current_addresses = list(u.addresses)
 
     if typing.TYPE_CHECKING:
-        # EXPECTED_TYPE: list[Address]
-        reveal_type(current_addresses)
+        assert_type(current_addresses, list[Address])
 
     # can assign plain list
     u.addresses = []
@@ -68,15 +66,13 @@ with Session() as session:
 
     if typing.TYPE_CHECKING:
         # still an AppenderQuery
-        # EXPECTED_TYPE: AppenderQuery[Address]
-        reveal_type(u.addresses)
+        assert_type(u.addresses, AppenderQuery[Address])
 
     u.addresses = {Address(), Address()}
 
     if typing.TYPE_CHECKING:
         # still an AppenderQuery
-        # EXPECTED_TYPE: AppenderQuery[Address]
-        reveal_type(u.addresses)
+        assert_type(u.addresses, AppenderQuery[Address])
 
     u.addresses.append(Address())
 
index 6ccd2eed3145c2ed7bcb3dd48ebe37d0e6efe433..81155f4c80c878517ee81cc732a46d73517b6b16 100644 (file)
@@ -1,13 +1,16 @@
+from typing import assert_type
 from typing import Sequence
 from typing import TYPE_CHECKING
 
 from sqlalchemy import create_engine
+from sqlalchemy import Select
 from sqlalchemy import select
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import with_polymorphic
+from sqlalchemy.orm.util import AliasedClass
 
 
 class Base(DeclarativeBase): ...
@@ -39,8 +42,7 @@ def get_messages() -> Sequence[Message]:
         message_query = select(Message)
 
         if TYPE_CHECKING:
-            # EXPECTED_TYPE: Select[Message]
-            reveal_type(message_query)
+            assert_type(message_query, Select[Message])
 
         return session.scalars(message_query).all()
 
@@ -50,13 +52,11 @@ def get_poly_messages() -> Sequence[Message]:
         PolymorphicMessage = with_polymorphic(Message, (UserComment,))
 
         if TYPE_CHECKING:
-            # EXPECTED_TYPE: AliasedClass[Message]
-            reveal_type(PolymorphicMessage)
+            assert_type(PolymorphicMessage, AliasedClass[Message])
 
         poly_query = select(PolymorphicMessage)
 
         if TYPE_CHECKING:
-            # EXPECTED_TYPE: Select[Message]
-            reveal_type(poly_query)
+            assert_type(poly_query, Select[Message])
 
         return session.scalars(poly_query).all()
index 831861000ce8782f515ad6dd1858d1c7ba8ac9d7..0b275bac8de353346b9700a734433dc4e334b8d4 100644 (file)
@@ -1,4 +1,5 @@
 import typing
+from typing import assert_type
 from typing import Dict
 from typing import Optional
 
@@ -42,5 +43,4 @@ item = Item()
 item.notes["a"] = Note("a", "atext")
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_TYPE: dict_items[str, Note]
-    reveal_type(item.notes.items())
+    assert_type(list(item.notes.items()), list[tuple[str, Note]])
index 82e668ceebe6093e2cb213389e73a7a0fa754158..f818791970f7923bb1af437d9f854386132047bd 100644 (file)
@@ -3,6 +3,7 @@
 from __future__ import annotations
 
 import typing
+from typing import assert_type
 from typing import ClassVar
 from typing import List
 from typing import Optional
@@ -16,6 +17,7 @@ from sqlalchemy import select
 from sqlalchemy import Table
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import InstrumentedAttribute
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
@@ -24,6 +26,7 @@ from sqlalchemy.orm import Relationship
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import with_polymorphic
+from sqlalchemy.orm.attributes import QueryableAttribute
 
 
 class Base(DeclarativeBase):
@@ -131,47 +134,46 @@ class Engineer(Employee):
 
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str \| None\]
-    reveal_type(User.extra)
+    assert_type(User.extra, InstrumentedAttribute[str | None])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str \| None\]
-    reveal_type(User.extra_name)
+    assert_type(User.extra_name, InstrumentedAttribute[str | None])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\]
-    reveal_type(Address.email)
+    assert_type(Address.email, InstrumentedAttribute[str])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\]
-    reveal_type(Address.email_name)
+    assert_type(Address.email_name, InstrumentedAttribute[str])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[relationship.Address\]\]
-    reveal_type(User.addresses_style_one)
+    assert_type(User.addresses_style_one, InstrumentedAttribute[list[Address]])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[relationship.Address\]\]
-    reveal_type(User.addresses_style_two)
+    assert_type(User.addresses_style_two, InstrumentedAttribute[set[Address]])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[relationship.User\]\]
-    reveal_type(Group.addresses_style_one_anno_only)
+    assert_type(
+        Group.addresses_style_one_anno_only, InstrumentedAttribute[list[User]]
+    )
 
-    # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[relationship.User\]\]
-    reveal_type(Group.addresses_style_two_anno_only)
+    assert_type(
+        Group.addresses_style_two_anno_only, InstrumentedAttribute[set[User]]
+    )
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[relationship.MoreMail\]\]
-    reveal_type(Address.rel_style_one)
+    assert_type(Address.rel_style_one, InstrumentedAttribute[list[MoreMail]])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[relationship.MoreMail\]\]
-    reveal_type(Address.rel_style_one_anno_only)
+    assert_type(
+        Address.rel_style_one_anno_only, InstrumentedAttribute[set[MoreMail]]
+    )
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.QueryableAttribute\[relationship.Engineer\]
-    reveal_type(Team.employees.of_type(Engineer))
+    assert_type(Team.employees.of_type(Engineer), QueryableAttribute[Engineer])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.QueryableAttribute\[relationship.Employee\]
-    reveal_type(Team.employees.of_type(aliased(Employee)))
+    assert_type(
+        Team.employees.of_type(aliased(Employee)), QueryableAttribute[Employee]
+    )
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.QueryableAttribute\[relationship.Engineer\]
-    reveal_type(Team.employees.of_type(aliased(Engineer)))
+    assert_type(
+        Team.employees.of_type(aliased(Engineer)), QueryableAttribute[Engineer]
+    )
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.QueryableAttribute\[relationship.Employee\]
-    reveal_type(Team.employees.of_type(with_polymorphic(Employee, [Engineer])))
+    assert_type(
+        Team.employees.of_type(with_polymorphic(Employee, [Engineer])),
+        QueryableAttribute[Employee],
+    )
 
 
 mapper_registry: registry = registry()
index 1cc5b1c014a52cf6e69d224f67b053f1bed730f3..af0de3386b7bb68d43efd205094e1d3a76105f5b 100644 (file)
@@ -1,10 +1,12 @@
 from __future__ import annotations
 
 import asyncio
+from typing import assert_type
 from typing import List
 
 from sqlalchemy import create_engine
 from sqlalchemy import ForeignKey
+from sqlalchemy.engine.row import Row
 from sqlalchemy.ext.asyncio import async_scoped_session
 from sqlalchemy.ext.asyncio import async_sessionmaker
 from sqlalchemy.ext.asyncio import AsyncSession
@@ -15,6 +17,8 @@ from sqlalchemy.orm import relationship
 from sqlalchemy.orm import scoped_session
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import sessionmaker
+from sqlalchemy.orm import SessionTransaction
+from sqlalchemy.orm.query import Query
 
 
 class Base(DeclarativeBase):
@@ -50,19 +54,16 @@ with Session(e) as sess:
 
     q = sess.query(User).filter_by(id=7)
 
-    # EXPECTED_TYPE: Query[User]
-    reveal_type(q)
+    assert_type(q, Query[User])
 
     rows1 = q.all()
 
-    # EXPECTED_RE_TYPE: builtins.list\[.*User\*?\]
-    reveal_type(rows1)
+    assert_type(rows1, list[User])
 
     q2 = sess.query(User.id).filter_by(id=7)
     rows2 = q2.all()
 
-    # EXPECTED_TYPE: list[.*Row[.*int].*]
-    reveal_type(rows2)
+    assert_type(rows2, list[Row[int]])
 
     # test #8280
 
@@ -86,12 +87,10 @@ with Session(e) as sess:
     # test #9125
 
     for row in sess.query(User.id, User.name):
-        # EXPECTED_TYPE: .*Row[int, str].*
-        reveal_type(row)
+        assert_type(row, Row[int, str])
 
     for uobj1 in sess.query(User):
-        # EXPECTED_TYPE: User
-        reveal_type(uobj1)
+        assert_type(uobj1, User)
 
     sess.query(User).limit(None).offset(None).limit(10).offset(10).limit(
         User.id
@@ -100,8 +99,7 @@ with Session(e) as sess:
     # test #11083
 
     with sess.begin() as tx:
-        # EXPECTED_TYPE: SessionTransaction
-        reveal_type(tx)
+        assert_type(tx, SessionTransaction)
 
 # more result tests in typed_results.py
 
index 60d2e8b33e3ee63dc7aeddb819acb6bd8a97d096..8e959cea9859e493ed678302ab14629b0c8b3034 100644 (file)
@@ -1,5 +1,7 @@
 """test sessionmaker, originally for #7656"""
 
+from typing import assert_type
+
 from sqlalchemy import create_engine
 from sqlalchemy import Engine
 from sqlalchemy.ext.asyncio import async_scoped_session
@@ -11,6 +13,7 @@ from sqlalchemy.orm import QueryPropertyDescriptor
 from sqlalchemy.orm import scoped_session
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import sessionmaker
+from sqlalchemy.orm.query import Query
 
 
 async_engine = create_async_engine("...")
@@ -39,19 +42,16 @@ async def async_main() -> None:
     fac = async_session_factory(async_engine)
 
     async with fac() as sess:
-        # EXPECTED_TYPE: MyAsyncSession
-        reveal_type(sess)
+        assert_type(sess, MyAsyncSession)
 
     async with fac.begin() as sess:
-        # EXPECTED_TYPE: MyAsyncSession
-        reveal_type(sess)
+        assert_type(sess, MyAsyncSession)
 
     scoped_fac = async_scoped_session_factory(async_engine)
 
     sess = scoped_fac()
 
-    # EXPECTED_TYPE: MyAsyncSession
-    reveal_type(sess)
+    assert_type(sess, MyAsyncSession)
 
 
 engine = create_engine("...")
@@ -75,49 +75,41 @@ def main() -> None:
     fac = session_factory(engine)
 
     with fac() as sess:
-        # EXPECTED_TYPE: MySession
-        reveal_type(sess)
+        assert_type(sess, MySession)
 
     with fac.begin() as sess:
-        # EXPECTED_TYPE: MySession
-        reveal_type(sess)
+        assert_type(sess, MySession)
 
     scoped_fac = scoped_session_factory(engine)
 
     sess = scoped_fac()
-    # EXPECTED_TYPE: MySession
-    reveal_type(sess)
+    assert_type(sess, MySession)
 
 
 def test_8837_sync() -> None:
     sm = sessionmaker()
 
-    # EXPECTED_TYPE: sessionmaker[Session]
-    reveal_type(sm)
+    assert_type(sm, sessionmaker[Session])
 
     session = sm()
 
-    # EXPECTED_TYPE: Session
-    reveal_type(session)
+    assert_type(session, Session)
 
 
 def test_8837_async() -> None:
     as_ = async_sessionmaker()
 
-    # EXPECTED_TYPE: async_sessionmaker[AsyncSession]
-    reveal_type(as_)
+    assert_type(as_, async_sessionmaker[AsyncSession])
 
     async_session = as_()
 
-    # EXPECTED_TYPE: AsyncSession
-    reveal_type(async_session)
+    assert_type(async_session, AsyncSession)
 
 
 # test #9338
 ss_9338 = scoped_session_factory(engine)
 
-# EXPECTED_TYPE: QueryPropertyDescriptor
-reveal_type(ss_9338.query_property())
+assert_type(ss_9338.query_property(), QueryPropertyDescriptor)
 qp: QueryPropertyDescriptor = ss_9338.query_property()
 
 
@@ -125,16 +117,13 @@ class Foo:
     query = qp
 
 
-# EXPECTED_TYPE: Query[Foo]
-reveal_type(Foo.query)
+assert_type(Foo.query, Query[Foo])
 
-# EXPECTED_TYPE: list[Foo]
-reveal_type(Foo.query.all())
+assert_type(Foo.query.all(), list[Foo])
 
 
 class Bar:
     query: QueryPropertyDescriptor = ss_9338.query_property()
 
 
-# EXPECTED_TYPE: Query[Bar]
-reveal_type(Bar.query)
+assert_type(Bar.query, Query[Bar])
index e15fe709341904fe19aee819adca2b3c0a1bbdbb..f8f9111e82a2c3061c590149aa4b4aea896bc0cd 100644 (file)
@@ -1,6 +1,8 @@
 """traditional relationship patterns with explicit uselist."""
 
 import typing
+from typing import Any
+from typing import assert_type
 from typing import cast
 from typing import Dict
 from typing import List
@@ -11,6 +13,7 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import InstrumentedAttribute
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
@@ -99,45 +102,34 @@ class Address(Base):
 
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[trad_relationship_uselist.Address\]\]
-    reveal_type(User.addresses_style_one)
+    assert_type(User.addresses_style_one, InstrumentedAttribute[list[Address]])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[trad_relationship_uselist.Address\]\]
-    reveal_type(User.addresses_style_two)
+    assert_type(User.addresses_style_two, InstrumentedAttribute[set[Address]])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
-    reveal_type(User.addresses_style_three)
+    assert_type(User.addresses_style_three, InstrumentedAttribute[Any])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
-    reveal_type(User.addresses_style_three_cast)
+    assert_type(User.addresses_style_three_cast, InstrumentedAttribute[Any])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
-    reveal_type(User.addresses_style_four)
+    assert_type(User.addresses_style_four, InstrumentedAttribute[Any])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
-    reveal_type(Address.user_style_one)
+    assert_type(Address.user_style_one, InstrumentedAttribute[Any])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\]
-    reveal_type(Address.user_style_one_typed)
+    assert_type(Address.user_style_one_typed, InstrumentedAttribute[User])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
-    reveal_type(Address.user_style_two)
+    assert_type(Address.user_style_two, InstrumentedAttribute[Any])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\]
-    reveal_type(Address.user_style_two_typed)
+    assert_type(Address.user_style_two_typed, InstrumentedAttribute[User])
 
     # reveal_type(Address.user_style_six)
 
     # reveal_type(Address.user_style_seven)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
-    reveal_type(Address.user_style_eight)
+    assert_type(Address.user_style_eight, InstrumentedAttribute[Any])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
-    reveal_type(Address.user_style_nine)
+    assert_type(Address.user_style_nine, InstrumentedAttribute[Any])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
-    reveal_type(Address.user_style_ten)
+    assert_type(Address.user_style_ten, InstrumentedAttribute[Any])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.dict\*?\[builtins.str, trad_relationship_uselist.User\]\]
-    reveal_type(Address.user_style_ten_typed)
+    assert_type(
+        Address.user_style_ten_typed, InstrumentedAttribute[dict[str, User]]
+    )
index bd6bada528c4a614465e50365125cb39dbb05548..062ea2b3f086fb0d1d8a67f37e3e004007cf6288 100644 (file)
@@ -7,6 +7,8 @@ if no uselists are present.
 """
 
 import typing
+from typing import Any
+from typing import assert_type
 from typing import List
 from typing import Set
 
@@ -14,6 +16,7 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import InstrumentedAttribute
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
@@ -80,29 +83,20 @@ class Address(Base):
 
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.Address\]\]
-    reveal_type(User.addresses_style_one)
+    assert_type(User.addresses_style_one, InstrumentedAttribute[list[Address]])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[traditional_relationship.Address\]\]
-    reveal_type(User.addresses_style_two)
+    assert_type(User.addresses_style_two, InstrumentedAttribute[set[Address]])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
-    reveal_type(Address.user_style_one)
+    assert_type(Address.user_style_one, InstrumentedAttribute[Any])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\]
-    reveal_type(Address.user_style_one_typed)
+    assert_type(Address.user_style_one_typed, InstrumentedAttribute[User])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
-    reveal_type(Address.user_style_two)
+    assert_type(Address.user_style_two, InstrumentedAttribute[Any])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\]
-    reveal_type(Address.user_style_two_typed)
+    assert_type(Address.user_style_two_typed, InstrumentedAttribute[User])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\]
-    reveal_type(Address.user_style_three)
+    assert_type(Address.user_style_three, InstrumentedAttribute[list[User]])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\]
-    reveal_type(Address.user_style_four)
+    assert_type(Address.user_style_four, InstrumentedAttribute[list[User]])
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
-    reveal_type(Address.user_style_five)
+    assert_type(Address.user_style_five, InstrumentedAttribute[Any])
index a3c07dd016f77e3603f7a8d90a21e1c34522b5c6..d21922e8691e4e610d87afbda3dc6d927052c184 100644 (file)
@@ -1,5 +1,9 @@
 from __future__ import annotations
 
+from typing import Any
+from typing import assert_type
+from typing import Unpack
+
 from sqlalchemy import Column
 from sqlalchemy import column
 from sqlalchemy import create_engine
@@ -16,11 +20,21 @@ from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import text
 from sqlalchemy import update
+from sqlalchemy.engine import Result
+from sqlalchemy.engine.row import Row
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import Session
+from sqlalchemy.orm.query import Query
+from sqlalchemy.orm.query import RowReturningQuery
+from sqlalchemy.sql.dml import ReturningInsert
+from sqlalchemy.sql.elements import KeyedColumnElement
+from sqlalchemy.sql.expression import FromClause
+from sqlalchemy.sql.expression import TextClause
+from sqlalchemy.sql.selectable import ScalarSelect
+from sqlalchemy.sql.selectable import TextualSelect
 
 
 class Base(DeclarativeBase):
@@ -51,13 +65,11 @@ connection = e.connect()
 def t_select_1() -> None:
     stmt = select(User.id, User.name).filter(User.id == 5)
 
-    # EXPECTED_TYPE: Select[int, str]
-    reveal_type(stmt)
+    assert_type(stmt, Select[int, str])
 
     result = session.execute(stmt)
 
-    # EXPECTED_TYPE: .*Result[int, str].*
-    reveal_type(result)
+    assert_type(result, Result[int, str])
 
 
 def t_select_2() -> None:
@@ -75,13 +87,11 @@ def t_select_2() -> None:
         .fetch(User.id)
     )
 
-    # EXPECTED_TYPE: Select[User]
-    reveal_type(stmt)
+    assert_type(stmt, Select[User])
 
     result = session.execute(stmt)
 
-    # EXPECTED_TYPE: .*Result[User].*
-    reveal_type(result)
+    assert_type(result, Result[User])
 
 
 def t_select_3() -> None:
@@ -95,196 +105,157 @@ def t_select_3() -> None:
     # awkwardnesses that aren't really worth it
     ua(id=1, name="foo")
 
-    # EXPECTED_RE_TYPE: [tT]ype\[.*\.User\]
-    reveal_type(ua)
+    assert_type(ua, type[User])
 
     stmt = select(ua.id, ua.name).filter(User.id == 5)
 
-    # EXPECTED_TYPE: Select[int, str]
-    reveal_type(stmt)
+    assert_type(stmt, Select[int, str])
 
     result = session.execute(stmt)
 
-    # EXPECTED_TYPE: .*Result[int, str].*
-    reveal_type(result)
+    assert_type(result, Result[int, str])
 
 
 def t_select_4() -> None:
     ua = aliased(User)
     stmt = select(ua, User).filter(User.id == 5)
 
-    # EXPECTED_TYPE: Select[User, User]
-    reveal_type(stmt)
+    assert_type(stmt, Select[User, User])
 
     result = session.execute(stmt)
 
-    # EXPECTED_TYPE: Result[User, User]
-    reveal_type(result)
+    assert_type(result, Result[User, User])
 
 
 def t_legacy_query_single_entity() -> None:
     q1 = session.query(User).filter(User.id == 5)
 
-    # EXPECTED_TYPE: Query[User]
-    reveal_type(q1)
+    assert_type(q1, Query[User])
 
-    # EXPECTED_TYPE: User
-    reveal_type(q1.one())
+    assert_type(q1.one(), User)
 
-    # EXPECTED_TYPE: list[User]
-    reveal_type(q1.all())
+    assert_type(q1.all(), list[User])
 
     # mypy switches to builtins.list for some reason here
-    # EXPECTED_RE_TYPE: .*\.list\[.*Row\*?\[.*User\].*\]
-    reveal_type(q1.only_return_tuples(True).all())
+    assert_type(q1.only_return_tuples(True).all(), list[Row[User]])
 
-    # EXPECTED_TYPE: list[tuple[User]]
-    reveal_type(q1.tuples().all())
+    assert_type(q1.tuples().all(), list[tuple[User]])
 
 
 def t_legacy_query_cols_1() -> None:
     q1 = session.query(User.id, User.name).filter(User.id == 5)
 
-    # EXPECTED_TYPE: RowReturningQuery[int, str]
-    reveal_type(q1)
+    assert_type(q1, RowReturningQuery[int, str])
 
-    # EXPECTED_TYPE: .*Row[int, str].*
-    reveal_type(q1.one())
+    assert_type(q1.one(), Row[int, str])
 
     r1 = q1.one()
 
     x, y = r1
 
-    # EXPECTED_TYPE: int
-    reveal_type(x)
+    assert_type(x, int)
 
-    # EXPECTED_TYPE: str
-    reveal_type(y)
+    assert_type(y, str)
 
 
 def t_legacy_query_cols_tupleq_1() -> None:
     q1 = session.query(User.id, User.name).filter(User.id == 5)
 
-    # EXPECTED_TYPE: RowReturningQuery[int, str]
-    reveal_type(q1)
+    assert_type(q1, RowReturningQuery[int, str])
 
     q2 = q1.tuples()
 
-    # EXPECTED_TYPE: tuple[int, str]
-    reveal_type(q2.one())
+    assert_type(q2.one(), tuple[int, str])
 
     r1 = q2.one()
 
     x, y = r1
 
-    # EXPECTED_TYPE: int
-    reveal_type(x)
+    assert_type(x, int)
 
-    # EXPECTED_TYPE: str
-    reveal_type(y)
+    assert_type(y, str)
 
 
 def t_legacy_query_cols_1_with_entities() -> None:
     q1 = session.query(User).filter(User.id == 5)
 
-    # EXPECTED_TYPE: Query[User]
-    reveal_type(q1)
+    assert_type(q1, Query[User])
 
     q2 = q1.with_entities(User.id, User.name)
 
-    # EXPECTED_TYPE: RowReturningQuery[int, str]
-    reveal_type(q2)
+    assert_type(q2, RowReturningQuery[int, str])
 
-    # EXPECTED_TYPE: .*Row[int, str].*
-    reveal_type(q2.one())
+    assert_type(q2.one(), Row[int, str])
 
     r1 = q2.one()
 
     x, y = r1
 
-    # EXPECTED_TYPE: int
-    reveal_type(x)
+    assert_type(x, int)
 
-    # EXPECTED_TYPE: str
-    reveal_type(y)
+    assert_type(y, str)
 
 
 def t_select_with_only_cols() -> None:
     q1 = select(User).where(User.id == 5)
 
-    # EXPECTED_TYPE: Select[User]
-    reveal_type(q1)
+    assert_type(q1, Select[User])
 
     q2 = q1.with_only_columns(User.id, User.name)
 
-    # EXPECTED_TYPE: Select[int, str]
-    reveal_type(q2)
+    assert_type(q2, Select[int, str])
 
     row = connection.execute(q2).one()
 
-    # EXPECTED_TYPE: .*Row[int, str].*
-    reveal_type(row)
+    assert_type(row, Row[int, str])
 
     x, y = row
 
-    # EXPECTED_TYPE: int
-    reveal_type(x)
+    assert_type(x, int)
 
-    # EXPECTED_TYPE: str
-    reveal_type(y)
+    assert_type(y, str)
 
 
 def t_legacy_query_cols_2() -> None:
     a1 = aliased(User)
     q1 = session.query(User, a1, User.name).filter(User.id == 5)
 
-    # EXPECTED_TYPE: RowReturningQuery[User, User, str]
-    reveal_type(q1)
+    assert_type(q1, RowReturningQuery[User, User, str])
 
-    # EXPECTED_TYPE: .*Row[User, User, str].*
-    reveal_type(q1.one())
+    assert_type(q1.one(), Row[User, User, str])
 
     r1 = q1.one()
 
     x, y, z = r1
 
-    # EXPECTED_TYPE: User
-    reveal_type(x)
+    assert_type(x, User)
 
-    # EXPECTED_TYPE: User
-    reveal_type(y)
+    assert_type(y, User)
 
-    # EXPECTED_TYPE: str
-    reveal_type(z)
+    assert_type(z, str)
 
 
 def t_legacy_query_cols_2_with_entities() -> None:
     q1 = session.query(User)
 
-    # EXPECTED_TYPE: Query[User]
-    reveal_type(q1)
+    assert_type(q1, Query[User])
 
     a1 = aliased(User)
     q2 = q1.with_entities(User, a1, User.name).filter(User.id == 5)
 
-    # EXPECTED_TYPE: RowReturningQuery[User, User, str]
-    reveal_type(q2)
+    assert_type(q2, RowReturningQuery[User, User, str])
 
-    # EXPECTED_TYPE: .*Row[User, User, str].*
-    reveal_type(q2.one())
+    assert_type(q2.one(), Row[User, User, str])
 
     r1 = q2.one()
 
     x, y, z = r1
 
-    # EXPECTED_TYPE: User
-    reveal_type(x)
+    assert_type(x, User)
 
-    # EXPECTED_TYPE: User
-    reveal_type(y)
+    assert_type(y, User)
 
-    # EXPECTED_TYPE: str
-    reveal_type(z)
+    assert_type(z, str)
 
 
 def t_select_add_col_loses_type() -> None:
@@ -293,8 +264,7 @@ def t_select_add_col_loses_type() -> None:
     q2 = q1.add_columns(User.data)
 
     # note this should not match Select
-    # EXPECTED_TYPE: Select[Unpack[.*tuple[Any, ...]]]
-    reveal_type(q2)
+    assert_type(q2, Select[Unpack[tuple[Any, ...]]])
 
 
 def t_legacy_query_add_col_loses_type() -> None:
@@ -303,14 +273,12 @@ def t_legacy_query_add_col_loses_type() -> None:
     q2 = q1.add_columns(User.data)
 
     # this should match only Any
-    # EXPECTED_TYPE: Query[Any]
-    reveal_type(q2)
+    assert_type(q2, Query[Any])
 
     ua = aliased(User)
     q3 = q1.add_entity(ua)
 
-    # EXPECTED_TYPE: Query[Any]
-    reveal_type(q3)
+    assert_type(q3, Query[Any])
 
 
 def t_legacy_query_scalar_subquery() -> None:
@@ -322,29 +290,25 @@ def t_legacy_query_scalar_subquery() -> None:
 
     # this should be int but mypy can't see it due to the
     # overload that tries to match an entity.
-    # EXPECTED_RE_TYPE: .*ScalarSelect\[(?:int|Any)\]
-    reveal_type(q2)
+    assert_type(q2, ScalarSelect[Any])
 
     q3 = session.query(User)
 
     q4 = q3.scalar_subquery()
 
-    # EXPECTED_TYPE: ScalarSelect[Any]
-    reveal_type(q4)
+    assert_type(q4, ScalarSelect[Any])
 
     q5 = session.query(User, User.name)
 
     q6 = q5.scalar_subquery()
 
-    # EXPECTED_TYPE: ScalarSelect[Any]
-    reveal_type(q6)
+    assert_type(q6, ScalarSelect[Any])
 
     # try to simulate the problem with select()
     q7 = session.query(User).only_return_tuples(True)
     q8 = q7.scalar_subquery()
 
-    # EXPECTED_TYPE: ScalarSelect[Any]
-    reveal_type(q8)
+    assert_type(q8, ScalarSelect[Any])
 
 
 def t_select_scalar_subquery() -> None:
@@ -355,16 +319,14 @@ def t_select_scalar_subquery() -> None:
 
     # this should be int but mypy can't see it due to the
     # overload that tries to match an entity.
-    # EXPECTED_TYPE: ScalarSelect[Any]
-    reveal_type(s2)
+    assert_type(s2, ScalarSelect[Any])
 
     s3 = select(User)
     s4 = s3.scalar_subquery()
 
     # it's more important that mypy doesn't get a false positive of
     # 'User' here
-    # EXPECTED_TYPE: ScalarSelect[Any]
-    reveal_type(s4)
+    assert_type(s4, ScalarSelect[Any])
 
 
 def t_select_w_core_selectables() -> None:
@@ -374,8 +336,7 @@ def t_select_w_core_selectables() -> None:
     """
     s1 = select(User.id, User.name).subquery()
 
-    # EXPECTED_TYPE: KeyedColumnElement[Any]
-    reveal_type(s1.c.name)
+    assert_type(s1.c.name, KeyedColumnElement[Any])
 
     s2 = select(User.id, s1.c.name)
 
@@ -386,31 +347,26 @@ def t_select_w_core_selectables() -> None:
     # mypy would downgrade to Any rather than picking the basemost type.
     # with typing integrated into Select etc. we can at least get a Select
     # object back.
-    # EXPECTED_TYPE: Select[Unpack[.*tuple[Any, ...]]]
-    reveal_type(s2)
+    assert_type(s2, Select[Unpack[tuple[Any, ...]]])
 
     # so a fully explicit type may be given
     s2_typed: Select[tuple[int, str]] = select(User.id, s1.c.name)
 
-    # EXPECTED_TYPE: Select[tuple[int, str]]
-    reveal_type(s2_typed)
+    assert_type(s2_typed, Select[tuple[int, str]])
 
     # plain FromClause etc we at least get Select
     s3 = select(s1)
 
-    # EXPECTED_TYPE: Select[Unpack[.*tuple[Any, ...]]]
-    reveal_type(s3)
+    assert_type(s3, Select[Unpack[tuple[Any, ...]]])
 
     t1 = User.__table__
     assert t1 is not None
 
-    # EXPECTED_TYPE: FromClause
-    reveal_type(t1)
+    assert_type(t1, FromClause)
 
     s4 = select(t1)
 
-    # EXPECTED_TYPE: Select[Unpack[.*tuple[Any, ...]]]
-    reveal_type(s4)
+    assert_type(s4, Select[Unpack[tuple[Any, ...]]])
 
 
 def t_dml_insert() -> None:
@@ -418,53 +374,45 @@ def t_dml_insert() -> None:
 
     r1 = session.execute(s1)
 
-    # EXPECTED_TYPE: Result[int, str]
-    reveal_type(r1)
+    assert_type(r1, Result[int, str])
 
     s2 = insert(User).returning(User)
 
     r2 = session.execute(s2)
 
-    # EXPECTED_TYPE: Result[User]
-    reveal_type(r2)
+    assert_type(r2, Result[User])
 
     s3 = insert(User).returning(func.foo(), column("q"))
 
-    # EXPECTED_TYPE: ReturningInsert[Unpack[.*tuple[Any, ...]]]
-    reveal_type(s3)
+    assert_type(s3, ReturningInsert[Unpack[tuple[Any, ...]]])
 
     r3 = session.execute(s3)
 
-    # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]]
-    reveal_type(r3)
+    assert_type(r3, Result[Unpack[tuple[Any, ...]]])
 
 
 def t_dml_bare_insert() -> None:
     s1 = insert(User)
     r1 = session.execute(s1)
-    # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]]
-    reveal_type(r1)
+    assert_type(r1, Result[Unpack[tuple[Any, ...]]])
 
 
 def t_dml_bare_update() -> None:
     s1 = update(User)
     r1 = session.execute(s1)
-    # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]]
-    reveal_type(r1)
+    assert_type(r1, Result[Unpack[tuple[Any, ...]]])
 
 
 def t_dml_update_with_values() -> None:
     s1 = update(User).values({User.id: 123, User.data: "value"})
     r1 = session.execute(s1)
-    # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]]
-    reveal_type(r1)
+    assert_type(r1, Result[Unpack[tuple[Any, ...]]])
 
 
 def t_dml_bare_delete() -> None:
     s1 = delete(User)
     r1 = session.execute(s1)
-    # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]]
-    reveal_type(r1)
+    assert_type(r1, Result[Unpack[tuple[Any, ...]]])
 
 
 def t_dml_update() -> None:
@@ -472,8 +420,7 @@ def t_dml_update() -> None:
 
     r1 = session.execute(s1)
 
-    # EXPECTED_TYPE: Result[int, str]
-    reveal_type(r1)
+    assert_type(r1, Result[int, str])
 
 
 def t_dml_delete() -> None:
@@ -481,22 +428,19 @@ def t_dml_delete() -> None:
 
     r1 = session.execute(s1)
 
-    # EXPECTED_TYPE: Result[int, str]
-    reveal_type(r1)
+    assert_type(r1, Result[int, str])
 
 
 def t_from_statement() -> None:
     t = text("select * from user")
 
-    # EXPECTED_TYPE: TextClause
-    reveal_type(t)
+    assert_type(t, TextClause)
 
     select(User).from_statement(t)
 
     ts = text("select * from user").columns(User.id, User.name)
 
-    # EXPECTED_TYPE: TextualSelect
-    reveal_type(ts)
+    assert_type(ts, TextualSelect)
 
     select(User).from_statement(ts)
 
@@ -504,8 +448,7 @@ def t_from_statement() -> None:
         user_table.c.id, user_table.c.name
     )
 
-    # EXPECTED_TYPE: TextualSelect
-    reveal_type(ts2)
+    assert_type(ts2, TextualSelect)
 
     select(User).from_statement(ts2)
 
@@ -519,17 +462,13 @@ def t_aliased_fromclause() -> None:
 
     a4 = aliased(user_table)
 
-    # EXPECTED_RE_TYPE: [tT]ype\[.*\.User\]
-    reveal_type(a1)
+    assert_type(a1, type[User])
 
-    # EXPECTED_RE_TYPE: [tT]ype\[.*\.User\]
-    reveal_type(a2)
+    assert_type(a2, type[User])
 
-    # EXPECTED_RE_TYPE: [tT]ype\[.*\.User\]
-    reveal_type(a3)
+    assert_type(a3, type[User])
 
-    # EXPECTED_TYPE: FromClause
-    reveal_type(a4)
+    assert_type(a4, FromClause)
 
 
 def test_select_from() -> None:
index 619cde74e8cadcf0439c8aa076cd800bdd473ab9..0ea8663e2dbe7eb3ef046ee35568f40be51a5754 100644 (file)
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 import typing
+from typing import assert_type
 
 from sqlalchemy import ForeignKey
 from sqlalchemy import select
@@ -10,6 +11,7 @@ from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import WriteOnlyMapped
+from sqlalchemy.orm.writeonly import WriteOnlyCollection
 
 
 class Base(DeclarativeBase):
@@ -35,16 +37,14 @@ with Session() as session:
     session.commit()
 
     if typing.TYPE_CHECKING:
-        # EXPECTED_TYPE: WriteOnlyCollection[Address]
-        reveal_type(u.addresses)
+        assert_type(u.addresses, WriteOnlyCollection[Address])
 
     address = session.scalars(
         u.addresses.select().filter(Address.email_address.like("xyz"))
     ).one()
 
     if typing.TYPE_CHECKING:
-        # EXPECTED_TYPE: Address
-        reveal_type(address)
+        assert_type(address, Address)
 
     u.addresses.add(Address())
     u.addresses.add_all([Address(), Address()])
index 3428a640df8dc77105c47a204eedb0e349099ba0..e8a10e553db85adb4d238e413774ed068c6add7b 100644 (file)
@@ -8,6 +8,10 @@ unions.
 
 from __future__ import annotations
 
+from typing import assert_type
+from typing import Never
+from typing import Unpack
+
 from sqlalchemy import asc
 from sqlalchemy import Column
 from sqlalchemy import column
@@ -20,16 +24,21 @@ from sqlalchemy import intersect
 from sqlalchemy import intersect_all
 from sqlalchemy import literal
 from sqlalchemy import MetaData
+from sqlalchemy import Select
 from sqlalchemy import select
 from sqlalchemy import SQLColumnExpression
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import union
 from sqlalchemy import union_all
+from sqlalchemy.engine import Result
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import Session
+from sqlalchemy.orm.query import RowReturningQuery
+from sqlalchemy.sql.expression import BindParameter
+from sqlalchemy.sql.expression import CompoundSelect
 
 
 class Base(DeclarativeBase):
@@ -67,26 +76,22 @@ def core_expr(email: str) -> SQLColumnExpression[bool]:
 
 e1 = orm_expr("hi")
 
-# EXPECTED_TYPE: SQLColumnExpression[bool]
-reveal_type(e1)
+assert_type(e1, SQLColumnExpression[bool])
 
 stmt = select(e1)
 
-# EXPECTED_TYPE: Select[bool]
-reveal_type(stmt)
+assert_type(stmt, Select[bool])
 
 stmt = stmt.where(e1)
 
 
 e2 = core_expr("hi")
 
-# EXPECTED_TYPE: SQLColumnExpression[bool]
-reveal_type(e2)
+assert_type(e2, SQLColumnExpression[bool])
 
 stmt = select(e2)
 
-# EXPECTED_TYPE: Select[bool]
-reveal_type(stmt)
+assert_type(stmt, Select[bool])
 
 stmt = stmt.where(e2)
 
@@ -95,20 +100,17 @@ stmt2 = select(User.id).order_by("id", "email").group_by("email", "id")
 stmt2 = (
     select(User.id).order_by(asc("id"), desc("email")).group_by("email", "id")
 )
-# EXPECTED_TYPE: Select[int]
-reveal_type(stmt2)
+assert_type(stmt2, Select[int])
 
 stmt2 = select(User.id).order_by(User.id).group_by(User.email)
 stmt2 = (
     select(User.id).order_by(User.id, User.email).group_by(User.email, User.id)
 )
-# EXPECTED_TYPE: Select[int]
-reveal_type(stmt2)
+assert_type(stmt2, Select[int])
 
 stmt3 = select(User.id).exists().select()
 
-# EXPECTED_TYPE: Select[bool]
-reveal_type(stmt3)
+assert_type(stmt3, Select[bool])
 
 
 receives_str_col_expr(User.email)
@@ -129,8 +131,7 @@ receives_bool_col_expr(user_table.c.email == "x")
 
 q1 = Session().query(User.id).order_by("email").group_by("email")
 q1 = Session().query(User.id).order_by("id", "email").group_by("email", "id")
-# EXPECTED_TYPE: RowReturningQuery[int]
-reveal_type(q1)
+assert_type(q1, RowReturningQuery[int])
 
 q1 = Session().query(User.id).order_by(User.id).group_by(User.email)
 q1 = (
@@ -139,8 +140,7 @@ q1 = (
     .order_by(User.id, User.email)
     .group_by(User.email, User.id)
 )
-# EXPECTED_TYPE: RowReturningQuery[int]
-reveal_type(q1)
+assert_type(q1, RowReturningQuery[int])
 
 # test 9174
 s9174_1 = select(User).with_for_update(of=User)
@@ -164,14 +164,10 @@ user = session.query(user_table).with_for_update(
 )
 
 # literal
-# EXPECTED_TYPE: BindParameter[str]
-reveal_type(literal("5"))
-# EXPECTED_TYPE: BindParameter[str]
-reveal_type(literal("5", None))
-# EXPECTED_TYPE: BindParameter[int]
-reveal_type(literal("123", Integer))
-# EXPECTED_TYPE: BindParameter[int]
-reveal_type(literal("123", Integer))
+assert_type(literal("5"), BindParameter[str])
+assert_type(literal("5", None), BindParameter[str])
+assert_type(literal("123", Integer), BindParameter[int])
+assert_type(literal("123", Integer), BindParameter[int])
 
 
 # hashable (issue #10353):
@@ -193,36 +189,26 @@ first_stmt = select(str_col, int_col)
 second_stmt = select(str_col, int_col)
 third_stmt = select(int_col, str_col)
 
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(union(first_stmt, second_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(union_all(first_stmt, second_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(except_(first_stmt, second_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(except_all(first_stmt, second_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(intersect(first_stmt, second_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(intersect_all(first_stmt, second_stmt))
-
-# EXPECTED_TYPE: Result[str, int]
-reveal_type(Session().execute(union(first_stmt, second_stmt)))
-# EXPECTED_TYPE: Result[str, int]
-reveal_type(Session().execute(union_all(first_stmt, second_stmt)))
-
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(first_stmt.union(second_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(first_stmt.union_all(second_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(first_stmt.except_(second_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(first_stmt.except_all(second_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(first_stmt.intersect(second_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(first_stmt.intersect_all(second_stmt))
+assert_type(union(first_stmt, second_stmt), CompoundSelect[str, int])
+assert_type(union_all(first_stmt, second_stmt), CompoundSelect[str, int])
+assert_type(except_(first_stmt, second_stmt), CompoundSelect[str, int])
+assert_type(except_all(first_stmt, second_stmt), CompoundSelect[str, int])
+assert_type(intersect(first_stmt, second_stmt), CompoundSelect[str, int])
+assert_type(intersect_all(first_stmt, second_stmt), CompoundSelect[str, int])
+
+assert_type(
+    Session().execute(union(first_stmt, second_stmt)), Result[str, int]
+)
+assert_type(
+    Session().execute(union_all(first_stmt, second_stmt)), Result[str, int]
+)
+
+assert_type(first_stmt.union(second_stmt), CompoundSelect[str, int])
+assert_type(first_stmt.union_all(second_stmt), CompoundSelect[str, int])
+assert_type(first_stmt.except_(second_stmt), CompoundSelect[str, int])
+assert_type(first_stmt.except_all(second_stmt), CompoundSelect[str, int])
+assert_type(first_stmt.intersect(second_stmt), CompoundSelect[str, int])
+assert_type(first_stmt.intersect_all(second_stmt), CompoundSelect[str, int])
 
 # TODO: the following do not error because _SelectStatementForCompoundArgument
 # includes untyped elements so the type checker falls back on them when
@@ -230,28 +216,32 @@ reveal_type(first_stmt.intersect_all(second_stmt))
 # looses the plot and returns a random type back. See TODO in the
 # overloads
 
-# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]]
-reveal_type(union(first_stmt, third_stmt))
-# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]]
-reveal_type(union_all(first_stmt, third_stmt))
-# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]]
-reveal_type(except_(first_stmt, third_stmt))
-# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]]
-reveal_type(except_all(first_stmt, third_stmt))
-# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]]
-reveal_type(intersect(first_stmt, third_stmt))
-# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]]
-reveal_type(intersect_all(first_stmt, third_stmt))
-
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(first_stmt.union(third_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(first_stmt.union_all(third_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(first_stmt.except_(third_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(first_stmt.except_all(third_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(first_stmt.intersect(third_stmt))
-# EXPECTED_TYPE: CompoundSelect[str, int]
-reveal_type(first_stmt.intersect_all(third_stmt))
+assert_type(
+    union(first_stmt, third_stmt), CompoundSelect[Unpack[tuple[Never, ...]]]
+)
+assert_type(
+    union_all(first_stmt, third_stmt),
+    CompoundSelect[Unpack[tuple[Never, ...]]],
+)
+assert_type(
+    except_(first_stmt, third_stmt), CompoundSelect[Unpack[tuple[Never, ...]]]
+)
+assert_type(
+    except_all(first_stmt, third_stmt),
+    CompoundSelect[Unpack[tuple[Never, ...]]],
+)
+assert_type(
+    intersect(first_stmt, third_stmt),
+    CompoundSelect[Unpack[tuple[Never, ...]]],
+)
+assert_type(
+    intersect_all(first_stmt, third_stmt),
+    CompoundSelect[Unpack[tuple[Never, ...]]],
+)
+
+assert_type(first_stmt.union(third_stmt), CompoundSelect[str, int])
+assert_type(first_stmt.union_all(third_stmt), CompoundSelect[str, int])
+assert_type(first_stmt.except_(third_stmt), CompoundSelect[str, int])
+assert_type(first_stmt.except_all(third_stmt), CompoundSelect[str, int])
+assert_type(first_stmt.intersect(third_stmt), CompoundSelect[str, int])
+assert_type(first_stmt.intersect_all(third_stmt), CompoundSelect[str, int])
index 3660417887915a932c08df7c7ddf5698685c8653..beb72c4df6b41457ad2f725bf25e11442bcbe32b 100644 (file)
@@ -1,11 +1,18 @@
 """this file is generated by tools/generate_sql_functions.py"""
 
+from datetime import date
+from datetime import datetime
+from datetime import time
+from decimal import Decimal
+from typing import assert_type
+from typing import Sequence
+
 from sqlalchemy import column
 from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import Select
 from sqlalchemy import select
-from sqlalchemy import Sequence
+from sqlalchemy import Sequence as SqlAlchemySequence
 from sqlalchemy import String
 
 # START GENERATED FUNCTION TYPING TESTS
@@ -15,152 +22,127 @@ from sqlalchemy import String
 
 stmt1 = select(func.aggregate_strings(column("x", String), ","))
 
-# EXPECTED_RE_TYPE: .*Select\[.*str\]
-reveal_type(stmt1)
+assert_type(stmt1, Select[str])
 
 
 stmt2 = select(func.array_agg(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[.*Sequence\[.*int\]\]
-reveal_type(stmt2)
+assert_type(stmt2, Select[Sequence[int]])
 
 
 stmt3 = select(func.char_length(column("x")))
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
-reveal_type(stmt3)
+assert_type(stmt3, Select[int])
 
 
 stmt4 = select(func.coalesce(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
-reveal_type(stmt4)
+assert_type(stmt4, Select[int])
 
 
 stmt5 = select(func.concat())
 
-# EXPECTED_RE_TYPE: .*Select\[.*str\]
-reveal_type(stmt5)
+assert_type(stmt5, Select[str])
 
 
 stmt6 = select(func.count(column("x")))
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
-reveal_type(stmt6)
+assert_type(stmt6, Select[int])
 
 
 stmt7 = select(func.cume_dist())
 
-# EXPECTED_RE_TYPE: .*Select\[.*Decimal\]
-reveal_type(stmt7)
+assert_type(stmt7, Select[Decimal])
 
 
 stmt8 = select(func.current_date())
 
-# EXPECTED_RE_TYPE: .*Select\[.*date\]
-reveal_type(stmt8)
+assert_type(stmt8, Select[date])
 
 
 stmt9 = select(func.current_time())
 
-# EXPECTED_RE_TYPE: .*Select\[.*time\]
-reveal_type(stmt9)
+assert_type(stmt9, Select[time])
 
 
 stmt10 = select(func.current_timestamp())
 
-# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
-reveal_type(stmt10)
+assert_type(stmt10, Select[datetime])
 
 
 stmt11 = select(func.current_user())
 
-# EXPECTED_RE_TYPE: .*Select\[.*str\]
-reveal_type(stmt11)
+assert_type(stmt11, Select[str])
 
 
 stmt12 = select(func.dense_rank())
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
-reveal_type(stmt12)
+assert_type(stmt12, Select[int])
 
 
 stmt13 = select(func.localtime())
 
-# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
-reveal_type(stmt13)
+assert_type(stmt13, Select[datetime])
 
 
 stmt14 = select(func.localtimestamp())
 
-# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
-reveal_type(stmt14)
+assert_type(stmt14, Select[datetime])
 
 
 stmt15 = select(func.max(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
-reveal_type(stmt15)
+assert_type(stmt15, Select[int])
 
 
 stmt16 = select(func.min(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
-reveal_type(stmt16)
+assert_type(stmt16, Select[int])
 
 
-stmt17 = select(func.next_value(Sequence("x_seq")))
+stmt17 = select(func.next_value(SqlAlchemySequence("x_seq")))
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
-reveal_type(stmt17)
+assert_type(stmt17, Select[int])
 
 
 stmt18 = select(func.now())
 
-# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
-reveal_type(stmt18)
+assert_type(stmt18, Select[datetime])
 
 
 stmt19 = select(func.percent_rank())
 
-# EXPECTED_RE_TYPE: .*Select\[.*Decimal\]
-reveal_type(stmt19)
+assert_type(stmt19, Select[Decimal])
 
 
 stmt20 = select(func.pow(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
-reveal_type(stmt20)
+assert_type(stmt20, Select[int])
 
 
 stmt21 = select(func.rank())
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
-reveal_type(stmt21)
+assert_type(stmt21, Select[int])
 
 
 stmt22 = select(func.session_user())
 
-# EXPECTED_RE_TYPE: .*Select\[.*str\]
-reveal_type(stmt22)
+assert_type(stmt22, Select[str])
 
 
 stmt23 = select(func.sum(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
-reveal_type(stmt23)
+assert_type(stmt23, Select[int])
 
 
 stmt24 = select(func.sysdate())
 
-# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
-reveal_type(stmt24)
+assert_type(stmt24, Select[datetime])
 
 
 stmt25 = select(func.user())
 
-# EXPECTED_RE_TYPE: .*Select\[.*str\]
-reveal_type(stmt25)
+assert_type(stmt25, Select[str])
 
 # END GENERATED FUNCTION TYPING TESTS
 
index fc000277d06a1b53ad51833d6a5084cc7951e9fa..1be8c5ce7822f59e0acbfee5168521a3c087e711 100644 (file)
@@ -1,10 +1,21 @@
+from typing import Any
+from typing import assert_type
+
 from sqlalchemy import column
 from sqlalchemy import func
+from sqlalchemy import Function
 from sqlalchemy import Integer
+from sqlalchemy import Select
 from sqlalchemy import select
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.sql.expression import FunctionFilter
+from sqlalchemy.sql.expression import Over
+from sqlalchemy.sql.expression import WithinGroup
+from sqlalchemy.sql.functions import coalesce
+from sqlalchemy.sql.functions import max as functions_max
+from sqlalchemy.sql.selectable import TableValuedAlias
 
 
 class Base(DeclarativeBase):
@@ -20,53 +31,45 @@ class Foo(Base):
     c: Mapped[str]
 
 
-# EXPECTED_TYPE: Over[Any]
-reveal_type(func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc()))
+assert_type(
+    func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc()),
+    Over[Any],
+)
 func.row_number().over(order_by=[Foo.a.desc(), Foo.b.desc()])
 func.row_number().over(partition_by=[Foo.a.desc(), Foo.b.desc()])
 func.row_number().over(order_by="a", partition_by=("a", "b"))
 func.row_number().over(partition_by="a", order_by=("a", "b"))
 
 
-# EXPECTED_TYPE: Function[Any]
-reveal_type(func.row_number().filter())
-# EXPECTED_TYPE: FunctionFilter[Any]
-reveal_type(func.row_number().filter(Foo.a > 0))
-# EXPECTED_TYPE: FunctionFilter[Any]
-reveal_type(func.row_number().within_group(Foo.a).filter(Foo.b < 0))
-# EXPECTED_TYPE: WithinGroup[Any]
-reveal_type(func.row_number().within_group(Foo.a))
-# EXPECTED_TYPE: WithinGroup[Any]
-reveal_type(func.row_number().filter(Foo.a > 0).within_group(Foo.a))
-# EXPECTED_TYPE: Over[Any]
-reveal_type(func.row_number().filter(Foo.a > 0).over())
-# EXPECTED_TYPE: Over[Any]
-reveal_type(func.row_number().within_group(Foo.a).over())
+assert_type(func.row_number().filter(), Function[Any])
+assert_type(func.row_number().filter(Foo.a > 0), FunctionFilter[Any])
+assert_type(
+    func.row_number().within_group(Foo.a).filter(Foo.b < 0),
+    FunctionFilter[Any],
+)
+assert_type(func.row_number().within_group(Foo.a), WithinGroup[Any])
+assert_type(
+    func.row_number().filter(Foo.a > 0).within_group(Foo.a), WithinGroup[Any]
+)
+assert_type(func.row_number().filter(Foo.a > 0).over(), Over[Any])
+assert_type(func.row_number().within_group(Foo.a).over(), Over[Any])
 
 # test #10801
-# EXPECTED_TYPE: max[int]
-reveal_type(func.max(Foo.b))
+assert_type(func.max(Foo.b), functions_max[int])
 
 
 stmt1 = select(Foo.a, func.min(Foo.b)).group_by(Foo.a)
-# EXPECTED_TYPE: Select[int, int]
-reveal_type(stmt1)
+assert_type(stmt1, Select[int, int])
 
 # test #10818
-# EXPECTED_TYPE: coalesce[str]
-reveal_type(func.coalesce(Foo.c, "a", "b"))
-# EXPECTED_TYPE: coalesce[str]
-reveal_type(func.coalesce("a", "b"))
-# EXPECTED_TYPE: coalesce[int]
-reveal_type(func.coalesce(column("x", Integer), 3))
+assert_type(func.coalesce(Foo.c, "a", "b"), coalesce[str])
+assert_type(func.coalesce("a", "b"), coalesce[str])
+assert_type(func.coalesce(column("x", Integer), 3), coalesce[int])
 
 
 stmt2 = select(Foo.a, func.coalesce(Foo.c, "a", "b")).group_by(Foo.a)
-# EXPECTED_TYPE: Select[int, str]
-reveal_type(stmt2)
+assert_type(stmt2, Select[int, str])
 
 
-# EXPECTED_TYPE: TableValuedAlias
-reveal_type(func.json_each().table_valued("key", "value"))
-# EXPECTED_TYPE: TableValuedAlias
-reveal_type(func.json_each().table_valued(Foo.a, Foo.b))
+assert_type(func.json_each().table_valued("key", "value"), TableValuedAlias)
+assert_type(func.json_each().table_valued(Foo.a, Foo.b), TableValuedAlias)
index 035fde800d5d0da612c0bf70a3ec478d62702ef8..1725a57b33a11a6c4a58869d790e17c63d90cb58 100644 (file)
@@ -1,6 +1,9 @@
 from __future__ import annotations
 
+from typing import Any
+from typing import assert_type
 from typing import TYPE_CHECKING
+from typing import Unpack
 
 from sqlalchemy import Column
 from sqlalchemy import create_engine
@@ -11,9 +14,11 @@ from sqlalchemy import Result
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
+from sqlalchemy.engine.cursor import CursorResult
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.sql.lambdas import StatementLambdaElement
 
 
 class Base(DeclarativeBase):
@@ -48,11 +53,9 @@ s6 = lambda_stmt(lambda: select(User)) + (lambda s: s.where(User.id == 5))
 
 
 if TYPE_CHECKING:
-    # EXPECTED_TYPE: StatementLambdaElement
-    reveal_type(s5)
+    assert_type(s5, StatementLambdaElement)
 
-    # EXPECTED_TYPE: StatementLambdaElement
-    reveal_type(s6)
+    assert_type(s6, StatementLambdaElement)
 
 
 e = create_engine("sqlite://")
@@ -61,8 +64,7 @@ with e.connect() as conn:
     result = conn.execute(s6)
 
     if TYPE_CHECKING:
-        # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]]
-        reveal_type(result)
+        assert_type(result, CursorResult[Unpack[tuple[Any, ...]]])
 
     # we can type these like this
     my_result: Result[User] = conn.execute(s6)
@@ -71,5 +73,4 @@ with e.connect() as conn:
         # pyright and mypy disagree on the specific type here,
         # mypy sees Result as we said, pyright seems to upgrade it to
         # CursorResult
-        # EXPECTED_RE_TYPE: .*(?:Cursor)?Result\[.*User\]
-        reveal_type(my_result)
+        assert_type(my_result, Result[User])
index 2a9e539dc38d7c4691218d5464e0242fc6c84a2b..338ee9800726b820cf882b277c7ce121eb475c29 100644 (file)
@@ -1,10 +1,12 @@
 from typing import Any
+from typing import assert_type
 
 from sqlalchemy import column
 from sqlalchemy import ColumnElement
 from sqlalchemy import Integer
 from sqlalchemy import literal
 from sqlalchemy import table
+from sqlalchemy.sql.expression import ColumnClause
 
 
 def test_col_accessors() -> None:
@@ -27,11 +29,10 @@ def test_col_get() -> None:
     col_alt = column("alt", Integer)
     tbl = table("mytable", col_id)
 
-    # EXPECTED_TYPE: ColumnClause[Any] | None
-    reveal_type(tbl.c.get("id"))
-    # EXPECTED_TYPE: ColumnClause[Any] | None
-    reveal_type(tbl.c.get("id", None))
-    # EXPECTED_TYPE: ColumnClause[Any] | ColumnClause[int]
-    reveal_type(tbl.c.get("alt", col_alt))
+    assert_type(tbl.c.get("id"), ColumnClause[Any] | None)
+    assert_type(tbl.c.get("id", None), ColumnClause[Any] | None)
+    assert_type(
+        tbl.c.get("alt", col_alt), ColumnClause[Any] | ColumnClause[int]
+    )
     col: ColumnElement[Any] = tbl.c.get("foo", literal("bar"))
     print(col)
index c09029fc14812d0fef692851b25c82180a32305f..e9680afe716267009830b9427527a6054465036b 100644 (file)
@@ -1,6 +1,7 @@
 import datetime as dt
 from decimal import Decimal
 from typing import Any
+from typing import assert_type
 from typing import List
 
 from sqlalchemy import ARRAY
@@ -15,6 +16,8 @@ from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.sql import operators
+from sqlalchemy.sql.elements import Grouping
+from sqlalchemy.sql.expression import BinaryExpression
 
 
 class Base(DeclarativeBase):
@@ -147,15 +150,14 @@ op_e: "ColumnElement[bool]" = col.bool_op("&")("1")
 
 
 op_a1 = col.op("&")(1)
-# EXPECTED_TYPE: BinaryExpression[Any]
-reveal_type(op_a1)
+assert_type(op_a1, BinaryExpression[Any])
 
 
 # op functions
 t1 = operators.eq(A.id, 1)
 select().where(t1)
 
-# EXPECTED_TYPE: BinaryExpression[Any]
-reveal_type(col.op("->>")("field"))
-# EXPECTED_TYPE: BinaryExpression[Any] | Grouping[Any]
-reveal_type(col.op("->>")("field").self_group())
+assert_type(col.op("->>")("field"), BinaryExpression[Any])
+assert_type(
+    col.op("->>")("field").self_group(), BinaryExpression[Any] | Grouping[Any]
+)
index f0025b2cb340bfe85e0f598c6ed32acfa3757fcf..ef3b2dc390b5b223d8a103292d9e8e68f68d1553 100644 (file)
@@ -1,4 +1,7 @@
+from decimal import Decimal
 import typing
+from typing import Any
+from typing import assert_type
 
 from sqlalchemy import and_
 from sqlalchemy import Boolean
@@ -13,6 +16,10 @@ from sqlalchemy import or_
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import true
+from sqlalchemy.sql.elements import ColumnElement
+from sqlalchemy.sql.elements import UnaryExpression
+from sqlalchemy.sql.expression import BinaryExpression
+from sqlalchemy.sql.expression import ColumnClause
 
 
 # builtin.pyi stubs define object.__eq__() as returning bool,  which
@@ -99,57 +106,38 @@ def test_issue_9650_char() -> None:
 
 
 def test_issue_9650_bitwise() -> None:
-    # EXPECTED_TYPE: BinaryExpression[Any]
-    reveal_type(c2.bitwise_and(5))
-    # EXPECTED_TYPE: BinaryExpression[Any]
-    reveal_type(c2.bitwise_or(5))
-    # EXPECTED_TYPE: BinaryExpression[Any]
-    reveal_type(c2.bitwise_xor(5))
-    # EXPECTED_TYPE: UnaryExpression[int]
-    reveal_type(c2.bitwise_not())
-    # EXPECTED_TYPE: BinaryExpression[Any]
-    reveal_type(c2.bitwise_lshift(5))
-    # EXPECTED_TYPE: BinaryExpression[Any]
-    reveal_type(c2.bitwise_rshift(5))
-    # EXPECTED_TYPE: ColumnElement[int]
-    reveal_type(c2 << 5)
-    # EXPECTED_TYPE: ColumnElement[int]
-    reveal_type(c2 >> 5)
+    assert_type(c2.bitwise_and(5), BinaryExpression[Any])
+    assert_type(c2.bitwise_or(5), BinaryExpression[Any])
+    assert_type(c2.bitwise_xor(5), BinaryExpression[Any])
+    assert_type(c2.bitwise_not(), UnaryExpression[int])
+    assert_type(c2.bitwise_lshift(5), BinaryExpression[Any])
+    assert_type(c2.bitwise_rshift(5), BinaryExpression[Any])
+    assert_type(c2 << 5, ColumnElement[int])
+    assert_type(c2 >> 5, ColumnElement[int])
 
 
 if typing.TYPE_CHECKING:
     # as far as if this is ColumnElement, BinaryElement, SQLCoreOperations,
     # that might change.  main thing is it's SomeSQLColThing[bool] and
     # not 'bool' or 'Any'.
-    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.bool\]
-    reveal_type(expr1)
+    assert_type(expr1, ColumnElement[bool])
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnClause\[builtins.str.?\]
-    reveal_type(c1)
+    assert_type(c1, ColumnClause[str])
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnClause\[builtins.int.?\]
-    reveal_type(c2)
+    assert_type(c2, ColumnClause[int])
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*BinaryExpression\[builtins.bool\]
-    reveal_type(expr2)
+    assert_type(expr2, BinaryExpression[bool])
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.float | .*\.Decimal\]
-    reveal_type(expr3)
+    assert_type(expr3, ColumnElement[float | Decimal])
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*UnaryExpression\[builtins.int.?\]
-    reveal_type(expr4)
+    assert_type(expr4, UnaryExpression[int])
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.bool.?\]
-    reveal_type(expr5)
+    assert_type(expr5, ColumnElement[bool])
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.bool.?\]
-    reveal_type(expr6)
+    assert_type(expr6, ColumnElement[bool])
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.str\]
-    reveal_type(expr7)
+    assert_type(expr7, ColumnElement[str])
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.int.?\]
-    reveal_type(expr8)
+    assert_type(expr8, ColumnElement[int])
 
-    # EXPECTED_TYPE: BinaryExpression[bool]
-    reveal_type(expr9)
+    assert_type(expr9, BinaryExpression[bool])
index 230cb957d4a9542f6f8429415af3fa3b75f740a7..0b5cc1bc92c9be1dff6b7cb847f43575d8f2aabe 100644 (file)
@@ -1,12 +1,11 @@
+from decimal import Decimal
+from typing import assert_type
+
 from sqlalchemy import Float
 from sqlalchemy import Numeric
 
-# EXPECTED_TYPE: Float[float]
-reveal_type(Float())
-# EXPECTED_TYPE: Float[Decimal]
-reveal_type(Float(asdecimal=True))
+assert_type(Float(), Float[float])
+assert_type(Float(asdecimal=True), Float[Decimal])
 
-# EXPECTED_TYPE: Numeric[Decimal]
-reveal_type(Numeric())
-# EXPECTED_TYPE: Numeric[float]
-reveal_type(Numeric(asdecimal=False))
+assert_type(Numeric(), Numeric[Decimal])
+assert_type(Numeric(asdecimal=False), Numeric[float])
index a544aa434f470695bcd64f916cb3268cefe02bf8..98dde5ad9f7b98a77fb4de2f5a23d9961aadfbbd 100644 (file)
@@ -1,9 +1,13 @@
 from __future__ import annotations
 
 import asyncio
+from typing import Any
+from typing import assert_type
 from typing import cast
 from typing import Optional
+from typing import Sequence
 from typing import Type
+from typing import Unpack
 
 from sqlalchemy import Column
 from sqlalchemy import column
@@ -19,9 +23,17 @@ from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import table
+from sqlalchemy.engine import Result
+from sqlalchemy.engine.cursor import CursorResult
+from sqlalchemy.engine.result import MappingResult
+from sqlalchemy.engine.result import ScalarResult
+from sqlalchemy.engine.result import TupleResult
+from sqlalchemy.engine.row import Row
 from sqlalchemy.ext.asyncio import AsyncConnection
 from sqlalchemy.ext.asyncio import AsyncSession
 from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.ext.asyncio.result import AsyncScalarResult
+from sqlalchemy.ext.asyncio.result import AsyncTupleResult
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
@@ -66,8 +78,7 @@ async def async_connect() -> AsyncConnection:
 
 async_connection = asyncio.run(async_connect())
 
-# EXPECTED_RE_TYPE: sqlalchemy..*AsyncConnection\*?
-reveal_type(async_connection)
+assert_type(async_connection, AsyncConnection)
 
 async_session = AsyncSession(async_connection)
 
@@ -81,41 +92,33 @@ user = session.query(User).one()
 
 user_iter = iter(session.scalars(select(User)))
 
-# EXPECTED_RE_TYPE: sqlalchemy..*AsyncSession\*?
-reveal_type(async_session)
+assert_type(async_session, AsyncSession)
 
 
 single_stmt = select(User.name).where(User.name == "foo")
 
-# EXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[builtins.str\*?\]
-reveal_type(single_stmt)
+assert_type(single_stmt, Select[str])
 
 multi_stmt = select(User.id, User.name).where(User.name == "foo")
 
-# EXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[builtins.int\*?, builtins.str\*?\]
-reveal_type(multi_stmt)
+assert_type(multi_stmt, Select[int, str])
 
 
 def t_result_ctxmanager() -> None:
     with connection.execute(select(column("q", Integer))) as r1:
-        # EXPECTED_TYPE: CursorResult[int]
-        reveal_type(r1)
+        assert_type(r1, CursorResult[int])
 
         with r1.mappings() as r1m:
-            # EXPECTED_TYPE: MappingResult
-            reveal_type(r1m)
+            assert_type(r1m, MappingResult)
 
     with connection.scalars(select(column("q", Integer))) as r2:
-        # EXPECTED_TYPE: ScalarResult[int]
-        reveal_type(r2)
+        assert_type(r2, ScalarResult[int])
 
     with session.execute(select(User.id)) as r3:
-        # EXPECTED_TYPE: Result[int]
-        reveal_type(r3)
+        assert_type(r3, Result[int])
 
     with session.scalars(select(User.id)) as r4:
-        # EXPECTED_TYPE: ScalarResult[int]
-        reveal_type(r4)
+        assert_type(r4, ScalarResult[int])
 
 
 def t_mappings() -> None:
@@ -143,22 +146,18 @@ def t_entity_varieties() -> None:
 
     r1 = session.execute(s1)
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*.Result\[builtins.int\*?, typed_results.User\*?, builtins.str\*?\]
-    reveal_type(r1)
+    assert_type(r1, Result[int, User, str])
 
     s2 = select(User, a1).where(User.name == "foo")
 
     r2 = session.execute(s2)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*Result\[typed_results.User\*?, typed_results.User\*?\]
-    reveal_type(r2)
+    assert_type(r2, Result[User, User])
 
     row = r2.t.one()
 
-    # EXPECTED_RE_TYPE: .*typed_results.User\*?
-    reveal_type(row[0])
-    # EXPECTED_RE_TYPE: .*typed_results.User\*?
-    reveal_type(row[1])
+    assert_type(row[0], User)
+    assert_type(row[1], User)
 
     # testing that plain Mapped[x] gets picked up as well as
     # aliased class
@@ -166,19 +165,17 @@ def t_entity_varieties() -> None:
     # automatically typed since they are dynamically generated
     a1_id = cast(Mapped[int], a1.id)
     s3 = select(User.id, a1_id, a1, User).where(User.name == "foo")
-    # EXPECTED_RE_TYPE: sqlalchemy.*Select\*?\[builtins.int\*?, builtins.int\*?, typed_results.User\*?, typed_results.User\*?\]
-    reveal_type(s3)
+    assert_type(s3, Select[int, int, User, User])
 
     # testing Mapped[entity]
     some_mp = cast(Mapped[User], object())
     s4 = select(some_mp, a1, User).where(User.name == "foo")
 
-    # NOTEXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[typed_results.User\*?, typed_results.User\*?, typed_results.User\*?\]
+    # NOTEXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[User\*?, User\*?, User\*?\]
 
-    # sqlalchemy.sql._gen_overloads.Select[typed_results.User, typed_results.User, typed_results.User]
+    # sqlalchemy.sql._gen_overloads.Select[User, User, User]
 
-    # EXPECTED_TYPE: Select[User, User, User]
-    reveal_type(s4)
+    assert_type(s4, Select[User, User, User])
 
     # test plain core expressions
     x = Column("x", Integer)
@@ -186,43 +183,36 @@ def t_entity_varieties() -> None:
 
     s5 = select(x, y, User.name + "hi")
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[builtins.int\*?, builtins.int\*?\, builtins.str\*?]
-    reveal_type(s5)
+    assert_type(s5, Select[int, int, str])
 
 
 def t_ambiguous_result_type_one() -> None:
     stmt = select(column("q", Integer), table("x", column("y")))
 
-    # EXPECTED_TYPE: Select[Unpack[.*tuple[Any, ...]]]
-    reveal_type(stmt)
+    assert_type(stmt, Select[Unpack[tuple[Any, ...]]])
 
     result = session.execute(stmt)
 
-    # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]]
-    reveal_type(result)
+    assert_type(result, Result[Unpack[tuple[Any, ...]]])
 
 
 def t_ambiguous_result_type_two() -> None:
     stmt = select(column("q"))
 
-    # EXPECTED_TYPE: Select[Any]
-    reveal_type(stmt)
+    assert_type(stmt, Select[Any])
     result = session.execute(stmt)
 
-    # EXPECTED_TYPE: Result[Unpack[.*tuple[Any, ...]]]
-    reveal_type(result)
+    assert_type(result, Result[Unpack[tuple[Any, ...]]])
 
 
 def t_aliased() -> None:
     a1 = aliased(User)
 
     s1 = select(a1)
-    # EXPECTED_TYPE: Select[User]
-    reveal_type(s1)
+    assert_type(s1, Select[User])
 
     s4 = select(a1.name, a1, a1, User).where(User.name == "foo")
-    # EXPECTED_TYPE: Select[str, User, User, User]
-    reveal_type(s4)
+    assert_type(s4, Select[str, User, User, User])
 
 
 def t_result_scalar_accessors() -> None:
@@ -230,28 +220,23 @@ def t_result_scalar_accessors() -> None:
 
     r1 = result.scalar()
 
-    # EXPECTED_RE_TYPE: builtins.str \| None
-    reveal_type(r1)
+    assert_type(r1, str | None)
 
     r2 = result.scalar_one()
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(r2)
+    assert_type(r2, str)
 
     r3 = result.scalar_one_or_none()
 
-    # EXPECTED_RE_TYPE: builtins.str \| None
-    reveal_type(r3)
+    assert_type(r3, str | None)
 
     r4 = result.scalars()
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\]
-    reveal_type(r4)
+    assert_type(r4, ScalarResult[str])
 
     r5 = result.scalars(0)
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\]
-    reveal_type(r5)
+    assert_type(r5, ScalarResult[str])
 
 
 async def t_async_result_scalar_accessors() -> None:
@@ -259,28 +244,23 @@ async def t_async_result_scalar_accessors() -> None:
 
     r1 = await result.scalar()
 
-    # EXPECTED_RE_TYPE: builtins.str \| None
-    reveal_type(r1)
+    assert_type(r1, str | None)
 
     r2 = await result.scalar_one()
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(r2)
+    assert_type(r2, str)
 
     r3 = await result.scalar_one_or_none()
 
-    # EXPECTED_RE_TYPE: builtins.str \| None
-    reveal_type(r3)
+    assert_type(r3, str | None)
 
     r4 = result.scalars()
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\]
-    reveal_type(r4)
+    assert_type(r4, AsyncScalarResult[str])
 
     r5 = result.scalars(0)
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\]
-    reveal_type(r5)
+    assert_type(r5, AsyncScalarResult[str])
 
 
 def t_result_insertmanyvalues_scalars() -> None:
@@ -295,8 +275,7 @@ def t_result_insertmanyvalues_scalars() -> None:
         ],
     ).all()
 
-    # EXPECTED_TYPE: Sequence[int]
-    reveal_type(uids1)
+    assert_type(uids1, Sequence[int])
 
     uids2 = (
         connection.execute(
@@ -311,8 +290,7 @@ def t_result_insertmanyvalues_scalars() -> None:
         .all()
     )
 
-    # EXPECTED_TYPE: Sequence[int]
-    reveal_type(uids2)
+    assert_type(uids2, Sequence[int])
 
 
 async def t_async_result_insertmanyvalues_scalars() -> None:
@@ -329,8 +307,7 @@ async def t_async_result_insertmanyvalues_scalars() -> None:
         )
     ).all()
 
-    # EXPECTED_TYPE: Sequence[int]
-    reveal_type(uids1)
+    assert_type(uids1, Sequence[int])
 
     uids2 = (
         (
@@ -347,344 +324,279 @@ async def t_async_result_insertmanyvalues_scalars() -> None:
         .all()
     )
 
-    # EXPECTED_TYPE: Sequence[int]
-    reveal_type(uids2)
+    assert_type(uids2, Sequence[int])
 
 
 def t_connection_execute_multi_row_t() -> None:
     result = connection.execute(multi_stmt)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*CursorResult\[builtins.int\*?, builtins.str\*?\]
-    reveal_type(result)
+    assert_type(result, CursorResult[int, str])
     row = result.one()
 
-    # EXPECTED_RE_TYPE: .*sqlalchemy.*Row\[builtins.int\*?, builtins.str\*?\].*
-    reveal_type(row)
+    assert_type(row, Row[int, str])
 
     x, y = row.t
 
-    # EXPECTED_RE_TYPE: builtins.int\*?
-    reveal_type(x)
+    assert_type(x, int)
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(y)
+    assert_type(y, str)
 
 
 def t_connection_execute_multi() -> None:
     result = connection.execute(multi_stmt).t
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.int\*?, builtins.str\*?\]\]
-    reveal_type(result)
+    assert_type(result, TupleResult[tuple[int, str]])
     row = result.one()
 
-    # EXPECTED_RE_TYPE: tuple\[builtins.int\*?, builtins.str\*?\]
-    reveal_type(row)
+    assert_type(row, tuple[int, str])
 
     x, y = row
 
-    # EXPECTED_RE_TYPE: builtins.int\*?
-    reveal_type(x)
+    assert_type(x, int)
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(y)
+    assert_type(y, str)
 
 
 def t_connection_execute_single() -> None:
     result = connection.execute(single_stmt).t
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.str\*?\]\]
-    reveal_type(result)
+    assert_type(result, TupleResult[tuple[str]])
     row = result.one()
 
-    # EXPECTED_RE_TYPE: tuple\[builtins.str\*?\]
-    reveal_type(row)
+    assert_type(row, tuple[str])
 
     (x,) = row
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(x)
+    assert_type(x, str)
 
 
 def t_connection_execute_single_row_scalar() -> None:
     result = connection.execute(single_stmt).t
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.str\*?\]\]
-    reveal_type(result)
+    assert_type(result, TupleResult[tuple[str]])
 
     x = result.scalar()
 
-    # EXPECTED_RE_TYPE: builtins.str \| None
-    reveal_type(x)
+    assert_type(x, str | None)
 
 
 def t_connection_scalar() -> None:
     obj = connection.scalar(single_stmt)
 
-    # EXPECTED_RE_TYPE: builtins.str \| None
-    reveal_type(obj)
+    assert_type(obj, str | None)
 
 
 def t_connection_scalars() -> None:
     result = connection.scalars(single_stmt)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\[builtins.str\*?\]
-    reveal_type(result)
+    assert_type(result, ScalarResult[str])
     data = result.all()
 
-    # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\]
-    reveal_type(data)
+    assert_type(data, Sequence[str])
 
 
 def t_session_execute_multi() -> None:
     result = session.execute(multi_stmt).t
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.int\*?, builtins.str\*?\]\]
-    reveal_type(result)
+    assert_type(result, TupleResult[tuple[int, str]])
     row = result.one()
 
-    # EXPECTED_RE_TYPE: tuple\[builtins.int\*?, builtins.str\*?\]
-    reveal_type(row)
+    assert_type(row, tuple[int, str])
 
     x, y = row
 
-    # EXPECTED_RE_TYPE: builtins.int\*?
-    reveal_type(x)
+    assert_type(x, int)
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(y)
+    assert_type(y, str)
 
 
 def t_session_execute_single() -> None:
     result = session.execute(single_stmt).t
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.str\*?\]\]
-    reveal_type(result)
+    assert_type(result, TupleResult[tuple[str]])
     row = result.one()
 
-    # EXPECTED_RE_TYPE: tuple\[builtins.str\*?\]
-    reveal_type(row)
+    assert_type(row, tuple[str])
 
     (x,) = row
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(x)
+    assert_type(x, str)
 
 
 def t_session_scalar() -> None:
     obj = session.scalar(single_stmt)
 
-    # EXPECTED_RE_TYPE: builtins.str \| None
-    reveal_type(obj)
+    assert_type(obj, str | None)
 
 
 def t_session_scalars() -> None:
     result = session.scalars(single_stmt)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\[builtins.str\*?\]
-    reveal_type(result)
+    assert_type(result, ScalarResult[str])
     data = result.all()
 
-    # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\]
-    reveal_type(data)
+    assert_type(data, Sequence[str])
 
 
 async def t_async_connection_execute_multi() -> None:
     result = (await async_connection.execute(multi_stmt)).t
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.int\*?, builtins.str\*?\]\]
-    reveal_type(result)
+    assert_type(result, TupleResult[tuple[int, str]])
     row = result.one()
 
-    # EXPECTED_RE_TYPE: tuple\[builtins.int\*?, builtins.str\*?\]
-    reveal_type(row)
+    assert_type(row, tuple[int, str])
 
     x, y = row
 
-    # EXPECTED_RE_TYPE: builtins.int\*?
-    reveal_type(x)
+    assert_type(x, int)
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(y)
+    assert_type(y, str)
 
 
 async def t_async_connection_execute_single() -> None:
     result = (await async_connection.execute(single_stmt)).t
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.str\*?\]\]
-    reveal_type(result)
+    assert_type(result, TupleResult[tuple[str]])
 
     row = result.one()
 
-    # EXPECTED_RE_TYPE: tuple\[builtins.str\*?\]
-    reveal_type(row)
+    assert_type(row, tuple[str])
 
     (x,) = row
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(x)
+    assert_type(x, str)
 
 
 async def t_async_connection_scalar() -> None:
     obj = await async_connection.scalar(single_stmt)
 
-    # EXPECTED_RE_TYPE: builtins.str \| None
-    reveal_type(obj)
+    assert_type(obj, str | None)
 
 
 async def t_async_connection_scalars() -> None:
     result = await async_connection.scalars(single_stmt)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\*?\[builtins.str\*?\]
-    reveal_type(result)
+    assert_type(result, ScalarResult[str])
     data = result.all()
 
-    # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\]
-    reveal_type(data)
+    assert_type(data, Sequence[str])
 
 
 async def t_async_session_execute_multi() -> None:
     result = (await async_session.execute(multi_stmt)).t
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.int\*?, builtins.str\*?\]\]
-    reveal_type(result)
+    assert_type(result, TupleResult[tuple[int, str]])
     row = result.one()
 
-    # EXPECTED_RE_TYPE: tuple\[builtins.int\*?, builtins.str\*?\]
-    reveal_type(row)
+    assert_type(row, tuple[int, str])
 
     x, y = row
 
-    # EXPECTED_RE_TYPE: builtins.int\*?
-    reveal_type(x)
+    assert_type(x, int)
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(y)
+    assert_type(y, str)
 
 
 async def t_async_session_execute_single() -> None:
     result = (await async_session.execute(single_stmt)).t
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.str\*?\]\]
-    reveal_type(result)
+    assert_type(result, TupleResult[tuple[str]])
     row = result.one()
 
-    # EXPECTED_RE_TYPE: tuple\[builtins.str\*?\]
-    reveal_type(row)
+    assert_type(row, tuple[str])
 
     (x,) = row
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(x)
+    assert_type(x, str)
 
 
 async def t_async_session_scalar() -> None:
     obj = await async_session.scalar(single_stmt)
 
-    # EXPECTED_RE_TYPE: builtins.str \| None
-    reveal_type(obj)
+    assert_type(obj, str | None)
 
 
 async def t_async_session_scalars() -> None:
     result = await async_session.scalars(single_stmt)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\*?\[builtins.str\*?\]
-    reveal_type(result)
+    assert_type(result, ScalarResult[str])
     data = result.all()
 
-    # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\]
-    reveal_type(data)
+    assert_type(data, Sequence[str])
 
 
 async def t_async_connection_stream_multi() -> None:
     result = (await async_connection.stream(multi_stmt)).t
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*AsyncTupleResult\[tuple\[builtins.int\*?, builtins.str\*?\]\]
-    reveal_type(result)
+    assert_type(result, AsyncTupleResult[tuple[int, str]])
     row = await result.one()
 
-    # EXPECTED_RE_TYPE: tuple\[builtins.int\*?, builtins.str\*?\]
-    reveal_type(row)
+    assert_type(row, tuple[int, str])
 
     x, y = row
 
-    # EXPECTED_RE_TYPE: builtins.int\*?
-    reveal_type(x)
+    assert_type(x, int)
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(y)
+    assert_type(y, str)
 
 
 async def t_async_connection_stream_single() -> None:
     result = (await async_connection.stream(single_stmt)).t
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*AsyncTupleResult\[tuple\[builtins.str\*?\]\]
-    reveal_type(result)
+    assert_type(result, AsyncTupleResult[tuple[str]])
     row = await result.one()
 
-    # EXPECTED_RE_TYPE: tuple\[builtins.str\*?\]
-    reveal_type(row)
+    assert_type(row, tuple[str])
 
     (x,) = row
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(x)
+    assert_type(x, str)
 
 
 async def t_async_connection_stream_scalars() -> None:
     result = await async_connection.stream_scalars(single_stmt)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*AsyncScalarResult\*?\[builtins.str\*?\]
-    reveal_type(result)
+    assert_type(result, AsyncScalarResult[str])
     data = await result.all()
 
-    # EXPECTED_RE_TYPE: typing.Sequence\*?\[builtins.str\*?\]
-    reveal_type(data)
+    assert_type(data, Sequence[str])
 
 
 async def t_async_session_stream_multi() -> None:
     result = (await async_session.stream(multi_stmt)).t
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[tuple\[builtins.int\*?, builtins.str\*?\]\]
-    reveal_type(result)
+    assert_type(result, AsyncTupleResult[tuple[int, str]])
     row = await result.one()
 
-    # EXPECTED_RE_TYPE: tuple\[builtins.int\*?, builtins.str\*?\]
-    reveal_type(row)
+    assert_type(row, tuple[int, str])
 
     x, y = row
 
-    # EXPECTED_RE_TYPE: builtins.int\*?
-    reveal_type(x)
+    assert_type(x, int)
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(y)
+    assert_type(y, str)
 
 
 async def t_async_session_stream_single() -> None:
     result = (await async_session.stream(single_stmt)).t
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*AsyncTupleResult\[tuple\[builtins.str\*?\]\]
-    reveal_type(result)
+    assert_type(result, AsyncTupleResult[tuple[str]])
     row = await result.one()
 
-    # EXPECTED_RE_TYPE: tuple\[builtins.str\*?\]
-    reveal_type(row)
+    assert_type(row, tuple[str])
 
     (x,) = row
 
-    # EXPECTED_RE_TYPE: builtins.str\*?
-    reveal_type(x)
+    assert_type(x, str)
 
 
 async def t_async_session_stream_scalars() -> None:
     result = await async_session.stream_scalars(single_stmt)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*AsyncScalarResult\*?\[builtins.str\*?\]
-    reveal_type(result)
+    assert_type(result, AsyncScalarResult[str])
     data = await result.all()
 
-    # EXPECTED_RE_TYPE: typing.Sequence\*?\[builtins.str\*?\]
-    reveal_type(data)
+    assert_type(data, Sequence[str])
 
 
 def test_outerjoin_10173() -> None:
index 7b6c93de14bd322ab76b8ef360ac0e0493dc554b..a78e2492a54aaf1cc08ce46437ca5adaee56c906 100644 (file)
@@ -176,17 +176,16 @@ def {key}(self) -> Type[{_type}]:{_reserved_word}
                         # The origin type, if rtype is a generic
                         orig_type = typing.get_origin(rtype)
                         if orig_type is not None:
-                            coltype = rf".*{orig_type.__name__}\[.*int\]"
+                            coltype = rf"{orig_type.__name__}[int]"
                         else:
-                            coltype = ".*int"
+                            coltype = "int"
 
                         buf.write(
                             textwrap.indent(
                                 rf"""
 stmt{count} = select(func.{key}(column('x', Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[{coltype}\]
-reveal_type(stmt{count})
+assert_type(stmt{count}, Select[{coltype}])
 
 """,
                                 indent,
@@ -199,8 +198,7 @@ reveal_type(stmt{count})
                                 rf"""
 stmt{count} = select(func.{key}(column('x', String), ','))
 
-# EXPECTED_RE_TYPE: .*Select\[.*str\]
-reveal_type(stmt{count})
+assert_type(stmt{count}, Select[str])
 
 """,
                                 indent,
@@ -211,10 +209,10 @@ reveal_type(stmt{count})
                         fn_class.type, TypeEngine
                     ):
                         python_type = fn_class.type.python_type
-                        python_expr = rf".*{python_type.__name__}"
+                        python_expr = python_type.__name__
                         argspec = inspect.getfullargspec(fn_class)
                         if fn_class.__name__ == "next_value":
-                            args = "Sequence('x_seq')"
+                            args = "SqlAlchemySequence('x_seq')"
                         else:
                             args = ", ".join(
                                 'column("x")' for elem in argspec.args[1:]
@@ -226,8 +224,7 @@ reveal_type(stmt{count})
                                 rf"""
 stmt{count} = select(func.{key}({args}))
 
-# EXPECTED_RE_TYPE: .*Select\[{python_expr}\]
-reveal_type(stmt{count})
+assert_type(stmt{count}, Select[{python_expr}])
 
 """,
                                 indent,