From: Jack Wotherspoon Date: Tue, 30 May 2023 15:09:32 +0000 (-0400) Subject: feat: add other async drivers X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=0ae1621aa34ea16520a847f4b382e20e1a42fd20;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git feat: add other async drivers --- diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index bc079ba17a..a4d6ea9c31 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -34,6 +34,7 @@ This dialect should normally be used only with the """ # noqa +from functools import partial from .pymysql import MySQLDialect_pymysql from ... import pool @@ -255,16 +256,17 @@ class AsyncAdapt_aiomysql_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop("creator_fn", partial(self.aiomysql.connect)) if util.asbool(async_fallback): return AsyncAdaptFallback_aiomysql_connection( self, - await_fallback(self.aiomysql.connect(*arg, **kw)), + await_fallback(creator_fn(*arg, **kw)), ) else: return AsyncAdapt_aiomysql_connection( self, - await_only(self.aiomysql.connect(*arg, **kw)), + await_only(creator_fn(*arg, **kw)), ) diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py index d3f809e6ae..6f5b506a71 100644 --- a/lib/sqlalchemy/dialects/mysql/asyncmy.py +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -29,8 +29,8 @@ This dialect should normally be used only with the """ # noqa - from contextlib import asynccontextmanager +from functools import partial from .pymysql import MySQLDialect_pymysql from ... import pool @@ -267,16 +267,17 @@ class AsyncAdapt_asyncmy_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop("creator_fn", partial(self.asyncmy.connect)) if util.asbool(async_fallback): return AsyncAdaptFallback_asyncmy_connection( self, - await_fallback(self.asyncmy.connect(*arg, **kw)), + await_fallback(creator_fn(*arg, **kw)), ) else: return AsyncAdapt_asyncmy_connection( self, - await_only(self.asyncmy.connect(*arg, **kw)), + await_only(creator_fn(*arg, **kw)), ) diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index efb758a00f..783552dbf5 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -161,7 +161,7 @@ from __future__ import annotations import collections import decimal -import functools +from functools import partial import json as _py_json import re import time @@ -873,9 +873,7 @@ class AsyncAdapt_asyncpg_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) - creator_fn = kw.pop( - "creator_fn", functools.partial(self.asyncpg.connect) - ) + creator_fn = kw.pop("creator_fn", partial(self.asyncpg.connect)) prepared_statement_cache_size = kw.pop( "prepared_statement_cache_size", 100 ) diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py index 2981976acc..f63bc45f15 100644 --- a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -298,8 +298,8 @@ class AsyncAdapt_aiosqlite_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) - - connection = self.aiosqlite.connect(*arg, **kw) + creator_fn = kw.pop("creator_fn", partial(self.aiosqlite.connect)) + connection = creator_fn(*arg, **kw) # it's a Thread. you'll thank us later connection.daemon = True diff --git a/test/dialect/mysql/test_dialect.py b/test/dialect/mysql/test_dialect.py index ed0fc6faca..333a84d61f 100644 --- a/test/dialect/mysql/test_dialect.py +++ b/test/dialect/mysql/test_dialect.py @@ -1,4 +1,5 @@ import datetime +from functools import partial from sqlalchemy import bindparam from sqlalchemy import Column @@ -14,6 +15,8 @@ from sqlalchemy import testing from sqlalchemy.dialects import mysql from sqlalchemy.engine.url import make_url from sqlalchemy.testing import assert_raises_message +from sqlalchemy.testing import async_test +from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings @@ -610,3 +613,36 @@ class ExecutionTest(fixtures.TestBase): def test_sysdate(self, connection): d = connection.execute(func.sysdate()).scalar() assert isinstance(d, datetime.datetime) + + +class AsyncMySQLTest(fixtures.TestBase): + __requires__ = ("async_dialect",) + __only_on__ = "mysql+aiomysql", "mysql+asyncmy" + + @async_test + async def test_async_creator(self, async_testing_engine): + if testing.against("mysql+aiomysql"): + import aiomysql + + connect_func = partial(aiomysql.connect) + if testing.against("mysql+asyncmy"): + import asyncmy + + connect_func = partial(asyncmy.connect) + + async def async_creator(): + conn = await connect_func( + host=config.db.url.host, + port=config.db.url.port, + user=config.db.url.username, + password=config.db.url.password, + db=config.db.url.database, + ) + return conn + + engine = async_testing_engine( + options={"async_creator": async_creator}, + ) + async with engine.connect() as conn: + result = await conn.execute(select(1)) + eq_(result.scalar(), 1) diff --git a/test/dialect/postgresql/test_async_pg_py3k.py b/test/dialect/postgresql/test_async_pg_py3k.py index d33924b53f..57639a6e19 100644 --- a/test/dialect/postgresql/test_async_pg_py3k.py +++ b/test/dialect/postgresql/test_async_pg_py3k.py @@ -12,6 +12,7 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy.dialects.postgresql import ENUM from sqlalchemy.testing import async_test +from sqlalchemy.testing import config from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock @@ -285,13 +286,15 @@ class AsyncPgTest(fixtures.TestBase): assert len(cache) > 0 @async_test - async def test_async_creator(self, metadata, async_testing_engine): + async def test_async_creator(self, async_testing_engine): import asyncpg + url = config.db.url.render_as_string(hide_password=False) + # format URL properly, strip driver + url = url.replace("+asyncpg", "") + async def async_creator(): - conn = await asyncpg.connect( - "postgresql://scott:tiger@127.0.0.1:5432/test" - ) + conn = await asyncpg.connect(url) return conn engine = async_testing_engine( diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 0817bdbf36..2cbdad4939 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -44,6 +44,7 @@ from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import AssertsExecutionResults +from sqlalchemy.testing import async_test from sqlalchemy.testing import combinations from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ @@ -3592,3 +3593,23 @@ class ReflectInternalSchemaTables(fixtures.TablesTest): eq_(res, ["sqlitetempview"]) finally: connection.exec_driver_sql("DROP VIEW sqlitetempview") + + +class AsyncSqlliteTest(fixtures.TestBase): + __requires__ = ("async_dialect",) + __only_on__ = "sqlite+aiosqlite" + + @async_test + async def test_async_creator(self, async_testing_engine): + import aiosqlite + + async def async_creator(): + conn = await aiosqlite.connect(":memory:") + return conn + + engine = async_testing_engine( + options={"async_creator": async_creator}, + ) + async with engine.connect() as conn: + result = await conn.execute(select(1)) + eq_(result.scalar(), 1)