From: Daniele Varrazzo Date: Sat, 27 Feb 2021 20:44:07 +0000 (+0100) Subject: Add pool context managers X-Git-Tag: 3.0.dev0~87^2~38 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e99342b9374e1464967058ab43037a44b860ce4c;p=thirdparty%2Fpsycopg.git Add pool context managers --- diff --git a/psycopg3/psycopg3/pool/async_pool.py b/psycopg3/psycopg3/pool/async_pool.py index c1e34add8..c024214d4 100644 --- a/psycopg3/psycopg3/pool/async_pool.py +++ b/psycopg3/psycopg3/pool/async_pool.py @@ -8,7 +8,9 @@ import sys import asyncio import logging from time import monotonic -from typing import Any, Awaitable, Callable, Deque, AsyncIterator, Optional +from types import TracebackType +from typing import Any, AsyncIterator, Awaitable, Callable, Deque +from typing import Optional, Type from collections import deque from ..pq import TransactionStatus @@ -233,6 +235,17 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): timeout, ) + async def __aenter__(self) -> "AsyncConnectionPool": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() + async def resize( self, minconn: int, maxconn: Optional[int] = None ) -> None: diff --git a/psycopg3/psycopg3/pool/pool.py b/psycopg3/psycopg3/pool/pool.py index c68a8615e..71c07ebc7 100644 --- a/psycopg3/psycopg3/pool/pool.py +++ b/psycopg3/psycopg3/pool/pool.py @@ -7,7 +7,8 @@ psycopg3 synchronous connection pool import logging import threading from time import monotonic -from typing import Any, Callable, Deque, Iterator, Optional +from types import TracebackType +from typing import Any, Callable, Deque, Iterator, Optional, Type from contextlib import contextmanager from collections import deque @@ -215,6 +216,17 @@ class ConnectionPool(BasePool[Connection]): timeout, ) + def __enter__(self) -> "ConnectionPool": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + def resize(self, minconn: int, maxconn: Optional[int] = None) -> None: if maxconn is None: maxconn = minconn diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index 840b3b1e2..80758bb6d 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -62,6 +62,12 @@ def test_connection_not_lost(dsn): assert conn2.pgconn.backend_pid == pid +def test_context(dsn): + with pool.ConnectionPool(dsn, minconn=1) as p: + assert not p.closed + assert p.closed + + @pytest.mark.slow def test_concurrent_filling(dsn, monkeypatch): delay_connection(monkeypatch, 0.1) diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index 71542f817..09794db98 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -67,6 +67,12 @@ async def test_its_really_a_pool(dsn): await p.close() +async def test_context(dsn): + async with pool.AsyncConnectionPool(dsn, minconn=1) as p: + assert not p.closed + assert p.closed + + async def test_connection_not_lost(dsn): p = pool.AsyncConnectionPool(dsn, minconn=1) with pytest.raises(ZeroDivisionError):