]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added reading from copy
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 23 Jun 2020 10:31:40 +0000 (22:31 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 23 Jun 2020 10:31:40 +0000 (22:31 +1200)
psycopg3/connection.py
psycopg3/copy.py
psycopg3/generators.py
psycopg3/pq/encodings.py
tests/test_copy.py
tests/test_copy_async.py

index ad53864761c54adc89be5441d261554bc4d9901e..4d331299952b7660327970afe316837d5cb1d27f 100644 (file)
@@ -127,7 +127,7 @@ class BaseConnection:
         if self._pgenc != pgenc:
             if pgenc:
                 try:
-                    pyenc = pq.py_codecs[pgenc.decode("ascii")]
+                    pyenc = pq.py_codecs[pgenc]
                 except KeyError:
                     raise e.NotSupportedError(
                         f"encoding {pgenc.decode('ascii')} not available in Python"
index e07c1b6bbddbf563b0e83c2306552d62bd16bbab..026da51e5a1898488c7e43a3608a7474bbed1f42 100644 (file)
@@ -5,7 +5,7 @@ psycopg3 copy support
 # Copyright (C) 2020 The Psycopg Team
 
 import re
-from typing import cast, TYPE_CHECKING
+from typing import cast, TYPE_CHECKING, AsyncGenerator, Generator
 from typing import Any, Deque, Dict, List, Match, Optional, Tuple, Type
 from types import TracebackType
 from collections import deque
@@ -13,7 +13,7 @@ from collections import deque
 from . import pq
 from . import errors as e
 from .proto import AdaptContext
-from .generators import copy_to, copy_end
+from .generators import copy_from, copy_to, copy_end
 
 if TYPE_CHECKING:
     from .connection import Connection, AsyncConnection
@@ -136,6 +136,17 @@ class Copy(BaseCopy):
 
         return self._connection
 
+    def read(self) -> Optional[bytes]:
+        if self._finished:
+            return None
+
+        conn = self.connection
+        rv = conn.wait(copy_from(conn.pgconn))
+        if rv is None:
+            self._finished = True
+
+        return rv
+
     def write(self, buffer: bytes) -> None:
         conn = self.connection
         conn.wait(copy_to(conn.pgconn, buffer))
@@ -147,11 +158,8 @@ class Copy(BaseCopy):
             if error is not None
             else None
         )
-        result = conn.wait(copy_end(conn.pgconn, berr))
-        if result.status != pq.ExecStatus.COMMAND_OK:
-            raise e.error_from_result(
-                result, encoding=self.connection.codec.name
-            )
+        conn.wait(copy_end(conn.pgconn, berr))
+        self._finished = True
 
     def __enter__(self) -> "Copy":
         return self
@@ -167,6 +175,13 @@ class Copy(BaseCopy):
         else:
             self.finish(str(exc_val))
 
+    def __iter__(self) -> Generator[bytes, None, None]:
+        while 1:
+            data = self.read()
+            if data is None:
+                break
+            yield data
+
 
 class AsyncCopy(BaseCopy):
     def __init__(
@@ -188,6 +203,17 @@ class AsyncCopy(BaseCopy):
 
         return self._connection
 
+    async def read(self) -> Optional[bytes]:
+        if self._finished:
+            return None
+
+        conn = self.connection
+        rv = await conn.wait(copy_from(conn.pgconn))
+        if rv is None:
+            self._finished = True
+
+        return rv
+
     async def write(self, buffer: bytes) -> None:
         conn = self.connection
         await conn.wait(copy_to(conn.pgconn, buffer))
@@ -199,11 +225,8 @@ class AsyncCopy(BaseCopy):
             if error is not None
             else None
         )
-        result = await conn.wait(copy_end(conn.pgconn, berr))
-        if result.status != pq.ExecStatus.COMMAND_OK:
-            raise e.error_from_result(
-                result, encoding=self.connection.codec.name
-            )
+        await conn.wait(copy_end(conn.pgconn, berr))
+        self._finished = True
 
     async def __aenter__(self) -> "AsyncCopy":
         return self
@@ -218,3 +241,10 @@ class AsyncCopy(BaseCopy):
             await self.finish()
         else:
             await self.finish(str(exc_val))
+
+    async def __aiter__(self) -> AsyncGenerator[bytes, None]:
+        while 1:
+            data = await self.read()
+            if data is None:
+                break
+            yield data
index b4269a84afc59b2e2851152921da8adde8e28e0b..7ecb11ad04f54ba1f2681d9193b92641d8a1006d 100644 (file)
@@ -151,15 +151,42 @@ def notifies(pgconn: pq.proto.PGconn) -> PQGen[List[pq.PGnotify]]:
     return ns
 
 
+def copy_from(pgconn: pq.proto.PGconn) -> PQGen[Optional[bytes]]:
+    while 1:
+        nbytes, data = pgconn.get_copy_data(1)
+        if nbytes != 0:
+            break
+
+        # would block
+        yield pgconn.socket, Wait.R
+        pgconn.consume_input()
+
+    if nbytes > 0:
+        # some data
+        return data
+
+    # Retrieve the final result of copy
+    results = yield from fetch(pgconn)
+    if len(results) != 1:
+        raise e.InternalError(
+            f"1 result expected from copy end, got {len(results)}"
+        )
+    if results[0].status != pq.ExecStatus.COMMAND_OK:
+        encoding = pq.py_codecs.get(
+            pgconn.parameter_status(b"client_encoding"), "utf8"
+        )
+        raise e.error_from_result(results[0], encoding=encoding)
+
+    return None
+
+
 def copy_to(pgconn: pq.proto.PGconn, buffer: bytes) -> PQGen[None]:
     # Retry enqueuing data until successful
     while pgconn.put_copy_data(buffer) == 0:
         yield pgconn.socket, Wait.W
 
 
-def copy_end(
-    pgconn: pq.proto.PGconn, error: Optional[bytes]
-) -> PQGen[pq.proto.PGresult]:
+def copy_end(pgconn: pq.proto.PGconn, error: Optional[bytes]) -> PQGen[None]:
     # Retry enqueuing end copy message until successful
     while pgconn.put_copy_end(error) == 0:
         yield pgconn.socket, Wait.W
@@ -173,9 +200,12 @@ def copy_end(
 
     # Retrieve the final result of copy
     results = yield from fetch(pgconn)
-    if len(results) == 1:
-        return results[0]
-    else:
+    if len(results) != 1:
         raise e.InternalError(
             f"1 result expected from copy end, got {len(results)}"
         )
+    if results[0].status != pq.ExecStatus.COMMAND_OK:
+        encoding = pq.py_codecs.get(
+            pgconn.parameter_status(b"client_encoding"), "utf8"
+        )
+        raise e.error_from_result(results[0], encoding=encoding)
index 64997b3cd0bc7c20ebd58f10283422ca6352a1f8..28eef6c139eec6c6094df0a726c79149295a1678 100644 (file)
@@ -4,7 +4,9 @@ Mappings between PostgreSQL and Python encodings.
 
 # Copyright (C) 2020 The Psycopg Team
 
-py_codecs = {
+from typing import Dict, Union
+
+_py_codecs = {
     "BIG5": "big5",
     "EUC_CN": "gb2312",
     "EUC_JIS_2004": "euc_jis_2004",
@@ -50,3 +52,7 @@ py_codecs = {
     "WIN866": "cp866",
     "WIN874": "cp874",
 }
+
+py_codecs: Dict[Union[bytes, str, None], str] = {}
+py_codecs.update((k, v) for k, v in _py_codecs.items())
+py_codecs.update((k.encode("ascii"), v) for k, v in _py_codecs.items())
index 3bdec57b8cf9683b14c663232686e86478a526b3..2f6173a384a5a13dd4aaf215bad12f18091d068a 100644 (file)
@@ -18,13 +18,20 @@ sample_text = b"""\
 """
 
 sample_binary = """
-5047 434f 5059 0aff 0d0a 0000 0000 0000
-0000 0000 0300 0000 0400 0000 0a00 0000
-0400 0000 1400 0000 0568 656c 6c6f 0003
-0000 0004 0000 0028 ffff ffff 0000 0005
-776f 726c 64ff ff
+5047 434f 5059 0aff 0d0a 00
+00 0000 0000 0000 00
+00 0300 0000 0400 0000 0a00 0000 0400 0000 1400 0000 0568 656c 6c6f
+
+0003 0000 0004 0000 0028 ffff ffff 0000 0005 776f 726c 64
+
+ff ff
 """
-sample_binary = bytes.fromhex("".join(sample_binary.split()))
+
+sample_binary_rows = [
+    bytes.fromhex("".join(row.split())) for row in sample_binary.split("\n\n")
+]
+
+sample_binary = b"".join(sample_binary_rows)
 
 
 def set_sample_attributes(res, format):
@@ -96,25 +103,33 @@ def test_buffers(format, buffer):
     assert list(copy.buffers(sample_records)) == [globals()[buffer]]
 
 
-@pytest.mark.xfail
-@pytest.mark.parametrize(
-    "format, buffer",
-    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
-)
-def test_copy_out_read(conn, format, buffer):
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_out_read(conn, format):
     cur = conn.cursor()
     copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})")
-    assert copy.read() == globals()[buffer]
+
+    if format == pq.Format.TEXT:
+        want = [row + b"\n" for row in sample_text.splitlines()]
+    else:
+        want = sample_binary_rows
+
+    for row in want:
+        got = copy.read()
+        assert got == row
+
     assert copy.read() is None
     assert copy.read() is None
 
 
-@pytest.mark.xfail
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_iter(conn, format):
+def test_copy_out_iter(conn, format):
     cur = conn.cursor()
     copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})")
-    assert list(copy) == sample_records
+    if format == pq.Format.TEXT:
+        want = [row + b"\n" for row in sample_text.splitlines()]
+    else:
+        want = sample_binary_rows
+    assert list(copy) == want
 
 
 @pytest.mark.parametrize(
index c5906453ea761f298ff9dbb2558876c07b6e6f70..fff8bd03498d5a408c97a8c7a26758ae68ff323a 100644 (file)
@@ -1,29 +1,51 @@
 import pytest
 
+from psycopg3 import pq
 from psycopg3 import errors as e
 from psycopg3.adapt import Format
 
-from .test_copy import sample_text, sample_binary  # noqa
+from .test_copy import sample_text, sample_binary, sample_binary_rows  # noqa
 from .test_copy import sample_values, sample_records, sample_tabledef
 
 pytestmark = pytest.mark.asyncio
 
 
-@pytest.mark.xfail
-@pytest.mark.parametrize(
-    "format, buffer",
-    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
-)
-async def test_copy_out_read(aconn, format, buffer):
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_out_read(aconn, format):
     cur = aconn.cursor()
     copy = await cur.copy(
         f"copy ({sample_values}) to stdout (format {format.name})"
     )
-    assert await copy.read() == globals()[buffer]
+
+    if format == pq.Format.TEXT:
+        want = [row + b"\n" for row in sample_text.splitlines()]
+    else:
+        want = sample_binary_rows
+
+    for row in want:
+        got = await copy.read()
+        assert got == row
+
     assert await copy.read() is None
     assert await copy.read() is None
 
 
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_out_iter(aconn, format):
+    cur = aconn.cursor()
+    copy = await cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    )
+    if format == pq.Format.TEXT:
+        want = [row + b"\n" for row in sample_text.splitlines()]
+    else:
+        want = sample_binary_rows
+    got = []
+    async for row in copy:
+        got.append(row)
+    assert got == want
+
+
 @pytest.mark.parametrize(
     "format, buffer",
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],