From 026afdfc03a98b47948a0655ba2ca23862dd01a0 Mon Sep 17 00:00:00 2001 From: Yurii Karabas <1998uriyyo@gmail.com> Date: Mon, 26 Jan 2026 20:54:01 -0500 Subject: [PATCH] Fix mypy error on scalar call with tuple[Any, ...] Fixed issue in new :pep:`646` support for result sets where an issue in the mypy type checker prevented "scalar" methods including :meth:`.Connection.scalar`, :meth:`.Result.scalar`, :meth:`_orm.Session.scalar`, as well as async versions of these methods from applying the correct type to the scalar result value, when the columns in the originating :func:`_sql.select` were typed as ``Any``. Pull request courtesy Yurii Karabas. Fixes: #13091 Closes: #13092 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13092 Pull-request-sha: c0535dd7691c8b57d87fd5411cb0202866256fd6 Change-Id: Ifbebaf0bcfdda0dd293c3e278a639cd0c71d45cd --- doc/build/changelog/unreleased_21/13091.rst | 12 +++++ lib/sqlalchemy/engine/base.py | 12 +++++ lib/sqlalchemy/engine/result.py | 11 +++++ lib/sqlalchemy/ext/asyncio/engine.py | 12 +++++ lib/sqlalchemy/ext/asyncio/scoping.py | 12 +++++ lib/sqlalchemy/ext/asyncio/session.py | 14 ++++++ lib/sqlalchemy/orm/scoping.py | 12 +++++ lib/sqlalchemy/orm/session.py | 14 ++++++ test/typing/plain_files/orm/session.py | 47 ++++++++++++++++++ test/typing/plain_files/sql/typed_results.py | 52 ++++++++++++++++++++ 10 files changed, 198 insertions(+) create mode 100644 doc/build/changelog/unreleased_21/13091.rst diff --git a/doc/build/changelog/unreleased_21/13091.rst b/doc/build/changelog/unreleased_21/13091.rst new file mode 100644 index 0000000000..0ce7cbb279 --- /dev/null +++ b/doc/build/changelog/unreleased_21/13091.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, typing + :tickets: 13091 + + Fixed issue in new :pep:`646` support for result sets where an issue in the + mypy type checker prevented "scalar" methods including + :meth:`.Connection.scalar`, :meth:`.Result.scalar`, + :meth:`_orm.Session.scalar`, as well as async versions of these methods + from applying the correct type to the scalar result value, when the columns + in the originating :func:`_sql.select` were typed as ``Any``. Pull request + courtesy Yurii Karabas. + diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 419ba6bf38..c7cd814773 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -41,6 +41,7 @@ from .. import log from .. import util from ..sql import compiler from ..sql import util as sql_util +from ..util.typing import Never from ..util.typing import TupleAny from ..util.typing import TypeVarTuple from ..util.typing import Unpack @@ -1279,6 +1280,17 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self._dbapi_connection = None self.__can_reconnect = False + # special case to handle mypy issue: + # https://github.com/python/mypy/issues/20651 + @overload + def scalar( + self, + statement: TypedReturnsRows[Never], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Optional[Any]: ... + @overload def scalar( self, diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 946f3192bd..5c1cd85a58 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -47,6 +47,7 @@ from ..sql.base import _generative from ..sql.base import InPlaceGenerative from ..util import deprecated from ..util import NONE_SET +from ..util.typing import Never from ..util.typing import Self from ..util.typing import TupleAny from ..util.typing import TypeVarTuple @@ -1064,6 +1065,16 @@ class Result(_WithKeys, ResultInternal[Row[Unpack[_Ts]]]): raise_for_second_row=True, raise_for_none=True, scalar=False ) + # special case to handle mypy issue: + # https://github.com/python/mypy/issues/20651 + @overload + def scalar(self: Result[Never, Unpack[TupleAny]]) -> Optional[Any]: + pass + + @overload + def scalar(self: Result[_T, Unpack[TupleAny]]) -> Optional[_T]: + pass + def scalar(self: Result[_T, Unpack[TupleAny]]) -> Optional[_T]: """Fetch the first column of the first row, and close the result set. diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index a969d76fee..722f3d6b3c 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -43,6 +43,7 @@ from ...engine.base import Transaction from ...exc import ArgumentError from ...util import immutabledict from ...util.concurrency import greenlet_spawn +from ...util.typing import Never from ...util.typing import TupleAny from ...util.typing import TypeVarTuple from ...util.typing import Unpack @@ -671,6 +672,17 @@ class AsyncConnection( # type:ignore[misc] ) return await _ensure_sync_result(result, self.execute) + # special case to handle mypy issue: + # https://github.com/python/mypy/issues/20651 + @overload + async def scalar( + self, + statement: TypedReturnsRows[Never], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Optional[Any]: ... + @overload async def scalar( self, diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 7ae240ebec..80d7a15a0f 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -31,6 +31,7 @@ from ...util import create_proxy_methods from ...util import ScopedRegistry from ...util import warn from ...util import warn_deprecated +from ...util.typing import Never from ...util.typing import TupleAny from ...util.typing import TypeVarTuple from ...util.typing import Unpack @@ -1038,6 +1039,17 @@ class async_scoped_session(Generic[_AS]): return await self._proxied.rollback() + @overload + async def scalar( + self, + statement: TypedReturnsRows[Never], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[Any]: ... + @overload async def scalar( self, diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 7c3c0f5539..69c98c57bb 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -40,6 +40,7 @@ from ...orm import Session from ...orm import SessionTransaction from ...orm import state as _instance_state from ...util.concurrency import greenlet_spawn +from ...util.typing import Never from ...util.typing import TupleAny from ...util.typing import TypeVarTuple from ...util.typing import Unpack @@ -461,6 +462,19 @@ class AsyncSession(ReversibleProxy[Session]): ) return await _ensure_sync_result(result, self.execute) + # special case to handle mypy issue: + # https://github.com/python/mypy/issues/20651 + @overload + async def scalar( + self, + statement: TypedReturnsRows[Never], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[Any]: ... + @overload async def scalar( self, diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 6b68cb5acd..0634086ea2 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -32,6 +32,7 @@ from ..util import ScopedRegistry from ..util import ThreadLocalRegistry from ..util import warn from ..util import warn_deprecated +from ..util.typing import Never from ..util.typing import TupleAny from ..util.typing import TypeVarTuple from ..util.typing import Unpack @@ -1852,6 +1853,17 @@ class scoped_session(Generic[_S]): return self._proxied.rollback() + @overload + def scalar( + self, + statement: TypedReturnsRows[Never], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[Any]: ... + @overload def scalar( self, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index c026cdd85b..b33a6ba651 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -91,6 +91,7 @@ from ..sql.schema import Table from ..sql.selectable import ForUpdateArg from ..util import deprecated_params from ..util import IdentitySet +from ..util.typing import Never from ..util.typing import TupleAny from ..util.typing import TypeVarTuple from ..util.typing import Unpack @@ -2410,6 +2411,19 @@ class Session(_SessionClassMethods, EventTarget): _add_event=_add_event, ) + # special case to handle mypy issue: + # https://github.com/python/mypy/issues/20651 + @overload + def scalar( + self, + statement: TypedReturnsRows[Never], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[Any]: ... + @overload def scalar( self, diff --git a/test/typing/plain_files/orm/session.py b/test/typing/plain_files/orm/session.py index af0de3386b..311c522a6d 100644 --- a/test/typing/plain_files/orm/session.py +++ b/test/typing/plain_files/orm/session.py @@ -1,11 +1,22 @@ from __future__ import annotations import asyncio +from typing import Any from typing import assert_type from typing import List +from typing import Tuple +from typing import Unpack +from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import Result +from sqlalchemy import Select +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.ext.asyncio import async_sessionmaker @@ -43,6 +54,14 @@ class Address(Base): user: Mapped[User] = relationship(back_populates="addresses") +user_table = Table( + "user", + MetaData(), + Column("id", Integer, primary_key=True), + Column("name", String, primary_key=True), +) + + e = create_engine("sqlite://") Base.metadata.create_all(e) @@ -172,3 +191,31 @@ async def async_test_exec_options() -> None: await scoped.connection( execution_options={"isolation_level": "REPEATABLE READ"} ) + + +def test_13091() -> None: + session = Session() + stmt = select(user_table.c.id) + assert_type(stmt, Select[Unpack[Tuple[Any, ...]]]) + result = session.execute(stmt) + + assert_type(result, Result[Unpack[Tuple[Any, ...]]]) + data1 = result.scalar() + assert_type(data1, Any | None) + + data2 = session.scalar(stmt) + assert_type(data2, Any | None) + + +async def async_test_13091() -> None: + session = AsyncSession() + stmt = select(user_table.c.id) + assert_type(stmt, Select[Unpack[Tuple[Any, ...]]]) + result = await session.execute(stmt) + + assert_type(result, Result[Unpack[Tuple[Any, ...]]]) + data1 = result.scalar() + assert_type(data1, Any | None) + + data2 = await session.scalar(stmt) + assert_type(data2, Any | None) diff --git a/test/typing/plain_files/sql/typed_results.py b/test/typing/plain_files/sql/typed_results.py index 98dde5ad9f..f648ecbc46 100644 --- a/test/typing/plain_files/sql/typed_results.py +++ b/test/typing/plain_files/sql/typed_results.py @@ -6,11 +6,13 @@ from typing import assert_type from typing import cast from typing import Optional from typing import Sequence +from typing import Tuple from typing import Type from typing import Unpack from sqlalchemy import Column from sqlalchemy import column +from sqlalchemy import Connection from sqlalchemy import create_engine from sqlalchemy import func from sqlalchemy import insert @@ -623,3 +625,53 @@ def test_outerjoin_10173() -> None: print(stmt4) print(stmt, stmt2, stmt3) + + +def test_13091() -> None: + with e.connect() as conn: + stmt = select(t_user.c.id) + assert_type(stmt, Select[Unpack[Tuple[Any, ...]]]) + result = conn.execute(stmt) + + assert_type(result, CursorResult[Unpack[Tuple[Any, ...]]]) + data1 = result.scalar() + assert_type(data1, Any | None) + + data2 = conn.scalar(stmt) + assert_type(data2, Any | None) + + +async def async_test_13091() -> None: + async with ae.connect() as conn: + stmt = select(t_user.c.id) + assert_type(stmt, Select[Unpack[Tuple[Any, ...]]]) + result = await conn.execute(stmt) + + assert_type(result, CursorResult[Unpack[Tuple[Any, ...]]]) + data1 = result.scalar() + assert_type(data1, Any | None) + + data2 = await conn.scalar(stmt) + assert_type(data2, Any | None) + + +def test_13091_2( + conn: Connection, table: Table, c: Column[int], c2: Column[str] +) -> None: + assert_type(table.select(), Select[Unpack[Tuple[Any, ...]]]) + r1 = conn.execute(table.select()) + assert_type(r1, CursorResult[Unpack[Tuple[Any, ...]]]) + d1 = r1.scalar() + assert_type(d1, Any | None) + r2 = conn.execute(select(table)) + assert_type(r2, CursorResult[Unpack[Tuple[Any, ...]]]) + d2 = r2.scalar() + assert_type(d2, Any | None) + r3 = conn.execute(select(c)) + assert_type(r3, CursorResult[int]) + d3 = r3.scalar() + assert_type(d3, int | None) + r4 = conn.execute(select(c, c2)) + assert_type(r4, CursorResult[int, str]) + d4 = r3.scalar() + assert_type(d4, int | None) -- 2.47.3