From: Mike Bayer Date: Thu, 2 Mar 2023 01:44:49 +0000 (-0500) Subject: allow multiparams with scalars X-Git-Tag: rel_2_0_5~11^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3ba05fa919be24447540ae9d4d9c95ab509cf929;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git allow multiparams with scalars Fixed bug where the :meth:`_engine.Connection.scalars` method was not typed as allowing a multiple-parameters list, which is now supported using insertmanyvalues operations. Change-Id: I65e22c3bee80fc226d484ff1424421dd78520fa5 --- diff --git a/doc/build/changelog/unreleased_20/type_scalars.rst b/doc/build/changelog/unreleased_20/type_scalars.rst new file mode 100644 index 0000000000..d983e15805 --- /dev/null +++ b/doc/build/changelog/unreleased_20/type_scalars.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: bug, typing + + Fixed bug where the :meth:`_engine.Connection.scalars` method was not typed + as allowing a multiple-parameters list, which is now supported using + insertmanyvalues operations. diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index f6c637aa89..926a08b76f 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1306,7 +1306,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def scalars( self, statement: TypedReturnsRows[Tuple[_T]], - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, ) -> ScalarResult[_T]: @@ -1316,7 +1316,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def scalars( self, statement: Executable, - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, ) -> ScalarResult[Any]: @@ -1325,7 +1325,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def scalars( self, statement: Executable, - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, ) -> ScalarResult[Any]: diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 86e257bdd7..325c58bdab 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -646,7 +646,7 @@ class AsyncConnection( async def scalars( self, statement: TypedReturnsRows[Tuple[_T]], - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, ) -> ScalarResult[_T]: @@ -656,7 +656,7 @@ class AsyncConnection( async def scalars( self, statement: Executable, - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, ) -> ScalarResult[Any]: @@ -665,7 +665,7 @@ class AsyncConnection( async def scalars( self, statement: Executable, - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, ) -> ScalarResult[Any]: diff --git a/test/ext/mypy/plain_files/typed_results.py b/test/ext/mypy/plain_files/typed_results.py index 8fd9e5cd13..2e42bb655b 100644 --- a/test/ext/mypy/plain_files/typed_results.py +++ b/test/ext/mypy/plain_files/typed_results.py @@ -6,6 +6,7 @@ from typing import cast from sqlalchemy import Column from sqlalchemy import column from sqlalchemy import create_engine +from sqlalchemy import insert from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import table @@ -249,6 +250,74 @@ async def t_async_result_scalar_accessors() -> None: reveal_type(r5) +def t_result_insertmanyvalues_scalars() -> None: + stmt = insert(User).returning(User.id) + + uids1 = connection.scalars( + stmt, + [ + {"name": "n1"}, + {"name": "n2"}, + {"name": "n3"}, + ], + ).all() + + # EXPECTED_TYPE: Sequence[int] + reveal_type(uids1) + + uids2 = ( + connection.execute( + stmt, + [ + {"name": "n1"}, + {"name": "n2"}, + {"name": "n3"}, + ], + ) + .scalars() + .all() + ) + + # EXPECTED_TYPE: Sequence[int] + reveal_type(uids2) + + +async def t_async_result_insertmanyvalues_scalars() -> None: + stmt = insert(User).returning(User.id) + + uids1 = ( + await async_connection.scalars( + stmt, + [ + {"name": "n1"}, + {"name": "n2"}, + {"name": "n3"}, + ], + ) + ).all() + + # EXPECTED_TYPE: Sequence[int] + reveal_type(uids1) + + uids2 = ( + ( + await async_connection.execute( + stmt, + [ + {"name": "n1"}, + {"name": "n2"}, + {"name": "n3"}, + ], + ) + ) + .scalars() + .all() + ) + + # EXPECTED_TYPE: Sequence[int] + reveal_type(uids2) + + def t_connection_execute_multi_row_t() -> None: result = connection.execute(multi_stmt)