From: Jack Wotherspoon Date: Sun, 4 Jun 2023 08:59:23 +0000 (-0400) Subject: feat: add `async_creator` argument to `create_async_engine` X-Git-Tag: rel_2_0_16~3^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=dcd938c68eb0bbb33876ab57cf67ba2ef9f9947f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git feat: add `async_creator` argument to `create_async_engine` Added new :paramref:`_asyncio.create_async_engine.async_creator` parameter to :func:`.create_async_engine`, which accomplishes the same purpose as the :paramref:`.create_engine.creator` parameter of :func:`.create_engine`. This is a no-argument callable that provides a new asyncio connection, using the asyncio database driver directly. The :func:`.create_async_engine` function will wrap the driver-level connection in the appropriate structures. Pull request curtesy of Jack Wotherspoon. Fixes #8215 Closes: #9854 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9854 Pull-request-sha: 537073e71e745696f4adb86191b72dd3547b5c95 Change-Id: I184c59ee68436e910464b717f2cbb7e314c1c2cc --- diff --git a/doc/build/changelog/unreleased_20/8215.rst b/doc/build/changelog/unreleased_20/8215.rst new file mode 100644 index 0000000000..fc4e5fe159 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8215.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: usecase, asyncio + :tickets: 8215 + + Added new :paramref:`_asyncio.create_async_engine.async_creator` parameter + to :func:`.create_async_engine`, which accomplishes the same purpose as the + :paramref:`.create_engine.creator` parameter of :func:`.create_engine`. + This is a no-argument callable that provides a new asyncio connection, + using the asyncio database driver directly. The + :func:`.create_async_engine` function will wrap the driver-level connection + in the appropriate structures. Pull request curtesy of Jack Wotherspoon. diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index 4533353253..d4540785c1 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -34,7 +34,6 @@ This dialect should normally be used only with the """ # noqa - from .pymysql import MySQLDialect_pymysql from ... import pool from ... import util @@ -255,16 +254,17 @@ class AsyncAdapt_aiomysql_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect) if util.asbool(async_fallback): return AsyncAdaptFallback_aiomysql_connection( self, - await_fallback(self.aiomysql.connect(*arg, **kw)), + await_fallback(creator_fn(*arg, **kw)), ) else: return AsyncAdapt_aiomysql_connection( self, - await_only(self.aiomysql.connect(*arg, **kw)), + await_only(creator_fn(*arg, **kw)), ) diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py index 8289daa7d1..f454dc38fa 100644 --- a/lib/sqlalchemy/dialects/mysql/asyncmy.py +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -29,7 +29,6 @@ This dialect should normally be used only with the """ # noqa - from contextlib import asynccontextmanager from .pymysql import MySQLDialect_pymysql @@ -267,16 +266,17 @@ class AsyncAdapt_asyncmy_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect) if util.asbool(async_fallback): return AsyncAdaptFallback_asyncmy_connection( self, - await_fallback(self.asyncmy.connect(*arg, **kw)), + await_fallback(creator_fn(*arg, **kw)), ) else: return AsyncAdapt_asyncmy_connection( self, - await_only(self.asyncmy.connect(*arg, **kw)), + await_only(creator_fn(*arg, **kw)), ) diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 53e27fb746..9eb17801e7 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -872,6 +872,7 @@ class AsyncAdapt_asyncpg_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop("async_creator_fn", self.asyncpg.connect) prepared_statement_cache_size = kw.pop( "prepared_statement_cache_size", 100 ) @@ -882,14 +883,14 @@ class AsyncAdapt_asyncpg_dbapi: if util.asbool(async_fallback): return AsyncAdaptFallback_asyncpg_connection( self, - await_fallback(self.asyncpg.connect(*arg, **kw)), + await_fallback(creator_fn(*arg, **kw)), prepared_statement_cache_size=prepared_statement_cache_size, prepared_statement_name_func=prepared_statement_name_func, ) else: return AsyncAdapt_asyncpg_connection( self, - await_only(self.asyncpg.connect(*arg, **kw)), + await_only(creator_fn(*arg, **kw)), prepared_statement_cache_size=prepared_statement_cache_size, prepared_statement_name_func=prepared_statement_name_func, ) diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index e65ccfea8f..5c58daa3e7 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -672,15 +672,16 @@ class PsycopgAdaptDBAPI: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop( + "async_creator_fn", self.psycopg.AsyncConnection.connect + ) if util.asbool(async_fallback): return AsyncAdaptFallback_psycopg_connection( - await_fallback( - self.psycopg.AsyncConnection.connect(*arg, **kw) - ) + await_fallback(creator_fn(*arg, **kw)) ) else: return AsyncAdapt_psycopg_connection( - await_only(self.psycopg.AsyncConnection.connect(*arg, **kw)) + await_only(creator_fn(*arg, **kw)) ) diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py index 78c7e08b9a..b8011a50ec 100644 --- a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -298,10 +298,13 @@ class AsyncAdapt_aiosqlite_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) - connection = self.aiosqlite.connect(*arg, **kw) - - # it's a Thread. you'll thank us later - connection.daemon = True + creator_fn = kw.pop("async_creator_fn", None) + if creator_fn: + connection = creator_fn(*arg, **kw) + else: + connection = self.aiosqlite.connect(*arg, **kw) + # it's a Thread. you'll thank us later + connection.daemon = True if util.asbool(async_fallback): return AsyncAdaptFallback_aiosqlite_connection( diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index fdf9580f48..e77c3df102 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -39,6 +39,7 @@ from ...engine import create_pool_from_url as _create_pool_from_url from ...engine import Engine from ...engine.base import NestedTransaction from ...engine.base import Transaction +from ...exc import ArgumentError from ...util.concurrency import greenlet_spawn if TYPE_CHECKING: @@ -73,6 +74,20 @@ def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine: .. versionadded:: 1.4 + :param async_creator: an async callable which returns a driver-level + asyncio connection. If given, the function should take no arguments, + and return a new asyncio connection from the underlying asyncio + database driver; the connection will be wrapped in the appropriate + structures to be used with the :class:`.AsyncEngine`. Note that the + parameters specified in the URL are not applied here, and the creator + function should use its own connection parameters. + + This parameter is the asyncio equivalent of the + :paramref:`_sa.create_engine.creator` parameter of the + :func:`_sa.create_engine` function. + + .. versionadded:: 2.0.16 + """ if kw.get("server_side_cursors", False): @@ -82,6 +97,23 @@ def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine: "streaming result set" ) kw["_is_async"] = True + async_creator = kw.pop("async_creator", None) + if async_creator: + if kw.get("creator", None): + raise ArgumentError( + "Can only specify one of 'async_creator' or 'creator', " + "not both." + ) + + def creator() -> Any: + # note that to send adapted arguments like + # prepared_statement_cache_size, user would use + # "creator" and emulate this form here + return sync_engine.dialect.dbapi.connect( # type: ignore + async_creator_fn=async_creator + ) + + kw["creator"] = creator sync_engine = _create_engine(url, **kw) return AsyncEngine(sync_engine) diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index ff4fcbf28e..bbbdbf512f 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -44,7 +44,7 @@ from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing import ne_ -from sqlalchemy.util.concurrency import greenlet_spawn +from sqlalchemy.util import greenlet_spawn class AsyncFixture: @@ -678,13 +678,14 @@ class AsyncEngineTest(EngineFixture): @async_test async def test_create_async_engine_server_side_cursor(self, async_engine): - testing.assert_raises_message( + with expect_raises_message( asyncio_exc.AsyncMethodRequired, "Can't set server_side_cursors for async engine globally", - create_async_engine, - testing.db.url, - server_side_cursors=True, - ) + ): + create_async_engine( + testing.db.url, + server_side_cursors=True, + ) def test_async_engine_from_config(self): config = { @@ -698,6 +699,79 @@ class AsyncEngineTest(EngineFixture): assert engine.echo is True assert engine.dialect.is_async is True + def test_async_creator_and_creator(self): + async def ac(): + return None + + def c(): + return None + + with expect_raises_message( + exc.ArgumentError, + "Can only specify one of 'async_creator' or 'creator', " + "not both.", + ): + create_async_engine(testing.db.url, creator=c, async_creator=ac) + + @async_test + async def test_async_creator_invoked(self, async_testing_engine): + """test for #8215""" + + existing_creator = testing.db.pool._creator + + async def async_creator(): + sync_conn = await greenlet_spawn(existing_creator) + return sync_conn.driver_connection + + async_creator = mock.Mock(side_effect=async_creator) + + eq_(async_creator.mock_calls, []) + + engine = async_testing_engine(options={"async_creator": async_creator}) + async with engine.connect() as conn: + result = await conn.scalar(select(1)) + eq_(result, 1) + + eq_(async_creator.mock_calls, [mock.call()]) + + @async_test + async def test_async_creator_accepts_args_if_called_directly( + self, async_testing_engine + ): + """supplemental test for #8215. + + The "async_creator" passed to create_async_engine() is expected to take + no arguments, the same way as "creator" passed to create_engine() + works. + + However, the ultimate "async_creator" received by the sync-emulating + DBAPI *does* take arguments in its ``.connect()`` method, which will be + all the other arguments passed to ``.connect()``. This functionality + is not currently used, however was decided that the creator should + internally work this way for improved flexibility; see + https://github.com/sqlalchemy/sqlalchemy/issues/8215#issuecomment-1181791539. + That contract is tested here. + + """ # noqa: E501 + + existing_creator = testing.db.pool._creator + + async def async_creator(x, y, *, z=None): + sync_conn = await greenlet_spawn(existing_creator) + return sync_conn.driver_connection + + async_creator = mock.Mock(side_effect=async_creator) + + async_dbapi = testing.db.dialect.loaded_dbapi + + conn = await greenlet_spawn( + async_dbapi.connect, 5, y=10, z=8, async_creator_fn=async_creator + ) + try: + eq_(async_creator.mock_calls, [mock.call(5, y=10, z=8)]) + finally: + await greenlet_spawn(conn.close) + class AsyncCreatePoolTest(fixtures.TestBase): @config.fixture