From: Daniele Varrazzo Date: Mon, 9 Nov 2020 03:35:03 +0000 (+0000) Subject: Added connection enter/exit X-Git-Tag: 3.0.dev0~387 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4a47121784d028c26a1f5be4ced96107b46baed2;p=thirdparty%2Fpsycopg.git Added connection enter/exit Just a simplified version, only commit/rollback + close, waiting for nested transactions... --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 0cde94d1b..66249767b 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -7,6 +7,7 @@ psycopg3 connection objects import logging import asyncio import threading +from types import TracebackType from typing import Any, AsyncGenerator, Callable, Generator, List, NamedTuple from typing import Optional, Type, cast from weakref import ref, ReferenceType @@ -234,6 +235,22 @@ class Connection(BaseConnection): conn._autocommit = autocommit return conn + def __enter__(self) -> "Connection": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if exc_type: + self.rollback() + else: + self.commit() + + self.close() + def close(self) -> None: self.pgconn.finish() @@ -349,6 +366,22 @@ class AsyncConnection(BaseConnection): conn._autocommit = autocommit return conn + async def __aenter__(self) -> "AsyncConnection": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if exc_type: + await self.rollback() + else: + await self.commit() + + await self.close() + async def close(self) -> None: self.pgconn.finish() diff --git a/tests/test_connection.py b/tests/test_connection.py index 777a65ec6..2d28f20e9 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,6 +5,7 @@ import weakref import psycopg3 from psycopg3 import Connection +from psycopg3.errors import UndefinedTable from psycopg3.conninfo import conninfo_to_dict @@ -42,6 +43,39 @@ def test_close(conn): cur.execute("select 1") +def test_context_commit(conn, dsn): + with conn: + with conn.cursor() as cur: + cur.execute("drop table if exists textctx") + cur.execute("create table textctx ()") + + assert conn.closed + + with psycopg3.connect(dsn) as conn: + with conn.cursor() as cur: + cur.execute("select * from textctx") + assert cur.fetchall() == [] + + +def test_context_rollback(conn, dsn): + with conn.cursor() as cur: + cur.execute("drop table if exists textctx") + conn.commit() + + with pytest.raises(ZeroDivisionError): + with conn: + with conn.cursor() as cur: + cur.execute("create table textctx ()") + 1 / 0 + + assert conn.closed + + with psycopg3.connect(dsn) as conn: + with conn.cursor() as cur: + with pytest.raises(UndefinedTable): + cur.execute("select * from textctx") + + def test_weakref(dsn): conn = psycopg3.connect(dsn) w = weakref.ref(conn) diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 14735c78e..60f6a0c68 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -5,6 +5,7 @@ import weakref import psycopg3 from psycopg3 import AsyncConnection +from psycopg3.errors import UndefinedTable from psycopg3.conninfo import conninfo_to_dict pytestmark = pytest.mark.asyncio @@ -45,6 +46,39 @@ async def test_close(aconn): await cur.execute("select 1") +async def test_context_commit(aconn, dsn): + async with aconn: + async with await aconn.cursor() as cur: + await cur.execute("drop table if exists textctx") + await cur.execute("create table textctx ()") + + assert aconn.closed + + async with await psycopg3.AsyncConnection.connect(dsn) as aconn: + async with await aconn.cursor() as cur: + await cur.execute("select * from textctx") + assert await cur.fetchall() == [] + + +async def test_context_rollback(aconn, dsn): + async with await aconn.cursor() as cur: + await cur.execute("drop table if exists textctx") + await aconn.commit() + + with pytest.raises(ZeroDivisionError): + async with aconn: + async with await aconn.cursor() as cur: + await cur.execute("create table textctx ()") + 1 / 0 + + assert aconn.closed + + async with await psycopg3.AsyncConnection.connect(dsn) as aconn: + async with await aconn.cursor() as cur: + with pytest.raises(UndefinedTable): + await cur.execute("select * from textctx") + + async def test_weakref(dsn): conn = await psycopg3.AsyncConnection.connect(dsn) w = weakref.ref(conn)