import re
from typing import cast, TYPE_CHECKING
-from typing import Any, Deque, Dict, List, Match, Optional, Tuple
+from typing import Any, Deque, Dict, List, Match, Optional, Tuple, Type
+from types import TracebackType
from collections import deque
from . import pq
result, encoding=self.connection.codec.name
)
+ def __enter__(self) -> "Copy":
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ if exc_val is None:
+ self.finish()
+ else:
+ self.finish(str(exc_val))
+
class AsyncCopy(BaseCopy):
def __init__(
if error is not None
else None
)
- await conn.wait(copy_end(conn.pgconn, berr))
+ result = await conn.wait(copy_end(conn.pgconn, berr))
+ if result.status != pq.ExecStatus.COMMAND_OK:
+ raise e.error_from_result(
+ result, encoding=self.connection.codec.name
+ )
+
+ async def __aenter__(self) -> "AsyncCopy":
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ if exc_val is None:
+ await self.finish()
+ else:
+ await self.finish(str(exc_val))
import pytest
from psycopg3 import pq
+from psycopg3 import errors as e
from psycopg3.adapt import Format
from psycopg3.types import builtins
assert data == sample_records
-@pytest.mark.xfail
+def test_copy_in_buffers_pg_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ copy = cur.copy("copy copy_in from stdin (format text)")
+ copy.write(sample_text)
+ copy.write(sample_text)
+ with pytest.raises(e.UniqueViolation):
+ copy.finish()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+
+
@pytest.mark.parametrize(
"format, buffer",
[(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
assert data == sample_records
+def test_copy_in_buffers_with_pg_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text)
+ copy.write(sample_text)
+
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_buffers_with_py_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text)
+ raise Exception("nuttengoggenio")
+
+ assert "nuttengoggenio" in str(exc.value)
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+
+
@pytest.mark.xfail
@pytest.mark.parametrize(
"format", [(Format.TEXT,), (Format.BINARY,)],
import pytest
+from psycopg3 import errors as e
from psycopg3.adapt import Format
from .test_copy import sample_text, sample_binary # noqa
assert data == sample_records
+async def test_copy_in_buffers_pg_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ copy = await cur.copy("copy copy_in from stdin (format text)")
+ await copy.write(sample_text)
+ await copy.write(sample_text)
+ with pytest.raises(e.UniqueViolation):
+ await copy.finish()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize(
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+async def test_copy_in_buffers_with(aconn, format, buffer):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with (
+ await cur.copy(f"copy copy_in from stdin (format {format.name})")
+ ) as copy:
+ await copy.write(globals()[buffer])
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+async def test_copy_in_buffers_with_pg_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ async with (
+ await cur.copy("copy copy_in from stdin (format text)")
+ ) as copy:
+ await copy.write(sample_text)
+ await copy.write(sample_text)
+
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_in_buffers_with_py_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ async with (
+ await cur.copy("copy copy_in from stdin (format text)")
+ ) as copy:
+ await copy.write(sample_text)
+ raise Exception("nuttengoggenio")
+
+ assert "nuttengoggenio" in str(exc.value)
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+
+
async def ensure_table(cur, tabledef, name="copy_in"):
await cur.execute(f"drop table if exists {name}")
await cur.execute(f"create table {name} ({tabledef})")