From 813620b226689dfe9788776e1a703b60a7c3699e Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 29 May 2023 15:56:59 +0000 Subject: [PATCH] chore: add async_creator to asyncpg --- lib/sqlalchemy/dialects/postgresql/asyncpg.py | 6 ++++-- lib/sqlalchemy/ext/asyncio/engine.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 6827151f36..172917e773 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -161,6 +161,7 @@ from __future__ import annotations import collections import decimal +import functools import json as _py_json import re import time @@ -875,6 +876,7 @@ class AsyncAdapt_asyncpg_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop("creator_fn", functools.partial(self.asyncpg.connect)) prepared_statement_cache_size = kw.pop( "prepared_statement_cache_size", 100 ) @@ -885,14 +887,14 @@ class AsyncAdapt_asyncpg_dbapi: if util.asbool(async_fallback): return AsyncAdaptFallback_asyncpg_connection( self, - await_fallback(self.asyncpg.connect(*arg, **kw)), + await_fallback(creator_fn(*arg, **kw)), prepared_statement_cache_size=prepared_statement_cache_size, prepared_statement_name_func=prepared_statement_name_func, ) else: return AsyncAdapt_asyncpg_connection( self, - await_only(self.asyncpg.connect(*arg, **kw)), + await_only(creator_fn(*arg, **kw)), prepared_statement_cache_size=prepared_statement_cache_size, prepared_statement_name_func=prepared_statement_name_func, ) diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 531abdde52..e89c4b4e6d 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -82,6 +82,19 @@ def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine: "streaming result set" ) kw["_is_async"] = True + async_creator = kw.pop("async_creator", None) + if async_creator: + async def wrap_async_creator(): + return await async_creator() + + def creator(): + # note that to send adapted arguments like + # prepared_statement_cache_size, user would use + # "creator" and emulate this form here + return sync_engine.dialect.dbapi.connect( + creator_fn=wrap_async_creator + ) + kw["creator"] = creator sync_engine = _create_engine(url, **kw) return AsyncEngine(sync_engine) -- 2.47.3