From: Daniele Varrazzo Date: Wed, 10 Mar 2021 02:39:46 +0000 (+0100) Subject: Add pool connection_class parameter X-Git-Tag: 3.0.dev0~87^2~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f1b8ad802bc0de9494950a80ec1032c96fd2dd91;p=thirdparty%2Fpsycopg.git Add pool connection_class parameter --- diff --git a/psycopg3/psycopg3/pool/async_pool.py b/psycopg3/psycopg3/pool/async_pool.py index 8b8ffbb10..dd729732f 100644 --- a/psycopg3/psycopg3/pool/async_pool.py +++ b/psycopg3/psycopg3/pool/async_pool.py @@ -32,6 +32,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): self, conninfo: str = "", *, + connection_class: Type[AsyncConnection] = AsyncConnection, configure: Optional[ Callable[[AsyncConnection], Awaitable[None]] ] = None, @@ -44,6 +45,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): "async pool not supported before Python 3.7" ) + self.connection_class = connection_class self._configure = configure self._reset = reset @@ -354,7 +356,9 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): async def _connect(self) -> AsyncConnection: """Return a new connection configured for the pool.""" - conn = await AsyncConnection.connect(self.conninfo, **self.kwargs) + conn = await self.connection_class.connect( + self.conninfo, **self.kwargs + ) conn._pool = self if self._configure: diff --git a/psycopg3/psycopg3/pool/pool.py b/psycopg3/psycopg3/pool/pool.py index 9c24bda7d..2efbb6e35 100644 --- a/psycopg3/psycopg3/pool/pool.py +++ b/psycopg3/psycopg3/pool/pool.py @@ -31,10 +31,12 @@ class ConnectionPool(BasePool[Connection]): self, conninfo: str = "", *, + connection_class: Type[Connection] = Connection, configure: Optional[Callable[[Connection], None]] = None, reset: Optional[Callable[[Connection], None]] = None, **kwargs: Any, ): + self.connection_class = connection_class self._configure = configure self._reset = reset @@ -390,7 +392,7 @@ class ConnectionPool(BasePool[Connection]): self._stats[self._CONNECTIONS_NUM] += 1 t0 = monotonic() try: - conn = Connection.connect(self.conninfo, **self.kwargs) + conn = self.connection_class.connect(self.conninfo, **self.kwargs) except Exception: self._stats[self._CONNECTIONS_ERRORS] += 1 raise diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index 369d25a45..ec71b04ad 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -33,6 +33,15 @@ def test_minconn_maxconn(dsn): pool.ConnectionPool(dsn, minconn=4, maxconn=2) +def test_connection_class(dsn): + class MyConn(psycopg3.Connection): + pass + + with pool.ConnectionPool(dsn, connection_class=MyConn, minconn=1) as p: + with p.connection() as conn: + assert isinstance(conn, MyConn) + + def test_kwargs(dsn): with pool.ConnectionPool(dsn, kwargs={"autocommit": True}, minconn=1) as p: with p.connection() as conn: diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index d9436c937..1aab09cc4 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -41,6 +41,17 @@ async def test_minconn_maxconn(dsn): pool.AsyncConnectionPool(dsn, minconn=4, maxconn=2) +async def test_connection_class(dsn): + class MyConn(psycopg3.AsyncConnection): + pass + + async with pool.AsyncConnectionPool( + dsn, connection_class=MyConn, minconn=1 + ) as p: + async with p.connection() as conn: + assert isinstance(conn, MyConn) + + async def test_kwargs(dsn): async with pool.AsyncConnectionPool( dsn, kwargs={"autocommit": True}, minconn=1