]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added connection enter/exit
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 9 Nov 2020 03:35:03 +0000 (03:35 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 9 Nov 2020 03:35:03 +0000 (03:35 +0000)
Just a simplified version, only commit/rollback + close, waiting for
nested transactions...

psycopg3/psycopg3/connection.py
tests/test_connection.py
tests/test_connection_async.py

index 0cde94d1b9fd112c42e9873a90f079fdf04000e7..66249767bf2d723c459593d77638395fcad1ea97 100644 (file)
@@ -7,6 +7,7 @@ psycopg3 connection objects
 import logging
 import asyncio
 import threading
+from types import TracebackType
 from typing import Any, AsyncGenerator, Callable, Generator, List, NamedTuple
 from typing import Optional, Type, cast
 from weakref import ref, ReferenceType
@@ -234,6 +235,22 @@ class Connection(BaseConnection):
         conn._autocommit = autocommit
         return conn
 
+    def __enter__(self) -> "Connection":
+        return self
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        if exc_type:
+            self.rollback()
+        else:
+            self.commit()
+
+        self.close()
+
     def close(self) -> None:
         self.pgconn.finish()
 
@@ -349,6 +366,22 @@ class AsyncConnection(BaseConnection):
         conn._autocommit = autocommit
         return conn
 
+    async def __aenter__(self) -> "AsyncConnection":
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        if exc_type:
+            await self.rollback()
+        else:
+            await self.commit()
+
+        await self.close()
+
     async def close(self) -> None:
         self.pgconn.finish()
 
index 777a65ec6f3f945bfbbf087a123c66c07e870eab..2d28f20e9004381632613e7b6b1f44de122c35ba 100644 (file)
@@ -5,6 +5,7 @@ import weakref
 
 import psycopg3
 from psycopg3 import Connection
+from psycopg3.errors import UndefinedTable
 from psycopg3.conninfo import conninfo_to_dict
 
 
@@ -42,6 +43,39 @@ def test_close(conn):
         cur.execute("select 1")
 
 
+def test_context_commit(conn, dsn):
+    with conn:
+        with conn.cursor() as cur:
+            cur.execute("drop table if exists textctx")
+            cur.execute("create table textctx ()")
+
+    assert conn.closed
+
+    with psycopg3.connect(dsn) as conn:
+        with conn.cursor() as cur:
+            cur.execute("select * from textctx")
+            assert cur.fetchall() == []
+
+
+def test_context_rollback(conn, dsn):
+    with conn.cursor() as cur:
+        cur.execute("drop table if exists textctx")
+    conn.commit()
+
+    with pytest.raises(ZeroDivisionError):
+        with conn:
+            with conn.cursor() as cur:
+                cur.execute("create table textctx ()")
+                1 / 0
+
+    assert conn.closed
+
+    with psycopg3.connect(dsn) as conn:
+        with conn.cursor() as cur:
+            with pytest.raises(UndefinedTable):
+                cur.execute("select * from textctx")
+
+
 def test_weakref(dsn):
     conn = psycopg3.connect(dsn)
     w = weakref.ref(conn)
index 14735c78e1968eb175f65636f4bc8c88ba4489d4..60f6a0c687896e23b338953e683815f73cd1bb90 100644 (file)
@@ -5,6 +5,7 @@ import weakref
 
 import psycopg3
 from psycopg3 import AsyncConnection
+from psycopg3.errors import UndefinedTable
 from psycopg3.conninfo import conninfo_to_dict
 
 pytestmark = pytest.mark.asyncio
@@ -45,6 +46,39 @@ async def test_close(aconn):
         await cur.execute("select 1")
 
 
+async def test_context_commit(aconn, dsn):
+    async with aconn:
+        async with await aconn.cursor() as cur:
+            await cur.execute("drop table if exists textctx")
+            await cur.execute("create table textctx ()")
+
+    assert aconn.closed
+
+    async with await psycopg3.AsyncConnection.connect(dsn) as aconn:
+        async with await aconn.cursor() as cur:
+            await cur.execute("select * from textctx")
+            assert await cur.fetchall() == []
+
+
+async def test_context_rollback(aconn, dsn):
+    async with await aconn.cursor() as cur:
+        await cur.execute("drop table if exists textctx")
+    await aconn.commit()
+
+    with pytest.raises(ZeroDivisionError):
+        async with aconn:
+            async with await aconn.cursor() as cur:
+                await cur.execute("create table textctx ()")
+                1 / 0
+
+    assert aconn.closed
+
+    async with await psycopg3.AsyncConnection.connect(dsn) as aconn:
+        async with await aconn.cursor() as cur:
+            with pytest.raises(UndefinedTable):
+                await cur.execute("select * from textctx")
+
+
 async def test_weakref(dsn):
     conn = await psycopg3.AsyncConnection.connect(dsn)
     w = weakref.ref(conn)