]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Align `AsyncSession` method annotations with `Session` equivalents
authorJanek Nouvertné <25355197+provinzkraut@users.noreply.github.com>
Sun, 11 Jun 2023 10:07:26 +0000 (06:07 -0400)
committersqla-tester <sqla-tester@sqlalchemy.org>
Sun, 11 Jun 2023 10:07:26 +0000 (06:07 -0400)
Fixes a few differences in the parameter signatures of `asyncio.ext.AsyncSession` that were misaligned with `orm.Session`. Fixes #9925

### Description

- Change the annotation of the `params` parameter of `.scalar`, `.scalars` and `.stream_scalars` from `_CoreSingleExecuteParams` to `_CoreAnyExecuteParams`
- Add named keyword arguments `bind_arguments` and `execution_options` to `.connection`

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [x] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #9929
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9929
Pull-request-sha: 481f3ad94efab14f7a63a38c195810811b7ed90f

Change-Id: I84c5a68f5d95c903dd64928a23ad0cb796df778c

lib/sqlalchemy/ext/asyncio/scoping.py
lib/sqlalchemy/ext/asyncio/session.py

index 49d8b3af937643866a0599a3cf0025052dd755d8..d72b00468543180698aaa9c357cc2742ada57194 100644 (file)
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
     from ...engine import Row
     from ...engine import RowMapping
     from ...engine.interfaces import _CoreAnyExecuteParams
-    from ...engine.interfaces import _CoreSingleExecuteParams
+    from ...engine.interfaces import _ExecuteOptions
     from ...engine.result import ScalarResult
     from ...orm._typing import _IdentityKeyType
     from ...orm._typing import _O
@@ -465,7 +465,12 @@ class async_scoped_session(Generic[_AS]):
 
         return await self._proxied.commit()
 
-    async def connection(self, **kw: Any) -> AsyncConnection:
+    async def connection(
+        self,
+        bind_arguments: Optional[_BindArguments] = None,
+        execution_options: Optional[_ExecuteOptions] = None,
+        **kw: Any,
+    ) -> AsyncConnection:
         r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
         this :class:`.Session` object's transactional state.
 
@@ -488,7 +493,11 @@ class async_scoped_session(Generic[_AS]):
 
         """  # noqa: E501
 
-        return await self._proxied.connection(**kw)
+        return await self._proxied.connection(
+            bind_arguments=bind_arguments,
+            execution_options=execution_options,
+            **kw,
+        )
 
     async def delete(self, instance: object) -> None:
         r"""Mark an instance as deleted.
@@ -978,7 +987,7 @@ class async_scoped_session(Generic[_AS]):
     async def scalar(
         self,
         statement: TypedReturnsRows[Tuple[_T]],
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -990,7 +999,7 @@ class async_scoped_session(Generic[_AS]):
     async def scalar(
         self,
         statement: Executable,
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -1001,7 +1010,7 @@ class async_scoped_session(Generic[_AS]):
     async def scalar(
         self,
         statement: Executable,
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -1033,7 +1042,7 @@ class async_scoped_session(Generic[_AS]):
     async def scalars(
         self,
         statement: TypedReturnsRows[Tuple[_T]],
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -1045,7 +1054,7 @@ class async_scoped_session(Generic[_AS]):
     async def scalars(
         self,
         statement: Executable,
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -1056,7 +1065,7 @@ class async_scoped_session(Generic[_AS]):
     async def scalars(
         self,
         statement: Executable,
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -1149,7 +1158,7 @@ class async_scoped_session(Generic[_AS]):
     async def stream_scalars(
         self,
         statement: TypedReturnsRows[Tuple[_T]],
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -1161,7 +1170,7 @@ class async_scoped_session(Generic[_AS]):
     async def stream_scalars(
         self,
         statement: Executable,
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -1172,7 +1181,7 @@ class async_scoped_session(Generic[_AS]):
     async def stream_scalars(
         self,
         statement: Executable,
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
index 14551e6e28e4a310d769d20e6baba805d14d67af..72a98f576b6386e707d166454aaf7eeec8862353 100644 (file)
@@ -48,7 +48,7 @@ if TYPE_CHECKING:
     from ...engine import RowMapping
     from ...engine import ScalarResult
     from ...engine.interfaces import _CoreAnyExecuteParams
-    from ...engine.interfaces import _CoreSingleExecuteParams
+    from ...engine.interfaces import _ExecuteOptions
     from ...event import dispatcher
     from ...orm._typing import _IdentityKeyType
     from ...orm._typing import _O
@@ -447,7 +447,7 @@ class AsyncSession(ReversibleProxy[Session]):
     async def scalar(
         self,
         statement: TypedReturnsRows[Tuple[_T]],
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -459,7 +459,7 @@ class AsyncSession(ReversibleProxy[Session]):
     async def scalar(
         self,
         statement: Executable,
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -470,7 +470,7 @@ class AsyncSession(ReversibleProxy[Session]):
     async def scalar(
         self,
         statement: Executable,
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -505,7 +505,7 @@ class AsyncSession(ReversibleProxy[Session]):
     async def scalars(
         self,
         statement: TypedReturnsRows[Tuple[_T]],
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -517,7 +517,7 @@ class AsyncSession(ReversibleProxy[Session]):
     async def scalars(
         self,
         statement: Executable,
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -528,7 +528,7 @@ class AsyncSession(ReversibleProxy[Session]):
     async def scalars(
         self,
         statement: Executable,
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -653,7 +653,7 @@ class AsyncSession(ReversibleProxy[Session]):
     async def stream_scalars(
         self,
         statement: TypedReturnsRows[Tuple[_T]],
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -665,7 +665,7 @@ class AsyncSession(ReversibleProxy[Session]):
     async def stream_scalars(
         self,
         statement: Executable,
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -676,7 +676,7 @@ class AsyncSession(ReversibleProxy[Session]):
     async def stream_scalars(
         self,
         statement: Executable,
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -864,7 +864,12 @@ class AsyncSession(ReversibleProxy[Session]):
             mapper=mapper, clause=clause, bind=bind, **kw
         )
 
-    async def connection(self, **kw: Any) -> AsyncConnection:
+    async def connection(
+        self,
+        bind_arguments: Optional[_BindArguments] = None,
+        execution_options: Optional[_ExecuteOptions] = None,
+        **kw: Any,
+    ) -> AsyncConnection:
         r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
         this :class:`.Session` object's transactional state.
 
@@ -882,7 +887,10 @@ class AsyncSession(ReversibleProxy[Session]):
         """
 
         sync_connection = await greenlet_spawn(
-            self.sync_session.connection, **kw
+            self.sync_session.connection,
+            bind_arguments=bind_arguments,
+            execution_options=execution_options,
+            **kw,
         )
         return engine.AsyncConnection._retrieve_proxy_for_target(
             sync_connection