import logging
import asyncio
import threading
+from types import TracebackType
from typing import Any, AsyncGenerator, Callable, Generator, List, NamedTuple
from typing import Optional, Type, cast
from weakref import ref, ReferenceType
conn._autocommit = autocommit
return conn
+ def __enter__(self) -> "Connection":
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ if exc_type:
+ self.rollback()
+ else:
+ self.commit()
+
+ self.close()
+
def close(self) -> None:
self.pgconn.finish()
conn._autocommit = autocommit
return conn
+ async def __aenter__(self) -> "AsyncConnection":
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ if exc_type:
+ await self.rollback()
+ else:
+ await self.commit()
+
+ await self.close()
+
async def close(self) -> None:
self.pgconn.finish()
import psycopg3
from psycopg3 import Connection
+from psycopg3.errors import UndefinedTable
from psycopg3.conninfo import conninfo_to_dict
cur.execute("select 1")
+def test_context_commit(conn, dsn):
+ with conn:
+ with conn.cursor() as cur:
+ cur.execute("drop table if exists textctx")
+ cur.execute("create table textctx ()")
+
+ assert conn.closed
+
+ with psycopg3.connect(dsn) as conn:
+ with conn.cursor() as cur:
+ cur.execute("select * from textctx")
+ assert cur.fetchall() == []
+
+
+def test_context_rollback(conn, dsn):
+ with conn.cursor() as cur:
+ cur.execute("drop table if exists textctx")
+ conn.commit()
+
+ with pytest.raises(ZeroDivisionError):
+ with conn:
+ with conn.cursor() as cur:
+ cur.execute("create table textctx ()")
+ 1 / 0
+
+ assert conn.closed
+
+ with psycopg3.connect(dsn) as conn:
+ with conn.cursor() as cur:
+ with pytest.raises(UndefinedTable):
+ cur.execute("select * from textctx")
+
+
def test_weakref(dsn):
conn = psycopg3.connect(dsn)
w = weakref.ref(conn)
import psycopg3
from psycopg3 import AsyncConnection
+from psycopg3.errors import UndefinedTable
from psycopg3.conninfo import conninfo_to_dict
pytestmark = pytest.mark.asyncio
await cur.execute("select 1")
+async def test_context_commit(aconn, dsn):
+ async with aconn:
+ async with await aconn.cursor() as cur:
+ await cur.execute("drop table if exists textctx")
+ await cur.execute("create table textctx ()")
+
+ assert aconn.closed
+
+ async with await psycopg3.AsyncConnection.connect(dsn) as aconn:
+ async with await aconn.cursor() as cur:
+ await cur.execute("select * from textctx")
+ assert await cur.fetchall() == []
+
+
+async def test_context_rollback(aconn, dsn):
+ async with await aconn.cursor() as cur:
+ await cur.execute("drop table if exists textctx")
+ await aconn.commit()
+
+ with pytest.raises(ZeroDivisionError):
+ async with aconn:
+ async with await aconn.cursor() as cur:
+ await cur.execute("create table textctx ()")
+ 1 / 0
+
+ assert aconn.closed
+
+ async with await psycopg3.AsyncConnection.connect(dsn) as aconn:
+ async with await aconn.cursor() as cur:
+ with pytest.raises(UndefinedTable):
+ await cur.execute("select * from textctx")
+
+
async def test_weakref(dsn):
conn = await psycopg3.AsyncConnection.connect(dsn)
w = weakref.ref(conn)