From: Jack Wotherspoon Date: Thu, 1 Jun 2023 15:14:18 +0000 (-0400) Subject: chore: add psycopg support (#4) X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=408ab2f51415a11be5e80219eed40553119c527b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git chore: add psycopg support (#4) --- diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index 3f11556cf5..43925841c9 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -678,15 +678,14 @@ class PsycopgAdaptDBAPI: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop("creator_fn", self.psycopg.AsyncConnection.connect) if util.asbool(async_fallback): return AsyncAdaptFallback_psycopg_connection( - await_fallback( - self.psycopg.AsyncConnection.connect(*arg, **kw) - ) + await_fallback(creator_fn(*arg, **kw)) ) else: return AsyncAdapt_psycopg_connection( - await_only(self.psycopg.AsyncConnection.connect(*arg, **kw)) + await_only(creator_fn(*arg, **kw)) ) diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 7bdb8e24f8..14b20fe755 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -95,7 +95,9 @@ def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine: # note that to send adapted arguments like # prepared_statement_cache_size, user would use # "creator" and emulate this form here - return sync_engine.dialect.dbapi.connect(creator_fn=async_creator) + return sync_engine.dialect.dbapi.connect( # type: ignore + creator_fn=async_creator + ) kw["creator"] = creator sync_engine = _create_engine(url, **kw) diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 78f561a352..deaea50c02 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -39,6 +39,7 @@ from sqlalchemy.dialects.postgresql.psycopg2 import ( ) from sqlalchemy.engine import url from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from sqlalchemy.testing import async_test from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import expect_raises_message @@ -1214,3 +1215,27 @@ class Psycopg3Test(fixtures.TestBase): def test_async_version(self): e = create_engine("postgresql+psycopg_async://") is_true(isinstance(e.dialect, psycopg_dialect.PGDialectAsync_psycopg)) + + +class AsyncPostgresTest(fixtures.TestBase): + __requires__ = ("async_dialect",) + + @testing.only_on("postgresql+psycopg") + @async_test + async def test_async_creator(self, async_testing_engine): + import psycopg + + url = config.db.url.render_as_string(hide_password=False) + # format URL properly, strip driver + url = url.replace("+psycopg_async", "") + + async def async_creator(): + conn = await psycopg.AsyncConnection.connect(url) + return conn + + engine = async_testing_engine( + options={"async_creator": async_creator}, + ) + async with engine.connect() as conn: + result = await conn.execute(select(1)) + eq_(result.scalar(), 1)