From: Daniele Varrazzo Date: Mon, 22 Jun 2020 09:41:58 +0000 (+1200) Subject: Added context manager interface to copy objects X-Git-Tag: 3.0.dev0~481 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8cd877646ffacb36940750f89a7cde545aff279d;p=thirdparty%2Fpsycopg.git Added context manager interface to copy objects --- diff --git a/psycopg3/copy.py b/psycopg3/copy.py index 8daf4158f..e07c1b6bb 100644 --- a/psycopg3/copy.py +++ b/psycopg3/copy.py @@ -6,7 +6,8 @@ psycopg3 copy support import re from typing import cast, TYPE_CHECKING -from typing import Any, Deque, Dict, List, Match, Optional, Tuple +from typing import Any, Deque, Dict, List, Match, Optional, Tuple, Type +from types import TracebackType from collections import deque from . import pq @@ -152,6 +153,20 @@ class Copy(BaseCopy): result, encoding=self.connection.codec.name ) + def __enter__(self) -> "Copy": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if exc_val is None: + self.finish() + else: + self.finish(str(exc_val)) + class AsyncCopy(BaseCopy): def __init__( @@ -184,4 +199,22 @@ class AsyncCopy(BaseCopy): if error is not None else None ) - await conn.wait(copy_end(conn.pgconn, berr)) + result = await conn.wait(copy_end(conn.pgconn, berr)) + if result.status != pq.ExecStatus.COMMAND_OK: + raise e.error_from_result( + result, encoding=self.connection.codec.name + ) + + async def __aenter__(self) -> "AsyncCopy": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if exc_val is None: + await self.finish() + else: + await self.finish(str(exc_val)) diff --git a/tests/test_copy.py b/tests/test_copy.py index 49635fc40..3bdec57b8 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -1,6 +1,7 @@ import pytest from psycopg3 import pq +from psycopg3 import errors as e from psycopg3.adapt import Format from psycopg3.types import builtins @@ -130,7 +131,17 @@ def test_copy_in_buffers(conn, format, buffer): assert data == sample_records -@pytest.mark.xfail +def test_copy_in_buffers_pg_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + copy = cur.copy("copy copy_in from stdin (format text)") + copy.write(sample_text) + copy.write(sample_text) + with pytest.raises(e.UniqueViolation): + copy.finish() + assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR + + @pytest.mark.parametrize( "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], @@ -145,6 +156,29 @@ def test_copy_in_buffers_with(conn, format, buffer): assert data == sample_records +def test_copy_in_buffers_with_pg_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + with cur.copy("copy copy_in from stdin (format text)") as copy: + copy.write(sample_text) + copy.write(sample_text) + + assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_in_buffers_with_py_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + with cur.copy("copy copy_in from stdin (format text)") as copy: + copy.write(sample_text) + raise Exception("nuttengoggenio") + + assert "nuttengoggenio" in str(exc.value) + assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR + + @pytest.mark.xfail @pytest.mark.parametrize( "format", [(Format.TEXT,), (Format.BINARY,)], diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index f6b414c2e..c5906453e 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -1,5 +1,6 @@ import pytest +from psycopg3 import errors as e from psycopg3.adapt import Format from .test_copy import sample_text, sample_binary # noqa @@ -38,6 +39,61 @@ async def test_copy_in_buffers(aconn, format, buffer): assert data == sample_records +async def test_copy_in_buffers_pg_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + copy = await cur.copy("copy copy_in from stdin (format text)") + await copy.write(sample_text) + await copy.write(sample_text) + with pytest.raises(e.UniqueViolation): + await copy.finish() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR + + +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +async def test_copy_in_buffers_with(aconn, format, buffer): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with ( + await cur.copy(f"copy copy_in from stdin (format {format.name})") + ) as copy: + await copy.write(globals()[buffer]) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +async def test_copy_in_buffers_with_pg_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + async with ( + await cur.copy("copy copy_in from stdin (format text)") + ) as copy: + await copy.write(sample_text) + await copy.write(sample_text) + + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_in_buffers_with_py_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + async with ( + await cur.copy("copy copy_in from stdin (format text)") + ) as copy: + await copy.write(sample_text) + raise Exception("nuttengoggenio") + + assert "nuttengoggenio" in str(exc.value) + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR + + async def ensure_table(cur, tabledef, name="copy_in"): await cur.execute(f"drop table if exists {name}") await cur.execute(f"create table {name} ({tabledef})")