]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
chore: add psycopg support (#4)
authorJack Wotherspoon <jackwoth@google.com>
Thu, 1 Jun 2023 15:14:18 +0000 (11:14 -0400)
committerGitHub <noreply@github.com>
Thu, 1 Jun 2023 15:14:18 +0000 (11:14 -0400)
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/ext/asyncio/engine.py
test/dialect/postgresql/test_dialect.py

index 3f11556cf542d1f7fd80b71f7a351afc055aa358..43925841c92b1f3cc44385a2ed1705164bfc7491 100644 (file)
@@ -678,15 +678,14 @@ class PsycopgAdaptDBAPI:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
+        creator_fn = kw.pop("creator_fn", self.psycopg.AsyncConnection.connect)
         if util.asbool(async_fallback):
             return AsyncAdaptFallback_psycopg_connection(
-                await_fallback(
-                    self.psycopg.AsyncConnection.connect(*arg, **kw)
-                )
+                await_fallback(creator_fn(*arg, **kw))
             )
         else:
             return AsyncAdapt_psycopg_connection(
-                await_only(self.psycopg.AsyncConnection.connect(*arg, **kw))
+                await_only(creator_fn(*arg, **kw))
             )
 
 
index 7bdb8e24f8aadae7021a30b94cc2da2343537466..14b20fe75539ea7342fa1b4dfee33eb2048d6d70 100644 (file)
@@ -95,7 +95,9 @@ def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine:
             # note that to send adapted arguments like
             # prepared_statement_cache_size, user would use
             # "creator" and emulate this form here
-            return sync_engine.dialect.dbapi.connect(creator_fn=async_creator)
+            return sync_engine.dialect.dbapi.connect(  # type: ignore
+                creator_fn=async_creator
+            )
 
         kw["creator"] = creator
     sync_engine = _create_engine(url, **kw)
index 78f561a35256b5d77f1a4d5d2f534431e96b4952..deaea50c02046afe42a08d49062d28626b179deb 100644 (file)
@@ -39,6 +39,7 @@ from sqlalchemy.dialects.postgresql.psycopg2 import (
 )
 from sqlalchemy.engine import url
 from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from sqlalchemy.testing import async_test
 from sqlalchemy.testing import config
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import expect_raises_message
@@ -1214,3 +1215,27 @@ class Psycopg3Test(fixtures.TestBase):
     def test_async_version(self):
         e = create_engine("postgresql+psycopg_async://")
         is_true(isinstance(e.dialect, psycopg_dialect.PGDialectAsync_psycopg))
+
+
+class AsyncPostgresTest(fixtures.TestBase):
+    __requires__ = ("async_dialect",)
+
+    @testing.only_on("postgresql+psycopg")
+    @async_test
+    async def test_async_creator(self, async_testing_engine):
+        import psycopg
+
+        url = config.db.url.render_as_string(hide_password=False)
+        # format URL properly, strip driver
+        url = url.replace("+psycopg_async", "")
+
+        async def async_creator():
+            conn = await psycopg.AsyncConnection.connect(url)
+            return conn
+
+        engine = async_testing_engine(
+            options={"async_creator": async_creator},
+        )
+        async with engine.connect() as conn:
+            result = await conn.execute(select(1))
+            eq_(result.scalar(), 1)