]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
feat: add `async_creator` argument to `create_async_engine`
authorJack Wotherspoon <jackwoth@google.com>
Sun, 4 Jun 2023 08:59:23 +0000 (04:59 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 9 Jun 2023 15:21:40 +0000 (11:21 -0400)
Added new :paramref:`_asyncio.create_async_engine.async_creator` parameter
to :func:`.create_async_engine`, which accomplishes the same purpose as the
:paramref:`.create_engine.creator` parameter of :func:`.create_engine`.
This is a no-argument callable that provides a new asyncio connection,
using the asyncio database driver directly. The
:func:`.create_async_engine` function will wrap the driver-level connection
in the appropriate structures. Pull request curtesy of Jack Wotherspoon.

Fixes #8215
Closes: #9854
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9854
Pull-request-sha: 537073e71e745696f4adb86191b72dd3547b5c95

Change-Id: I184c59ee68436e910464b717f2cbb7e314c1c2cc

doc/build/changelog/unreleased_20/8215.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/aiomysql.py
lib/sqlalchemy/dialects/mysql/asyncmy.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/dialects/sqlite/aiosqlite.py
lib/sqlalchemy/ext/asyncio/engine.py
test/ext/asyncio/test_engine_py3k.py

diff --git a/doc/build/changelog/unreleased_20/8215.rst b/doc/build/changelog/unreleased_20/8215.rst
new file mode 100644 (file)
index 0000000..fc4e5fe
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: usecase, asyncio
+    :tickets: 8215
+
+    Added new :paramref:`_asyncio.create_async_engine.async_creator` parameter
+    to :func:`.create_async_engine`, which accomplishes the same purpose as the
+    :paramref:`.create_engine.creator` parameter of :func:`.create_engine`.
+    This is a no-argument callable that provides a new asyncio connection,
+    using the asyncio database driver directly. The
+    :func:`.create_async_engine` function will wrap the driver-level connection
+    in the appropriate structures. Pull request curtesy of Jack Wotherspoon.
index 45333532535ca7b9f36fea83b49cb82d2a5bda22..d4540785c123b99e0ffd76f650374278e8cc0014 100644 (file)
@@ -34,7 +34,6 @@ This dialect should normally be used only with the
 
 
 """  # noqa
-
 from .pymysql import MySQLDialect_pymysql
 from ... import pool
 from ... import util
@@ -255,16 +254,17 @@ class AsyncAdapt_aiomysql_dbapi:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
+        creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
 
         if util.asbool(async_fallback):
             return AsyncAdaptFallback_aiomysql_connection(
                 self,
-                await_fallback(self.aiomysql.connect(*arg, **kw)),
+                await_fallback(creator_fn(*arg, **kw)),
             )
         else:
             return AsyncAdapt_aiomysql_connection(
                 self,
-                await_only(self.aiomysql.connect(*arg, **kw)),
+                await_only(creator_fn(*arg, **kw)),
             )
 
 
index 8289daa7d100e0dd518243013fe385bfaea76909..f454dc38fa67bf43957b2dacf2080c3abe441961 100644 (file)
@@ -29,7 +29,6 @@ This dialect should normally be used only with the
 
 
 """  # noqa
-
 from contextlib import asynccontextmanager
 
 from .pymysql import MySQLDialect_pymysql
@@ -267,16 +266,17 @@ class AsyncAdapt_asyncmy_dbapi:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
+        creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect)
 
         if util.asbool(async_fallback):
             return AsyncAdaptFallback_asyncmy_connection(
                 self,
-                await_fallback(self.asyncmy.connect(*arg, **kw)),
+                await_fallback(creator_fn(*arg, **kw)),
             )
         else:
             return AsyncAdapt_asyncmy_connection(
                 self,
-                await_only(self.asyncmy.connect(*arg, **kw)),
+                await_only(creator_fn(*arg, **kw)),
             )
 
 
index 53e27fb746cc567391b23bdd405fdff03207b20b..9eb17801e7c8f989e5b44383fb1c97382b75f5bf 100644 (file)
@@ -872,6 +872,7 @@ class AsyncAdapt_asyncpg_dbapi:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
+        creator_fn = kw.pop("async_creator_fn", self.asyncpg.connect)
         prepared_statement_cache_size = kw.pop(
             "prepared_statement_cache_size", 100
         )
@@ -882,14 +883,14 @@ class AsyncAdapt_asyncpg_dbapi:
         if util.asbool(async_fallback):
             return AsyncAdaptFallback_asyncpg_connection(
                 self,
-                await_fallback(self.asyncpg.connect(*arg, **kw)),
+                await_fallback(creator_fn(*arg, **kw)),
                 prepared_statement_cache_size=prepared_statement_cache_size,
                 prepared_statement_name_func=prepared_statement_name_func,
             )
         else:
             return AsyncAdapt_asyncpg_connection(
                 self,
-                await_only(self.asyncpg.connect(*arg, **kw)),
+                await_only(creator_fn(*arg, **kw)),
                 prepared_statement_cache_size=prepared_statement_cache_size,
                 prepared_statement_name_func=prepared_statement_name_func,
             )
index e65ccfea8fc460a334a31b92b3e6ff6525a3f26e..5c58daa3e7330ef72807a5091e18b53d18328b0c 100644 (file)
@@ -672,15 +672,16 @@ class PsycopgAdaptDBAPI:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
+        creator_fn = kw.pop(
+            "async_creator_fn", self.psycopg.AsyncConnection.connect
+        )
         if util.asbool(async_fallback):
             return AsyncAdaptFallback_psycopg_connection(
-                await_fallback(
-                    self.psycopg.AsyncConnection.connect(*arg, **kw)
-                )
+                await_fallback(creator_fn(*arg, **kw))
             )
         else:
             return AsyncAdapt_psycopg_connection(
-                await_only(self.psycopg.AsyncConnection.connect(*arg, **kw))
+                await_only(creator_fn(*arg, **kw))
             )
 
 
index 78c7e08b9a742c270b193e069073a17529f5037a..b8011a50ec274cf73eed4eb3c5acc0b75a76f1a8 100644 (file)
@@ -298,10 +298,13 @@ class AsyncAdapt_aiosqlite_dbapi:
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
 
-        connection = self.aiosqlite.connect(*arg, **kw)
-
-        # it's a Thread.   you'll thank us later
-        connection.daemon = True
+        creator_fn = kw.pop("async_creator_fn", None)
+        if creator_fn:
+            connection = creator_fn(*arg, **kw)
+        else:
+            connection = self.aiosqlite.connect(*arg, **kw)
+            # it's a Thread.   you'll thank us later
+            connection.daemon = True
 
         if util.asbool(async_fallback):
             return AsyncAdaptFallback_aiosqlite_connection(
index fdf9580f486bcf3f7f9629dd5de4ee06bb141857..e77c3df10219d937e4e7914aebf6b5b1a9bc365c 100644 (file)
@@ -39,6 +39,7 @@ from ...engine import create_pool_from_url as _create_pool_from_url
 from ...engine import Engine
 from ...engine.base import NestedTransaction
 from ...engine.base import Transaction
+from ...exc import ArgumentError
 from ...util.concurrency import greenlet_spawn
 
 if TYPE_CHECKING:
@@ -73,6 +74,20 @@ def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine:
 
     .. versionadded:: 1.4
 
+    :param async_creator: an async callable which returns a driver-level
+        asyncio connection. If given, the function should take no arguments,
+        and return a new asyncio connection from the underlying asyncio
+        database driver; the connection will be wrapped in the appropriate
+        structures to be used with the :class:`.AsyncEngine`.   Note that the
+        parameters specified in the URL are not applied here, and the creator
+        function should use its own connection parameters.
+
+        This parameter is the asyncio equivalent of the
+        :paramref:`_sa.create_engine.creator` parameter of the
+        :func:`_sa.create_engine` function.
+
+        .. versionadded:: 2.0.16
+
     """
 
     if kw.get("server_side_cursors", False):
@@ -82,6 +97,23 @@ def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine:
             "streaming result set"
         )
     kw["_is_async"] = True
+    async_creator = kw.pop("async_creator", None)
+    if async_creator:
+        if kw.get("creator", None):
+            raise ArgumentError(
+                "Can only specify one of 'async_creator' or 'creator', "
+                "not both."
+            )
+
+        def creator() -> Any:
+            # note that to send adapted arguments like
+            # prepared_statement_cache_size, user would use
+            # "creator" and emulate this form here
+            return sync_engine.dialect.dbapi.connect(  # type: ignore
+                async_creator_fn=async_creator
+            )
+
+        kw["creator"] = creator
     sync_engine = _create_engine(url, **kw)
     return AsyncEngine(sync_engine)
 
index ff4fcbf28eb9644ca00528d85c7a6964b4610abb..bbbdbf512f4b528026653e9a9700c198b86e2a60 100644 (file)
@@ -44,7 +44,7 @@ from sqlalchemy.testing import is_not
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import ne_
-from sqlalchemy.util.concurrency import greenlet_spawn
+from sqlalchemy.util import greenlet_spawn
 
 
 class AsyncFixture:
@@ -678,13 +678,14 @@ class AsyncEngineTest(EngineFixture):
 
     @async_test
     async def test_create_async_engine_server_side_cursor(self, async_engine):
-        testing.assert_raises_message(
+        with expect_raises_message(
             asyncio_exc.AsyncMethodRequired,
             "Can't set server_side_cursors for async engine globally",
-            create_async_engine,
-            testing.db.url,
-            server_side_cursors=True,
-        )
+        ):
+            create_async_engine(
+                testing.db.url,
+                server_side_cursors=True,
+            )
 
     def test_async_engine_from_config(self):
         config = {
@@ -698,6 +699,79 @@ class AsyncEngineTest(EngineFixture):
         assert engine.echo is True
         assert engine.dialect.is_async is True
 
+    def test_async_creator_and_creator(self):
+        async def ac():
+            return None
+
+        def c():
+            return None
+
+        with expect_raises_message(
+            exc.ArgumentError,
+            "Can only specify one of 'async_creator' or 'creator', "
+            "not both.",
+        ):
+            create_async_engine(testing.db.url, creator=c, async_creator=ac)
+
+    @async_test
+    async def test_async_creator_invoked(self, async_testing_engine):
+        """test for #8215"""
+
+        existing_creator = testing.db.pool._creator
+
+        async def async_creator():
+            sync_conn = await greenlet_spawn(existing_creator)
+            return sync_conn.driver_connection
+
+        async_creator = mock.Mock(side_effect=async_creator)
+
+        eq_(async_creator.mock_calls, [])
+
+        engine = async_testing_engine(options={"async_creator": async_creator})
+        async with engine.connect() as conn:
+            result = await conn.scalar(select(1))
+            eq_(result, 1)
+
+        eq_(async_creator.mock_calls, [mock.call()])
+
+    @async_test
+    async def test_async_creator_accepts_args_if_called_directly(
+        self, async_testing_engine
+    ):
+        """supplemental test for #8215.
+
+        The "async_creator" passed to create_async_engine() is expected to take
+        no arguments, the same way as "creator" passed to create_engine()
+        works.
+
+        However, the ultimate "async_creator" received by the sync-emulating
+        DBAPI *does* take arguments in its ``.connect()`` method, which will be
+        all the other arguments passed to ``.connect()``.  This functionality
+        is not currently used, however was decided that the creator should
+        internally work this way for improved flexibility; see
+        https://github.com/sqlalchemy/sqlalchemy/issues/8215#issuecomment-1181791539.
+        That contract is tested here.
+
+        """  # noqa: E501
+
+        existing_creator = testing.db.pool._creator
+
+        async def async_creator(x, y, *, z=None):
+            sync_conn = await greenlet_spawn(existing_creator)
+            return sync_conn.driver_connection
+
+        async_creator = mock.Mock(side_effect=async_creator)
+
+        async_dbapi = testing.db.dialect.loaded_dbapi
+
+        conn = await greenlet_spawn(
+            async_dbapi.connect, 5, y=10, z=8, async_creator_fn=async_creator
+        )
+        try:
+            eq_(async_creator.mock_calls, [mock.call(5, y=10, z=8)])
+        finally:
+            await greenlet_spawn(conn.close)
+
 
 class AsyncCreatePoolTest(fixtures.TestBase):
     @config.fixture