]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added cursors context managers
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Oct 2020 16:33:44 +0000 (17:33 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Oct 2020 21:05:58 +0000 (22:05 +0100)
psycopg3/psycopg3/cursor.py
tests/test_cursor.py
tests/test_cursor_async.py

index 81bc86289e6a7414537b51acaf480e61cd296e2d..236165621880b4f838109b26b740367145c59985 100644 (file)
@@ -4,8 +4,9 @@ psycopg3 cursor objects
 
 # Copyright (C) 2020 The Psycopg Team
 
+from types import TracebackType
+from typing import Any, Callable, List, Optional, Sequence, Type, TYPE_CHECKING
 from operator import attrgetter
-from typing import Any, Callable, List, Optional, Sequence, TYPE_CHECKING
 
 from . import errors as e
 from . import pq
@@ -282,6 +283,17 @@ class Cursor(BaseCursor):
     ):
         super().__init__(connection, format=format)
 
+    def __enter__(self) -> "Cursor":
+        return self
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        self.close()
+
     def close(self) -> None:
         self._closed = True
         self._reset()
@@ -390,6 +402,17 @@ class AsyncCursor(BaseCursor):
     ):
         super().__init__(connection, format=format)
 
+    async def __aenter__(self) -> "AsyncCursor":
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        await self.close()
+
     async def close(self) -> None:
         self._closed = True
         self._reset()
index 416a2108a470410ffbb7a56be45b913ad858b588..a50b9bbecdb03e8c7b407d97baaf630023f986f0 100644 (file)
@@ -18,6 +18,13 @@ def test_close(conn):
     assert cur.closed
 
 
+def test_context(conn):
+    with conn.cursor() as cur:
+        assert not cur.closed
+
+    assert cur.closed
+
+
 def test_weakref(conn):
     cur = conn.cursor()
     w = weakref.ref(cur)
index 9168ddecf2c9af42eaa9c5e941d5a3fdc40771fe..37598b950f2913282a190dd75e64d373df7be343 100644 (file)
@@ -20,6 +20,13 @@ async def test_close(aconn):
     assert cur.closed
 
 
+async def test_context(aconn):
+    async with aconn.cursor() as cur:
+        assert not cur.closed
+
+    assert cur.closed
+
+
 async def test_weakref(aconn):
     cur = aconn.cursor()
     w = weakref.ref(cur)