From: Daniele Varrazzo Date: Wed, 28 Oct 2020 16:33:44 +0000 (+0100) Subject: Added cursors context managers X-Git-Tag: 3.0.dev0~423 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5b4368c7ec7d790e32b16a1c0e75219185f2c723;p=thirdparty%2Fpsycopg.git Added cursors context managers --- diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index 81bc86289..236165621 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -4,8 +4,9 @@ psycopg3 cursor objects # Copyright (C) 2020 The Psycopg Team +from types import TracebackType +from typing import Any, Callable, List, Optional, Sequence, Type, TYPE_CHECKING from operator import attrgetter -from typing import Any, Callable, List, Optional, Sequence, TYPE_CHECKING from . import errors as e from . import pq @@ -282,6 +283,17 @@ class Cursor(BaseCursor): ): super().__init__(connection, format=format) + def __enter__(self) -> "Cursor": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + def close(self) -> None: self._closed = True self._reset() @@ -390,6 +402,17 @@ class AsyncCursor(BaseCursor): ): super().__init__(connection, format=format) + async def __aenter__(self) -> "AsyncCursor": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() + async def close(self) -> None: self._closed = True self._reset() diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 416a2108a..a50b9bbec 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -18,6 +18,13 @@ def test_close(conn): assert cur.closed +def test_context(conn): + with conn.cursor() as cur: + assert not cur.closed + + assert cur.closed + + def test_weakref(conn): cur = conn.cursor() w = weakref.ref(cur) diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index 9168ddecf..37598b950 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -20,6 +20,13 @@ async def test_close(aconn): assert cur.closed +async def test_context(aconn): + async with aconn.cursor() as cur: + assert not cur.closed + + assert cur.closed + + async def test_weakref(aconn): cur = aconn.cursor() w = weakref.ref(cur)