From: Daniele Varrazzo Date: Tue, 23 Jun 2020 10:31:40 +0000 (+1200) Subject: Added reading from copy X-Git-Tag: 3.0.dev0~479 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=72155d47f5a86c89d7be68da0f342161fea84639;p=thirdparty%2Fpsycopg.git Added reading from copy --- diff --git a/psycopg3/connection.py b/psycopg3/connection.py index ad5386476..4d3312999 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -127,7 +127,7 @@ class BaseConnection: if self._pgenc != pgenc: if pgenc: try: - pyenc = pq.py_codecs[pgenc.decode("ascii")] + pyenc = pq.py_codecs[pgenc] except KeyError: raise e.NotSupportedError( f"encoding {pgenc.decode('ascii')} not available in Python" diff --git a/psycopg3/copy.py b/psycopg3/copy.py index e07c1b6bb..026da51e5 100644 --- a/psycopg3/copy.py +++ b/psycopg3/copy.py @@ -5,7 +5,7 @@ psycopg3 copy support # Copyright (C) 2020 The Psycopg Team import re -from typing import cast, TYPE_CHECKING +from typing import cast, TYPE_CHECKING, AsyncGenerator, Generator from typing import Any, Deque, Dict, List, Match, Optional, Tuple, Type from types import TracebackType from collections import deque @@ -13,7 +13,7 @@ from collections import deque from . import pq from . import errors as e from .proto import AdaptContext -from .generators import copy_to, copy_end +from .generators import copy_from, copy_to, copy_end if TYPE_CHECKING: from .connection import Connection, AsyncConnection @@ -136,6 +136,17 @@ class Copy(BaseCopy): return self._connection + def read(self) -> Optional[bytes]: + if self._finished: + return None + + conn = self.connection + rv = conn.wait(copy_from(conn.pgconn)) + if rv is None: + self._finished = True + + return rv + def write(self, buffer: bytes) -> None: conn = self.connection conn.wait(copy_to(conn.pgconn, buffer)) @@ -147,11 +158,8 @@ class Copy(BaseCopy): if error is not None else None ) - result = conn.wait(copy_end(conn.pgconn, berr)) - if result.status != pq.ExecStatus.COMMAND_OK: - raise e.error_from_result( - result, encoding=self.connection.codec.name - ) + conn.wait(copy_end(conn.pgconn, berr)) + self._finished = True def __enter__(self) -> "Copy": return self @@ -167,6 +175,13 @@ class Copy(BaseCopy): else: self.finish(str(exc_val)) + def __iter__(self) -> Generator[bytes, None, None]: + while 1: + data = self.read() + if data is None: + break + yield data + class AsyncCopy(BaseCopy): def __init__( @@ -188,6 +203,17 @@ class AsyncCopy(BaseCopy): return self._connection + async def read(self) -> Optional[bytes]: + if self._finished: + return None + + conn = self.connection + rv = await conn.wait(copy_from(conn.pgconn)) + if rv is None: + self._finished = True + + return rv + async def write(self, buffer: bytes) -> None: conn = self.connection await conn.wait(copy_to(conn.pgconn, buffer)) @@ -199,11 +225,8 @@ class AsyncCopy(BaseCopy): if error is not None else None ) - result = await conn.wait(copy_end(conn.pgconn, berr)) - if result.status != pq.ExecStatus.COMMAND_OK: - raise e.error_from_result( - result, encoding=self.connection.codec.name - ) + await conn.wait(copy_end(conn.pgconn, berr)) + self._finished = True async def __aenter__(self) -> "AsyncCopy": return self @@ -218,3 +241,10 @@ class AsyncCopy(BaseCopy): await self.finish() else: await self.finish(str(exc_val)) + + async def __aiter__(self) -> AsyncGenerator[bytes, None]: + while 1: + data = await self.read() + if data is None: + break + yield data diff --git a/psycopg3/generators.py b/psycopg3/generators.py index b4269a84a..7ecb11ad0 100644 --- a/psycopg3/generators.py +++ b/psycopg3/generators.py @@ -151,15 +151,42 @@ def notifies(pgconn: pq.proto.PGconn) -> PQGen[List[pq.PGnotify]]: return ns +def copy_from(pgconn: pq.proto.PGconn) -> PQGen[Optional[bytes]]: + while 1: + nbytes, data = pgconn.get_copy_data(1) + if nbytes != 0: + break + + # would block + yield pgconn.socket, Wait.R + pgconn.consume_input() + + if nbytes > 0: + # some data + return data + + # Retrieve the final result of copy + results = yield from fetch(pgconn) + if len(results) != 1: + raise e.InternalError( + f"1 result expected from copy end, got {len(results)}" + ) + if results[0].status != pq.ExecStatus.COMMAND_OK: + encoding = pq.py_codecs.get( + pgconn.parameter_status(b"client_encoding"), "utf8" + ) + raise e.error_from_result(results[0], encoding=encoding) + + return None + + def copy_to(pgconn: pq.proto.PGconn, buffer: bytes) -> PQGen[None]: # Retry enqueuing data until successful while pgconn.put_copy_data(buffer) == 0: yield pgconn.socket, Wait.W -def copy_end( - pgconn: pq.proto.PGconn, error: Optional[bytes] -) -> PQGen[pq.proto.PGresult]: +def copy_end(pgconn: pq.proto.PGconn, error: Optional[bytes]) -> PQGen[None]: # Retry enqueuing end copy message until successful while pgconn.put_copy_end(error) == 0: yield pgconn.socket, Wait.W @@ -173,9 +200,12 @@ def copy_end( # Retrieve the final result of copy results = yield from fetch(pgconn) - if len(results) == 1: - return results[0] - else: + if len(results) != 1: raise e.InternalError( f"1 result expected from copy end, got {len(results)}" ) + if results[0].status != pq.ExecStatus.COMMAND_OK: + encoding = pq.py_codecs.get( + pgconn.parameter_status(b"client_encoding"), "utf8" + ) + raise e.error_from_result(results[0], encoding=encoding) diff --git a/psycopg3/pq/encodings.py b/psycopg3/pq/encodings.py index 64997b3cd..28eef6c13 100644 --- a/psycopg3/pq/encodings.py +++ b/psycopg3/pq/encodings.py @@ -4,7 +4,9 @@ Mappings between PostgreSQL and Python encodings. # Copyright (C) 2020 The Psycopg Team -py_codecs = { +from typing import Dict, Union + +_py_codecs = { "BIG5": "big5", "EUC_CN": "gb2312", "EUC_JIS_2004": "euc_jis_2004", @@ -50,3 +52,7 @@ py_codecs = { "WIN866": "cp866", "WIN874": "cp874", } + +py_codecs: Dict[Union[bytes, str, None], str] = {} +py_codecs.update((k, v) for k, v in _py_codecs.items()) +py_codecs.update((k.encode("ascii"), v) for k, v in _py_codecs.items()) diff --git a/tests/test_copy.py b/tests/test_copy.py index 3bdec57b8..2f6173a38 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -18,13 +18,20 @@ sample_text = b"""\ """ sample_binary = """ -5047 434f 5059 0aff 0d0a 0000 0000 0000 -0000 0000 0300 0000 0400 0000 0a00 0000 -0400 0000 1400 0000 0568 656c 6c6f 0003 -0000 0004 0000 0028 ffff ffff 0000 0005 -776f 726c 64ff ff +5047 434f 5059 0aff 0d0a 00 +00 0000 0000 0000 00 +00 0300 0000 0400 0000 0a00 0000 0400 0000 1400 0000 0568 656c 6c6f + +0003 0000 0004 0000 0028 ffff ffff 0000 0005 776f 726c 64 + +ff ff """ -sample_binary = bytes.fromhex("".join(sample_binary.split())) + +sample_binary_rows = [ + bytes.fromhex("".join(row.split())) for row in sample_binary.split("\n\n") +] + +sample_binary = b"".join(sample_binary_rows) def set_sample_attributes(res, format): @@ -96,25 +103,33 @@ def test_buffers(format, buffer): assert list(copy.buffers(sample_records)) == [globals()[buffer]] -@pytest.mark.xfail -@pytest.mark.parametrize( - "format, buffer", - [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], -) -def test_copy_out_read(conn, format, buffer): +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_copy_out_read(conn, format): cur = conn.cursor() copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") - assert copy.read() == globals()[buffer] + + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + for row in want: + got = copy.read() + assert got == row + assert copy.read() is None assert copy.read() is None -@pytest.mark.xfail @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_iter(conn, format): +def test_copy_out_iter(conn, format): cur = conn.cursor() copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") - assert list(copy) == sample_records + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + assert list(copy) == want @pytest.mark.parametrize( diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index c5906453e..fff8bd034 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -1,29 +1,51 @@ import pytest +from psycopg3 import pq from psycopg3 import errors as e from psycopg3.adapt import Format -from .test_copy import sample_text, sample_binary # noqa +from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa from .test_copy import sample_values, sample_records, sample_tabledef pytestmark = pytest.mark.asyncio -@pytest.mark.xfail -@pytest.mark.parametrize( - "format, buffer", - [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], -) -async def test_copy_out_read(aconn, format, buffer): +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_copy_out_read(aconn, format): cur = aconn.cursor() copy = await cur.copy( f"copy ({sample_values}) to stdout (format {format.name})" ) - assert await copy.read() == globals()[buffer] + + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + for row in want: + got = await copy.read() + assert got == row + assert await copy.read() is None assert await copy.read() is None +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_copy_out_iter(aconn, format): + cur = aconn.cursor() + copy = await cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + got = [] + async for row in copy: + got.append(row) + assert got == want + + @pytest.mark.parametrize( "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],