]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added context manager interface to copy objects
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 22 Jun 2020 09:41:58 +0000 (21:41 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 22 Jun 2020 11:30:09 +0000 (23:30 +1200)
psycopg3/copy.py
tests/test_copy.py
tests/test_copy_async.py

index 8daf4158f246571499798250dc7dec687035d171..e07c1b6bbddbf563b0e83c2306552d62bd16bbab 100644 (file)
@@ -6,7 +6,8 @@ psycopg3 copy support
 
 import re
 from typing import cast, TYPE_CHECKING
-from typing import Any, Deque, Dict, List, Match, Optional, Tuple
+from typing import Any, Deque, Dict, List, Match, Optional, Tuple, Type
+from types import TracebackType
 from collections import deque
 
 from . import pq
@@ -152,6 +153,20 @@ class Copy(BaseCopy):
                 result, encoding=self.connection.codec.name
             )
 
+    def __enter__(self) -> "Copy":
+        return self
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        if exc_val is None:
+            self.finish()
+        else:
+            self.finish(str(exc_val))
+
 
 class AsyncCopy(BaseCopy):
     def __init__(
@@ -184,4 +199,22 @@ class AsyncCopy(BaseCopy):
             if error is not None
             else None
         )
-        await conn.wait(copy_end(conn.pgconn, berr))
+        result = await conn.wait(copy_end(conn.pgconn, berr))
+        if result.status != pq.ExecStatus.COMMAND_OK:
+            raise e.error_from_result(
+                result, encoding=self.connection.codec.name
+            )
+
+    async def __aenter__(self) -> "AsyncCopy":
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        if exc_val is None:
+            await self.finish()
+        else:
+            await self.finish(str(exc_val))
index 49635fc4084fc5497e9b5226535d2bd3a1024465..3bdec57b8cf9683b14c663232686e86478a526b3 100644 (file)
@@ -1,6 +1,7 @@
 import pytest
 
 from psycopg3 import pq
+from psycopg3 import errors as e
 from psycopg3.adapt import Format
 from psycopg3.types import builtins
 
@@ -130,7 +131,17 @@ def test_copy_in_buffers(conn, format, buffer):
     assert data == sample_records
 
 
-@pytest.mark.xfail
+def test_copy_in_buffers_pg_error(conn):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    copy = cur.copy("copy copy_in from stdin (format text)")
+    copy.write(sample_text)
+    copy.write(sample_text)
+    with pytest.raises(e.UniqueViolation):
+        copy.finish()
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+
+
 @pytest.mark.parametrize(
     "format, buffer",
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
@@ -145,6 +156,29 @@ def test_copy_in_buffers_with(conn, format, buffer):
     assert data == sample_records
 
 
+def test_copy_in_buffers_with_pg_error(conn):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    with pytest.raises(e.UniqueViolation):
+        with cur.copy("copy copy_in from stdin (format text)") as copy:
+            copy.write(sample_text)
+            copy.write(sample_text)
+
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_buffers_with_py_error(conn):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    with pytest.raises(e.QueryCanceled) as exc:
+        with cur.copy("copy copy_in from stdin (format text)") as copy:
+            copy.write(sample_text)
+            raise Exception("nuttengoggenio")
+
+    assert "nuttengoggenio" in str(exc.value)
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+
+
 @pytest.mark.xfail
 @pytest.mark.parametrize(
     "format", [(Format.TEXT,), (Format.BINARY,)],
index f6b414c2e8f00ef9d9cecaa8e49cb1f8cab5dfca..c5906453ea761f298ff9dbb2558876c07b6e6f70 100644 (file)
@@ -1,5 +1,6 @@
 import pytest
 
+from psycopg3 import errors as e
 from psycopg3.adapt import Format
 
 from .test_copy import sample_text, sample_binary  # noqa
@@ -38,6 +39,61 @@ async def test_copy_in_buffers(aconn, format, buffer):
     assert data == sample_records
 
 
+async def test_copy_in_buffers_pg_error(aconn):
+    cur = aconn.cursor()
+    await ensure_table(cur, sample_tabledef)
+    copy = await cur.copy("copy copy_in from stdin (format text)")
+    await copy.write(sample_text)
+    await copy.write(sample_text)
+    with pytest.raises(e.UniqueViolation):
+        await copy.finish()
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize(
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+async def test_copy_in_buffers_with(aconn, format, buffer):
+    cur = aconn.cursor()
+    await ensure_table(cur, sample_tabledef)
+    async with (
+        await cur.copy(f"copy copy_in from stdin (format {format.name})")
+    ) as copy:
+        await copy.write(globals()[buffer])
+
+    await cur.execute("select * from copy_in order by 1")
+    data = await cur.fetchall()
+    assert data == sample_records
+
+
+async def test_copy_in_buffers_with_pg_error(aconn):
+    cur = aconn.cursor()
+    await ensure_table(cur, sample_tabledef)
+    with pytest.raises(e.UniqueViolation):
+        async with (
+            await cur.copy("copy copy_in from stdin (format text)")
+        ) as copy:
+            await copy.write(sample_text)
+            await copy.write(sample_text)
+
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_in_buffers_with_py_error(aconn):
+    cur = aconn.cursor()
+    await ensure_table(cur, sample_tabledef)
+    with pytest.raises(e.QueryCanceled) as exc:
+        async with (
+            await cur.copy("copy copy_in from stdin (format text)")
+        ) as copy:
+            await copy.write(sample_text)
+            raise Exception("nuttengoggenio")
+
+    assert "nuttengoggenio" in str(exc.value)
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+
+
 async def ensure_table(cur, tabledef, name="copy_in"):
     await cur.execute(f"drop table if exists {name}")
     await cur.execute(f"create table {name} ({tabledef})")