--- /dev/null
+.. change::
+ :tags: usecase, asyncio
+ :tickets: 8215
+
+ Added new :paramref:`_asyncio.create_async_engine.async_creator` parameter
+ to :func:`.create_async_engine`, which accomplishes the same purpose as the
+ :paramref:`.create_engine.creator` parameter of :func:`.create_engine`.
+ This is a no-argument callable that provides a new asyncio connection,
+ using the asyncio database driver directly. The
+ :func:`.create_async_engine` function will wrap the driver-level connection
+ in the appropriate structures. Pull request curtesy of Jack Wotherspoon.
""" # noqa
-
from .pymysql import MySQLDialect_pymysql
from ... import pool
from ... import util
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
+ creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
if util.asbool(async_fallback):
return AsyncAdaptFallback_aiomysql_connection(
self,
- await_fallback(self.aiomysql.connect(*arg, **kw)),
+ await_fallback(creator_fn(*arg, **kw)),
)
else:
return AsyncAdapt_aiomysql_connection(
self,
- await_only(self.aiomysql.connect(*arg, **kw)),
+ await_only(creator_fn(*arg, **kw)),
)
""" # noqa
-
from contextlib import asynccontextmanager
from .pymysql import MySQLDialect_pymysql
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
+ creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect)
if util.asbool(async_fallback):
return AsyncAdaptFallback_asyncmy_connection(
self,
- await_fallback(self.asyncmy.connect(*arg, **kw)),
+ await_fallback(creator_fn(*arg, **kw)),
)
else:
return AsyncAdapt_asyncmy_connection(
self,
- await_only(self.asyncmy.connect(*arg, **kw)),
+ await_only(creator_fn(*arg, **kw)),
)
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
+ creator_fn = kw.pop("async_creator_fn", self.asyncpg.connect)
prepared_statement_cache_size = kw.pop(
"prepared_statement_cache_size", 100
)
if util.asbool(async_fallback):
return AsyncAdaptFallback_asyncpg_connection(
self,
- await_fallback(self.asyncpg.connect(*arg, **kw)),
+ await_fallback(creator_fn(*arg, **kw)),
prepared_statement_cache_size=prepared_statement_cache_size,
prepared_statement_name_func=prepared_statement_name_func,
)
else:
return AsyncAdapt_asyncpg_connection(
self,
- await_only(self.asyncpg.connect(*arg, **kw)),
+ await_only(creator_fn(*arg, **kw)),
prepared_statement_cache_size=prepared_statement_cache_size,
prepared_statement_name_func=prepared_statement_name_func,
)
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
+ creator_fn = kw.pop(
+ "async_creator_fn", self.psycopg.AsyncConnection.connect
+ )
if util.asbool(async_fallback):
return AsyncAdaptFallback_psycopg_connection(
- await_fallback(
- self.psycopg.AsyncConnection.connect(*arg, **kw)
- )
+ await_fallback(creator_fn(*arg, **kw))
)
else:
return AsyncAdapt_psycopg_connection(
- await_only(self.psycopg.AsyncConnection.connect(*arg, **kw))
+ await_only(creator_fn(*arg, **kw))
)
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
- connection = self.aiosqlite.connect(*arg, **kw)
-
- # it's a Thread. you'll thank us later
- connection.daemon = True
+ creator_fn = kw.pop("async_creator_fn", None)
+ if creator_fn:
+ connection = creator_fn(*arg, **kw)
+ else:
+ connection = self.aiosqlite.connect(*arg, **kw)
+ # it's a Thread. you'll thank us later
+ connection.daemon = True
if util.asbool(async_fallback):
return AsyncAdaptFallback_aiosqlite_connection(
from ...engine import Engine
from ...engine.base import NestedTransaction
from ...engine.base import Transaction
+from ...exc import ArgumentError
from ...util.concurrency import greenlet_spawn
if TYPE_CHECKING:
.. versionadded:: 1.4
+ :param async_creator: an async callable which returns a driver-level
+ asyncio connection. If given, the function should take no arguments,
+ and return a new asyncio connection from the underlying asyncio
+ database driver; the connection will be wrapped in the appropriate
+ structures to be used with the :class:`.AsyncEngine`. Note that the
+ parameters specified in the URL are not applied here, and the creator
+ function should use its own connection parameters.
+
+ This parameter is the asyncio equivalent of the
+ :paramref:`_sa.create_engine.creator` parameter of the
+ :func:`_sa.create_engine` function.
+
+ .. versionadded:: 2.0.16
+
"""
if kw.get("server_side_cursors", False):
"streaming result set"
)
kw["_is_async"] = True
+ async_creator = kw.pop("async_creator", None)
+ if async_creator:
+ if kw.get("creator", None):
+ raise ArgumentError(
+ "Can only specify one of 'async_creator' or 'creator', "
+ "not both."
+ )
+
+ def creator() -> Any:
+ # note that to send adapted arguments like
+ # prepared_statement_cache_size, user would use
+ # "creator" and emulate this form here
+ return sync_engine.dialect.dbapi.connect( # type: ignore
+ async_creator_fn=async_creator
+ )
+
+ kw["creator"] = creator
sync_engine = _create_engine(url, **kw)
return AsyncEngine(sync_engine)
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing import ne_
-from sqlalchemy.util.concurrency import greenlet_spawn
+from sqlalchemy.util import greenlet_spawn
class AsyncFixture:
@async_test
async def test_create_async_engine_server_side_cursor(self, async_engine):
- testing.assert_raises_message(
+ with expect_raises_message(
asyncio_exc.AsyncMethodRequired,
"Can't set server_side_cursors for async engine globally",
- create_async_engine,
- testing.db.url,
- server_side_cursors=True,
- )
+ ):
+ create_async_engine(
+ testing.db.url,
+ server_side_cursors=True,
+ )
def test_async_engine_from_config(self):
config = {
assert engine.echo is True
assert engine.dialect.is_async is True
+ def test_async_creator_and_creator(self):
+ async def ac():
+ return None
+
+ def c():
+ return None
+
+ with expect_raises_message(
+ exc.ArgumentError,
+ "Can only specify one of 'async_creator' or 'creator', "
+ "not both.",
+ ):
+ create_async_engine(testing.db.url, creator=c, async_creator=ac)
+
+ @async_test
+ async def test_async_creator_invoked(self, async_testing_engine):
+ """test for #8215"""
+
+ existing_creator = testing.db.pool._creator
+
+ async def async_creator():
+ sync_conn = await greenlet_spawn(existing_creator)
+ return sync_conn.driver_connection
+
+ async_creator = mock.Mock(side_effect=async_creator)
+
+ eq_(async_creator.mock_calls, [])
+
+ engine = async_testing_engine(options={"async_creator": async_creator})
+ async with engine.connect() as conn:
+ result = await conn.scalar(select(1))
+ eq_(result, 1)
+
+ eq_(async_creator.mock_calls, [mock.call()])
+
+ @async_test
+ async def test_async_creator_accepts_args_if_called_directly(
+ self, async_testing_engine
+ ):
+ """supplemental test for #8215.
+
+ The "async_creator" passed to create_async_engine() is expected to take
+ no arguments, the same way as "creator" passed to create_engine()
+ works.
+
+ However, the ultimate "async_creator" received by the sync-emulating
+ DBAPI *does* take arguments in its ``.connect()`` method, which will be
+ all the other arguments passed to ``.connect()``. This functionality
+ is not currently used, however was decided that the creator should
+ internally work this way for improved flexibility; see
+ https://github.com/sqlalchemy/sqlalchemy/issues/8215#issuecomment-1181791539.
+ That contract is tested here.
+
+ """ # noqa: E501
+
+ existing_creator = testing.db.pool._creator
+
+ async def async_creator(x, y, *, z=None):
+ sync_conn = await greenlet_spawn(existing_creator)
+ return sync_conn.driver_connection
+
+ async_creator = mock.Mock(side_effect=async_creator)
+
+ async_dbapi = testing.db.dialect.loaded_dbapi
+
+ conn = await greenlet_spawn(
+ async_dbapi.connect, 5, y=10, z=8, async_creator_fn=async_creator
+ )
+ try:
+ eq_(async_creator.mock_calls, [mock.call(5, y=10, z=8)])
+ finally:
+ await greenlet_spawn(conn.close)
+
class AsyncCreatePoolTest(fixtures.TestBase):
@config.fixture