if self._pgenc != pgenc:
if pgenc:
try:
- pyenc = pq.py_codecs[pgenc.decode("ascii")]
+ pyenc = pq.py_codecs[pgenc]
except KeyError:
raise e.NotSupportedError(
f"encoding {pgenc.decode('ascii')} not available in Python"
# Copyright (C) 2020 The Psycopg Team
import re
-from typing import cast, TYPE_CHECKING
+from typing import cast, TYPE_CHECKING, AsyncGenerator, Generator
from typing import Any, Deque, Dict, List, Match, Optional, Tuple, Type
from types import TracebackType
from collections import deque
from . import pq
from . import errors as e
from .proto import AdaptContext
-from .generators import copy_to, copy_end
+from .generators import copy_from, copy_to, copy_end
if TYPE_CHECKING:
from .connection import Connection, AsyncConnection
return self._connection
+ def read(self) -> Optional[bytes]:
+ if self._finished:
+ return None
+
+ conn = self.connection
+ rv = conn.wait(copy_from(conn.pgconn))
+ if rv is None:
+ self._finished = True
+
+ return rv
+
def write(self, buffer: bytes) -> None:
conn = self.connection
conn.wait(copy_to(conn.pgconn, buffer))
if error is not None
else None
)
- result = conn.wait(copy_end(conn.pgconn, berr))
- if result.status != pq.ExecStatus.COMMAND_OK:
- raise e.error_from_result(
- result, encoding=self.connection.codec.name
- )
+ conn.wait(copy_end(conn.pgconn, berr))
+ self._finished = True
def __enter__(self) -> "Copy":
return self
else:
self.finish(str(exc_val))
+ def __iter__(self) -> Generator[bytes, None, None]:
+ while 1:
+ data = self.read()
+ if data is None:
+ break
+ yield data
+
class AsyncCopy(BaseCopy):
def __init__(
return self._connection
+ async def read(self) -> Optional[bytes]:
+ if self._finished:
+ return None
+
+ conn = self.connection
+ rv = await conn.wait(copy_from(conn.pgconn))
+ if rv is None:
+ self._finished = True
+
+ return rv
+
async def write(self, buffer: bytes) -> None:
conn = self.connection
await conn.wait(copy_to(conn.pgconn, buffer))
if error is not None
else None
)
- result = await conn.wait(copy_end(conn.pgconn, berr))
- if result.status != pq.ExecStatus.COMMAND_OK:
- raise e.error_from_result(
- result, encoding=self.connection.codec.name
- )
+ await conn.wait(copy_end(conn.pgconn, berr))
+ self._finished = True
async def __aenter__(self) -> "AsyncCopy":
return self
await self.finish()
else:
await self.finish(str(exc_val))
+
+ async def __aiter__(self) -> AsyncGenerator[bytes, None]:
+ while 1:
+ data = await self.read()
+ if data is None:
+ break
+ yield data
return ns
+def copy_from(pgconn: pq.proto.PGconn) -> PQGen[Optional[bytes]]:
+ while 1:
+ nbytes, data = pgconn.get_copy_data(1)
+ if nbytes != 0:
+ break
+
+ # would block
+ yield pgconn.socket, Wait.R
+ pgconn.consume_input()
+
+ if nbytes > 0:
+ # some data
+ return data
+
+ # Retrieve the final result of copy
+ results = yield from fetch(pgconn)
+ if len(results) != 1:
+ raise e.InternalError(
+ f"1 result expected from copy end, got {len(results)}"
+ )
+ if results[0].status != pq.ExecStatus.COMMAND_OK:
+ encoding = pq.py_codecs.get(
+ pgconn.parameter_status(b"client_encoding"), "utf8"
+ )
+ raise e.error_from_result(results[0], encoding=encoding)
+
+ return None
+
+
def copy_to(pgconn: pq.proto.PGconn, buffer: bytes) -> PQGen[None]:
# Retry enqueuing data until successful
while pgconn.put_copy_data(buffer) == 0:
yield pgconn.socket, Wait.W
-def copy_end(
- pgconn: pq.proto.PGconn, error: Optional[bytes]
-) -> PQGen[pq.proto.PGresult]:
+def copy_end(pgconn: pq.proto.PGconn, error: Optional[bytes]) -> PQGen[None]:
# Retry enqueuing end copy message until successful
while pgconn.put_copy_end(error) == 0:
yield pgconn.socket, Wait.W
# Retrieve the final result of copy
results = yield from fetch(pgconn)
- if len(results) == 1:
- return results[0]
- else:
+ if len(results) != 1:
raise e.InternalError(
f"1 result expected from copy end, got {len(results)}"
)
+ if results[0].status != pq.ExecStatus.COMMAND_OK:
+ encoding = pq.py_codecs.get(
+ pgconn.parameter_status(b"client_encoding"), "utf8"
+ )
+ raise e.error_from_result(results[0], encoding=encoding)
# Copyright (C) 2020 The Psycopg Team
-py_codecs = {
+from typing import Dict, Union
+
+_py_codecs = {
"BIG5": "big5",
"EUC_CN": "gb2312",
"EUC_JIS_2004": "euc_jis_2004",
"WIN866": "cp866",
"WIN874": "cp874",
}
+
+py_codecs: Dict[Union[bytes, str, None], str] = {}
+py_codecs.update((k, v) for k, v in _py_codecs.items())
+py_codecs.update((k.encode("ascii"), v) for k, v in _py_codecs.items())
"""
sample_binary = """
-5047 434f 5059 0aff 0d0a 0000 0000 0000
-0000 0000 0300 0000 0400 0000 0a00 0000
-0400 0000 1400 0000 0568 656c 6c6f 0003
-0000 0004 0000 0028 ffff ffff 0000 0005
-776f 726c 64ff ff
+5047 434f 5059 0aff 0d0a 00
+00 0000 0000 0000 00
+00 0300 0000 0400 0000 0a00 0000 0400 0000 1400 0000 0568 656c 6c6f
+
+0003 0000 0004 0000 0028 ffff ffff 0000 0005 776f 726c 64
+
+ff ff
"""
-sample_binary = bytes.fromhex("".join(sample_binary.split()))
+
+sample_binary_rows = [
+ bytes.fromhex("".join(row.split())) for row in sample_binary.split("\n\n")
+]
+
+sample_binary = b"".join(sample_binary_rows)
def set_sample_attributes(res, format):
assert list(copy.buffers(sample_records)) == [globals()[buffer]]
-@pytest.mark.xfail
-@pytest.mark.parametrize(
- "format, buffer",
- [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
-)
-def test_copy_out_read(conn, format, buffer):
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_out_read(conn, format):
cur = conn.cursor()
copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})")
- assert copy.read() == globals()[buffer]
+
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+
+ for row in want:
+ got = copy.read()
+ assert got == row
+
assert copy.read() is None
assert copy.read() is None
-@pytest.mark.xfail
@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_iter(conn, format):
+def test_copy_out_iter(conn, format):
cur = conn.cursor()
copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})")
- assert list(copy) == sample_records
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+ assert list(copy) == want
@pytest.mark.parametrize(
import pytest
+from psycopg3 import pq
from psycopg3 import errors as e
from psycopg3.adapt import Format
-from .test_copy import sample_text, sample_binary # noqa
+from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa
from .test_copy import sample_values, sample_records, sample_tabledef
pytestmark = pytest.mark.asyncio
-@pytest.mark.xfail
-@pytest.mark.parametrize(
- "format, buffer",
- [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
-)
-async def test_copy_out_read(aconn, format, buffer):
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_out_read(aconn, format):
cur = aconn.cursor()
copy = await cur.copy(
f"copy ({sample_values}) to stdout (format {format.name})"
)
- assert await copy.read() == globals()[buffer]
+
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+
+ for row in want:
+ got = await copy.read()
+ assert got == row
+
assert await copy.read() is None
assert await copy.read() is None
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_out_iter(aconn, format):
+ cur = aconn.cursor()
+ copy = await cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ )
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+ got = []
+ async for row in copy:
+ got.append(row)
+ assert got == want
+
+
@pytest.mark.parametrize(
"format, buffer",
[(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],