]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add pool context managers
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 27 Feb 2021 20:44:07 +0000 (21:44 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool/async_pool.py
psycopg3/psycopg3/pool/pool.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index c1e34add87fbf5635d9f10405906cfb5557c2736..c024214d4b2c0db85da6adb4bb1a8c3cc83b0f80 100644 (file)
@@ -8,7 +8,9 @@ import sys
 import asyncio
 import logging
 from time import monotonic
-from typing import Any, Awaitable, Callable, Deque, AsyncIterator, Optional
+from types import TracebackType
+from typing import Any, AsyncIterator, Awaitable, Callable, Deque
+from typing import Optional, Type
 from collections import deque
 
 from ..pq import TransactionStatus
@@ -233,6 +235,17 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
                         timeout,
                     )
 
+    async def __aenter__(self) -> "AsyncConnectionPool":
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        await self.close()
+
     async def resize(
         self, minconn: int, maxconn: Optional[int] = None
     ) -> None:
index c68a8615ec0a0e76e26ddd12849e614305357fe4..71c07ebc7d1c117c1eb5bce4fbe7913e0edcc3ea 100644 (file)
@@ -7,7 +7,8 @@ psycopg3 synchronous connection pool
 import logging
 import threading
 from time import monotonic
-from typing import Any, Callable, Deque, Iterator, Optional
+from types import TracebackType
+from typing import Any, Callable, Deque, Iterator, Optional, Type
 from contextlib import contextmanager
 from collections import deque
 
@@ -215,6 +216,17 @@ class ConnectionPool(BasePool[Connection]):
                         timeout,
                     )
 
+    def __enter__(self) -> "ConnectionPool":
+        return self
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        self.close()
+
     def resize(self, minconn: int, maxconn: Optional[int] = None) -> None:
         if maxconn is None:
             maxconn = minconn
index 840b3b1e265be9dc31b1f14a1f556c3f633ae6e9..80758bb6d3f50d02fb29feae6c54734fc9241488 100644 (file)
@@ -62,6 +62,12 @@ def test_connection_not_lost(dsn):
         assert conn2.pgconn.backend_pid == pid
 
 
+def test_context(dsn):
+    with pool.ConnectionPool(dsn, minconn=1) as p:
+        assert not p.closed
+    assert p.closed
+
+
 @pytest.mark.slow
 def test_concurrent_filling(dsn, monkeypatch):
     delay_connection(monkeypatch, 0.1)
index 71542f817601f4f4fed9a1fd77a896a00942de7d..09794db98ec631dce4cf2af321c7729f97cfc439 100644 (file)
@@ -67,6 +67,12 @@ async def test_its_really_a_pool(dsn):
     await p.close()
 
 
+async def test_context(dsn):
+    async with pool.AsyncConnectionPool(dsn, minconn=1) as p:
+        assert not p.closed
+    assert p.closed
+
+
 async def test_connection_not_lost(dsn):
     p = pool.AsyncConnectionPool(dsn, minconn=1)
     with pytest.raises(ZeroDivisionError):