]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add pool connection_class parameter
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 10 Mar 2021 02:39:46 +0000 (03:39 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool/async_pool.py
psycopg3/psycopg3/pool/pool.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index 8b8ffbb10e7522fccbfec6935d3733b6095b2d70..dd729732f645fb613b4f5a9ed8e361ea0de4baf0 100644 (file)
@@ -32,6 +32,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
         self,
         conninfo: str = "",
         *,
+        connection_class: Type[AsyncConnection] = AsyncConnection,
         configure: Optional[
             Callable[[AsyncConnection], Awaitable[None]]
         ] = None,
@@ -44,6 +45,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
                 "async pool not supported before Python 3.7"
             )
 
+        self.connection_class = connection_class
         self._configure = configure
         self._reset = reset
 
@@ -354,7 +356,9 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
 
     async def _connect(self) -> AsyncConnection:
         """Return a new connection configured for the pool."""
-        conn = await AsyncConnection.connect(self.conninfo, **self.kwargs)
+        conn = await self.connection_class.connect(
+            self.conninfo, **self.kwargs
+        )
         conn._pool = self
 
         if self._configure:
index 9c24bda7d8dc86dc61b3fdc6f68bb58398303521..2efbb6e3555a665fd15f88ee1f7742a4163f8107 100644 (file)
@@ -31,10 +31,12 @@ class ConnectionPool(BasePool[Connection]):
         self,
         conninfo: str = "",
         *,
+        connection_class: Type[Connection] = Connection,
         configure: Optional[Callable[[Connection], None]] = None,
         reset: Optional[Callable[[Connection], None]] = None,
         **kwargs: Any,
     ):
+        self.connection_class = connection_class
         self._configure = configure
         self._reset = reset
 
@@ -390,7 +392,7 @@ class ConnectionPool(BasePool[Connection]):
         self._stats[self._CONNECTIONS_NUM] += 1
         t0 = monotonic()
         try:
-            conn = Connection.connect(self.conninfo, **self.kwargs)
+            conn = self.connection_class.connect(self.conninfo, **self.kwargs)
         except Exception:
             self._stats[self._CONNECTIONS_ERRORS] += 1
             raise
index 369d25a45a6cfcea1c82853685e7557dd6b485b8..ec71b04ad9fd5994ddd992c4924f018946a2a139 100644 (file)
@@ -33,6 +33,15 @@ def test_minconn_maxconn(dsn):
         pool.ConnectionPool(dsn, minconn=4, maxconn=2)
 
 
+def test_connection_class(dsn):
+    class MyConn(psycopg3.Connection):
+        pass
+
+    with pool.ConnectionPool(dsn, connection_class=MyConn, minconn=1) as p:
+        with p.connection() as conn:
+            assert isinstance(conn, MyConn)
+
+
 def test_kwargs(dsn):
     with pool.ConnectionPool(dsn, kwargs={"autocommit": True}, minconn=1) as p:
         with p.connection() as conn:
index d9436c9375d60b5a0d6186b77212a8819a850e09..1aab09cc41ea8033ccf70264ac184322125ffb59 100644 (file)
@@ -41,6 +41,17 @@ async def test_minconn_maxconn(dsn):
         pool.AsyncConnectionPool(dsn, minconn=4, maxconn=2)
 
 
+async def test_connection_class(dsn):
+    class MyConn(psycopg3.AsyncConnection):
+        pass
+
+    async with pool.AsyncConnectionPool(
+        dsn, connection_class=MyConn, minconn=1
+    ) as p:
+        async with p.connection() as conn:
+            assert isinstance(conn, MyConn)
+
+
 async def test_kwargs(dsn):
     async with pool.AsyncConnectionPool(
         dsn, kwargs={"autocommit": True}, minconn=1