from typing import Any, Dict, List, Match, Optional, Sequence, Type, Union
from types import TracebackType
-from .pq import Format
+from .pq import Format, ExecStatus
from .proto import ConnectionType, Transformer
from .generators import copy_from, copy_to, copy_end
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
+ # no-op in COPY TO
+ if self.pgresult.status == ExecStatus.COPY_OUT:
+ return
+
if exc_val is None:
if self.format == Format.BINARY and not self._first_row:
# send EOF only if we copied binary rows (_first_row is False)
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
+ # no-op in COPY TO
+ if self.pgresult.status == ExecStatus.COPY_OUT:
+ return
+
if exc_val is None:
if self.format == Format.BINARY and not self._first_row:
# send EOF only if we copied binary rows (_first_row is False)
assert list(copy) == want
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_out_context(conn, format):
+ cur = conn.cursor()
+ out = []
+ with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ for row in copy:
+ out.append(row)
+
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+ assert out == want
+
+
@pytest.mark.parametrize(
"format, buffer",
[(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
assert got == want
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_out_context(aconn, format):
+ cur = await aconn.cursor()
+ out = []
+ async with (
+ await cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ )
+ ) as copy:
+ async for row in copy:
+ out.append(row)
+
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+ assert out == want
+
+
@pytest.mark.parametrize(
"format, buffer",
[(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],