From e99342b9374e1464967058ab43037a44b860ce4c Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sat, 27 Feb 2021 21:44:07 +0100 Subject: [PATCH] Add pool context managers --- psycopg3/psycopg3/pool/async_pool.py | 15 ++++++++++++++- psycopg3/psycopg3/pool/pool.py | 14 +++++++++++++- tests/pool/test_pool.py | 6 ++++++ tests/pool/test_pool_async.py | 6 ++++++ 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/psycopg3/psycopg3/pool/async_pool.py b/psycopg3/psycopg3/pool/async_pool.py index c1e34add8..c024214d4 100644 --- a/psycopg3/psycopg3/pool/async_pool.py +++ b/psycopg3/psycopg3/pool/async_pool.py @@ -8,7 +8,9 @@ import sys import asyncio import logging from time import monotonic -from typing import Any, Awaitable, Callable, Deque, AsyncIterator, Optional +from types import TracebackType +from typing import Any, AsyncIterator, Awaitable, Callable, Deque +from typing import Optional, Type from collections import deque from ..pq import TransactionStatus @@ -233,6 +235,17 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): timeout, ) + async def __aenter__(self) -> "AsyncConnectionPool": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() + async def resize( self, minconn: int, maxconn: Optional[int] = None ) -> None: diff --git a/psycopg3/psycopg3/pool/pool.py b/psycopg3/psycopg3/pool/pool.py index c68a8615e..71c07ebc7 100644 --- a/psycopg3/psycopg3/pool/pool.py +++ b/psycopg3/psycopg3/pool/pool.py @@ -7,7 +7,8 @@ psycopg3 synchronous connection pool import logging import threading from time import monotonic -from typing import Any, Callable, Deque, Iterator, Optional +from types import TracebackType +from typing import Any, Callable, Deque, Iterator, Optional, Type from contextlib import contextmanager from collections import deque @@ -215,6 +216,17 @@ class ConnectionPool(BasePool[Connection]): timeout, ) + def __enter__(self) -> "ConnectionPool": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + def resize(self, minconn: int, maxconn: Optional[int] = None) -> None: if maxconn is None: maxconn = minconn diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index 840b3b1e2..80758bb6d 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -62,6 +62,12 @@ def test_connection_not_lost(dsn): assert conn2.pgconn.backend_pid == pid +def test_context(dsn): + with pool.ConnectionPool(dsn, minconn=1) as p: + assert not p.closed + assert p.closed + + @pytest.mark.slow def test_concurrent_filling(dsn, monkeypatch): delay_connection(monkeypatch, 0.1) diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index 71542f817..09794db98 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -67,6 +67,12 @@ async def test_its_really_a_pool(dsn): await p.close() +async def test_context(dsn): + async with pool.AsyncConnectionPool(dsn, minconn=1) as p: + assert not p.closed + assert p.closed + + async def test_connection_not_lost(dsn): p = pool.AsyncConnectionPool(dsn, minconn=1) with pytest.raises(ZeroDivisionError): -- 2.47.3