]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
chore: add async_creator to asyncpg
authorjackwotherspoon <jackwoth@google.com>
Mon, 29 May 2023 15:56:59 +0000 (15:56 +0000)
committerjackwotherspoon <jackwoth@google.com>
Mon, 29 May 2023 15:56:59 +0000 (15:56 +0000)
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/ext/asyncio/engine.py

index 6827151f363aa97c57d1fb565738b9d47d599c1d..172917e77387cab842e62d225c6064be37e9f693 100644 (file)
@@ -161,6 +161,7 @@ from __future__ import annotations
 
 import collections
 import decimal
+import functools
 import json as _py_json
 import re
 import time
@@ -875,6 +876,7 @@ class AsyncAdapt_asyncpg_dbapi:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
+        creator_fn = kw.pop("creator_fn", functools.partial(self.asyncpg.connect))
         prepared_statement_cache_size = kw.pop(
             "prepared_statement_cache_size", 100
         )
@@ -885,14 +887,14 @@ class AsyncAdapt_asyncpg_dbapi:
         if util.asbool(async_fallback):
             return AsyncAdaptFallback_asyncpg_connection(
                 self,
-                await_fallback(self.asyncpg.connect(*arg, **kw)),
+                await_fallback(creator_fn(*arg, **kw)),
                 prepared_statement_cache_size=prepared_statement_cache_size,
                 prepared_statement_name_func=prepared_statement_name_func,
             )
         else:
             return AsyncAdapt_asyncpg_connection(
                 self,
-                await_only(self.asyncpg.connect(*arg, **kw)),
+                await_only(creator_fn(*arg, **kw)),
                 prepared_statement_cache_size=prepared_statement_cache_size,
                 prepared_statement_name_func=prepared_statement_name_func,
             )
index 531abdde52ea04d8aef5271d9e8f30e3a6534ee7..e89c4b4e6dbdf58e8e5a8fb38de9c70d4baf603b 100644 (file)
@@ -82,6 +82,19 @@ def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine:
             "streaming result set"
         )
     kw["_is_async"] = True
+    async_creator = kw.pop("async_creator", None)
+    if async_creator:
+        async def wrap_async_creator():
+            return await async_creator()
+
+        def creator():
+            # note that to send adapted arguments like
+            # prepared_statement_cache_size, user would use
+            # "creator" and emulate this form here
+            return sync_engine.dialect.dbapi.connect(
+                creator_fn=wrap_async_creator
+            )
+        kw["creator"] = creator
     sync_engine = _create_engine(url, **kw)
     return AsyncEngine(sync_engine)