]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
feat: add other async drivers
authorJack Wotherspoon <jackwoth@google.com>
Tue, 30 May 2023 15:09:32 +0000 (11:09 -0400)
committerGitHub <noreply@github.com>
Tue, 30 May 2023 15:09:32 +0000 (11:09 -0400)
lib/sqlalchemy/dialects/mysql/aiomysql.py
lib/sqlalchemy/dialects/mysql/asyncmy.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/sqlite/aiosqlite.py
test/dialect/mysql/test_dialect.py
test/dialect/postgresql/test_async_pg_py3k.py
test/dialect/test_sqlite.py

index bc079ba17ac34f549329a5e6830708332b0c042b..a4d6ea9c31da4a3aee92f1441dc05731554ec8e1 100644 (file)
@@ -34,6 +34,7 @@ This dialect should normally be used only with the
 
 
 """  # noqa
+from functools import partial
 
 from .pymysql import MySQLDialect_pymysql
 from ... import pool
@@ -255,16 +256,17 @@ class AsyncAdapt_aiomysql_dbapi:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
+        creator_fn = kw.pop("creator_fn", partial(self.aiomysql.connect))
 
         if util.asbool(async_fallback):
             return AsyncAdaptFallback_aiomysql_connection(
                 self,
-                await_fallback(self.aiomysql.connect(*arg, **kw)),
+                await_fallback(creator_fn(*arg, **kw)),
             )
         else:
             return AsyncAdapt_aiomysql_connection(
                 self,
-                await_only(self.aiomysql.connect(*arg, **kw)),
+                await_only(creator_fn(*arg, **kw)),
             )
 
 
index d3f809e6aed3bcb3d098814f532e30f00591826c..6f5b506a71c6afb7e7a8fdbd4229cca7e5f8cdeb 100644 (file)
@@ -29,8 +29,8 @@ This dialect should normally be used only with the
 
 
 """  # noqa
-
 from contextlib import asynccontextmanager
+from functools import partial
 
 from .pymysql import MySQLDialect_pymysql
 from ... import pool
@@ -267,16 +267,17 @@ class AsyncAdapt_asyncmy_dbapi:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
+        creator_fn = kw.pop("creator_fn", partial(self.asyncmy.connect))
 
         if util.asbool(async_fallback):
             return AsyncAdaptFallback_asyncmy_connection(
                 self,
-                await_fallback(self.asyncmy.connect(*arg, **kw)),
+                await_fallback(creator_fn(*arg, **kw)),
             )
         else:
             return AsyncAdapt_asyncmy_connection(
                 self,
-                await_only(self.asyncmy.connect(*arg, **kw)),
+                await_only(creator_fn(*arg, **kw)),
             )
 
 
index efb758a00f138b0f59e34a676cdbce31c3b36191..783552dbf5d980d8cd3e4c3ec3e66658837375b9 100644 (file)
@@ -161,7 +161,7 @@ from __future__ import annotations
 
 import collections
 import decimal
-import functools
+from functools import partial
 import json as _py_json
 import re
 import time
@@ -873,9 +873,7 @@ class AsyncAdapt_asyncpg_dbapi:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
-        creator_fn = kw.pop(
-            "creator_fn", functools.partial(self.asyncpg.connect)
-        )
+        creator_fn = kw.pop("creator_fn", partial(self.asyncpg.connect))
         prepared_statement_cache_size = kw.pop(
             "prepared_statement_cache_size", 100
         )
index 2981976acc108033cb9da67bdf1c3a36ecca1cef..f63bc45f15cc5a82019699b64535d1526edc65ed 100644 (file)
@@ -298,8 +298,8 @@ class AsyncAdapt_aiosqlite_dbapi:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
-
-        connection = self.aiosqlite.connect(*arg, **kw)
+        creator_fn = kw.pop("creator_fn", partial(self.aiosqlite.connect))
+        connection = creator_fn(*arg, **kw)
 
         # it's a Thread.   you'll thank us later
         connection.daemon = True
index ed0fc6faca6ac1c7730833770b2ec3cfeb593db2..333a84d61f92d90be74784114a501bc341f46bc7 100644 (file)
@@ -1,4 +1,5 @@
 import datetime
+from functools import partial
 
 from sqlalchemy import bindparam
 from sqlalchemy import Column
@@ -14,6 +15,8 @@ from sqlalchemy import testing
 from sqlalchemy.dialects import mysql
 from sqlalchemy.engine.url import make_url
 from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import async_test
+from sqlalchemy.testing import config
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_warnings
@@ -610,3 +613,36 @@ class ExecutionTest(fixtures.TestBase):
     def test_sysdate(self, connection):
         d = connection.execute(func.sysdate()).scalar()
         assert isinstance(d, datetime.datetime)
+
+
+class AsyncMySQLTest(fixtures.TestBase):
+    __requires__ = ("async_dialect",)
+    __only_on__ = "mysql+aiomysql", "mysql+asyncmy"
+
+    @async_test
+    async def test_async_creator(self, async_testing_engine):
+        if testing.against("mysql+aiomysql"):
+            import aiomysql
+
+            connect_func = partial(aiomysql.connect)
+        if testing.against("mysql+asyncmy"):
+            import asyncmy
+
+            connect_func = partial(asyncmy.connect)
+
+        async def async_creator():
+            conn = await connect_func(
+                host=config.db.url.host,
+                port=config.db.url.port,
+                user=config.db.url.username,
+                password=config.db.url.password,
+                db=config.db.url.database,
+            )
+            return conn
+
+        engine = async_testing_engine(
+            options={"async_creator": async_creator},
+        )
+        async with engine.connect() as conn:
+            result = await conn.execute(select(1))
+            eq_(result.scalar(), 1)
index d33924b53f1ca08cebe58a723da7b5e958f86844..57639a6e197483e9ac3b6e6f4516270d38f4e569 100644 (file)
@@ -12,6 +12,7 @@ from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy.dialects.postgresql import ENUM
 from sqlalchemy.testing import async_test
+from sqlalchemy.testing import config
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
@@ -285,13 +286,15 @@ class AsyncPgTest(fixtures.TestBase):
             assert len(cache) > 0
 
     @async_test
-    async def test_async_creator(self, metadata, async_testing_engine):
+    async def test_async_creator(self, async_testing_engine):
         import asyncpg
 
+        url = config.db.url.render_as_string(hide_password=False)
+        # format URL properly, strip driver
+        url = url.replace("+asyncpg", "")
+
         async def async_creator():
-            conn = await asyncpg.connect(
-                "postgresql://scott:tiger@127.0.0.1:5432/test"
-            )
+            conn = await asyncpg.connect(url)
             return conn
 
         engine = async_testing_engine(
index 0817bdbf36339a8b5ef2d45b02fb2f6e3d2fae0c..2cbdad4939115362226dacc4efbae618e64a84a0 100644 (file)
@@ -44,6 +44,7 @@ from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import AssertsExecutionResults
+from sqlalchemy.testing import async_test
 from sqlalchemy.testing import combinations
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
@@ -3592,3 +3593,23 @@ class ReflectInternalSchemaTables(fixtures.TablesTest):
             eq_(res, ["sqlitetempview"])
         finally:
             connection.exec_driver_sql("DROP VIEW sqlitetempview")
+
+
+class AsyncSqlliteTest(fixtures.TestBase):
+    __requires__ = ("async_dialect",)
+    __only_on__ = "sqlite+aiosqlite"
+
+    @async_test
+    async def test_async_creator(self, async_testing_engine):
+        import aiosqlite
+
+        async def async_creator():
+            conn = await aiosqlite.connect(":memory:")
+            return conn
+
+        engine = async_testing_engine(
+            options={"async_creator": async_creator},
+        )
+        async with engine.connect() as conn:
+            result = await conn.execute(select(1))
+            eq_(result.scalar(), 1)