Copy.read() changed to return b"" on EOF, consistently with file.read().
Also changed copy generators to return the final result of the
operation, and pgconn.get_copy_data() to always return bytes as second
argument, because it will never return an empty string unless on error.
With this changeset all psycopg2 test_copy pass, both sync and async.
from types import TracebackType
from .pq import Format, ExecStatus
-from .proto import ConnectionType, Transformer
+from .proto import ConnectionType
from .generators import copy_from, copy_to, copy_end
if TYPE_CHECKING:
from .pq.proto import PGresult
+ from .cursor import BaseCursor # noqa: F401
from .connection import Connection, AsyncConnection # noqa: F401
class BaseCopy(Generic[ConnectionType]):
- def __init__(self, connection: ConnectionType, transformer: Transformer):
- self.connection = connection
- self.transformer = transformer
+ def __init__(self, cursor: "BaseCursor[ConnectionType]"):
+ self.cursor = cursor
+ self.connection = cursor.connection
+ self.transformer = cursor._transformer
assert (
self.transformer.pgresult
class Copy(BaseCopy["Connection"]):
"""Manage a :sql:`COPY` operation."""
- def read(self) -> Optional[bytes]:
+ def read(self) -> bytes:
"""Read a row after a :sql:`COPY TO` operation.
- Return `None` when the data is finished.
+ Return an empty bytes string when the data is finished.
"""
if self._finished:
- return None
+ return b""
conn = self.connection
- rv = conn.wait(copy_from(conn.pgconn))
- if rv is None:
- self._finished = True
+ res = conn.wait(copy_from(conn.pgconn))
+ if isinstance(res, bytes):
+ return res
- return rv
+ # res is the final PGresult
+ self._finished = True
+ nrows = res.command_tuples
+ self.cursor._rowcount = nrows if nrows is not None else -1
+ return b""
def write(self, buffer: Union[str, bytes]) -> None:
"""Write a block of data after a :sql:`COPY FROM` operation."""
"""Terminate a :sql:`COPY FROM` operation."""
conn = self.connection
berr = error.encode(conn.client_encoding, "replace") if error else None
- conn.wait(copy_end(conn.pgconn, berr))
+ res = conn.wait(copy_end(conn.pgconn, berr))
+ nrows = res.command_tuples
+ self.cursor._rowcount = nrows if nrows is not None else -1
self._finished = True
def __enter__(self) -> "Copy":
+ assert not self._finished
return self
def __exit__(
def __iter__(self) -> Iterator[bytes]:
while True:
data = self.read()
- if data is None:
+ if not data:
break
yield data
class AsyncCopy(BaseCopy["AsyncConnection"]):
"""Manage an asynchronous :sql:`COPY` operation."""
- async def read(self) -> Optional[bytes]:
+ async def read(self) -> bytes:
if self._finished:
- return None
+ return b""
conn = self.connection
- rv = await conn.wait(copy_from(conn.pgconn))
- if rv is None:
- self._finished = True
+ res = await conn.wait(copy_from(conn.pgconn))
+ if isinstance(res, bytes):
+ return res
- return rv
+ # res is the final PGresult
+ self._finished = True
+ nrows = res.command_tuples
+ self.cursor._rowcount = nrows if nrows is not None else -1
+ return b""
async def write(self, buffer: Union[str, bytes]) -> None:
conn = self.connection
async def _finish(self, error: str = "") -> None:
conn = self.connection
berr = error.encode(conn.client_encoding, "replace") if error else None
- await conn.wait(copy_end(conn.pgconn, berr))
+ res = await conn.wait(copy_end(conn.pgconn, berr))
+ nrows = res.command_tuples
+ self.cursor._rowcount = nrows if nrows is not None else -1
self._finished = True
async def __aenter__(self) -> "AsyncCopy":
+ assert not self._finished
return self
async def __aexit__(
async def __aiter__(self) -> AsyncIterator[bytes]:
while True:
data = await self.read()
- if data is None:
+ if not data:
break
yield data
ExecStatus = pq.ExecStatus
_transformer: "Transformer"
+ _rowcount: int
def __init__(
self,
self._check_copy_results(results)
self.pgresult = results[0] # will set it on the transformer too
- return Copy(connection=self.connection, transformer=self._transformer)
+ return Copy(self)
class AsyncCursor(BaseCursor["AsyncConnection"]):
self._check_copy_results(results)
self.pgresult = results[0] # will set it on the transformer too
- return AsyncCopy(
- connection=self.connection,
- transformer=self._transformer,
- )
+ return AsyncCopy(self)
class NamedCursorMixin:
# Copyright (C) 2020 The Psycopg Team
import logging
-from typing import List, Optional
+from typing import List, Optional, Union
from . import pq
from . import errors as e
from .proto import PQGen
from .waiting import Wait, Ready
from .encodings import py_codecs
+from .pq.proto import PGconn, PGresult
logger = logging.getLogger(__name__)
-def connect(conninfo: str) -> PQGen[pq.proto.PGconn]:
+def connect(conninfo: str) -> PQGen[PGconn]:
"""
Generator to create a database connection without blocking.
return conn
-def execute(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]:
+def execute(pgconn: PGconn) -> PQGen[List[PGresult]]:
"""
Generator sending a query and returning results without blocking.
return rv
-def send(pgconn: pq.proto.PGconn) -> PQGen[None]:
+def send(pgconn: PGconn) -> PQGen[None]:
"""
Generator to send a query to the server without blocking.
continue
-def fetch(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]:
+def fetch(pgconn: PGconn) -> PQGen[List[PGresult]]:
"""
Generator retrieving results from the database without blocking.
or error).
"""
S = pq.ExecStatus
- results: List[pq.proto.PGresult] = []
+ results: List[PGresult] = []
while 1:
pgconn.consume_input()
if pgconn.is_busy():
return results
-def notifies(pgconn: pq.proto.PGconn) -> PQGen[List[pq.PGnotify]]:
+def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
yield pgconn.socket, Wait.R
pgconn.consume_input()
return ns
-def copy_from(pgconn: pq.proto.PGconn) -> PQGen[Optional[bytes]]:
+def copy_from(pgconn: PGconn) -> PQGen[Union[bytes, PGresult]]:
while 1:
nbytes, data = pgconn.get_copy_data(1)
if nbytes != 0:
return data
# Retrieve the final result of copy
- results = yield from fetch(pgconn)
- if len(results) != 1:
- raise e.InternalError(
- f"1 result expected from copy end, got {len(results)}"
- )
- if results[0].status != pq.ExecStatus.COMMAND_OK:
+ (result,) = yield from fetch(pgconn)
+ if result.status != pq.ExecStatus.COMMAND_OK:
encoding = py_codecs.get(
pgconn.parameter_status(b"client_encoding") or "", "utf-8"
)
- raise e.error_from_result(results[0], encoding=encoding)
+ raise e.error_from_result(result, encoding=encoding)
- return None
+ return result
-def copy_to(pgconn: pq.proto.PGconn, buffer: bytes) -> PQGen[None]:
+def copy_to(pgconn: PGconn, buffer: bytes) -> PQGen[None]:
# Retry enqueuing data until successful
while pgconn.put_copy_data(buffer) == 0:
yield pgconn.socket, Wait.W
-def copy_end(pgconn: pq.proto.PGconn, error: Optional[bytes]) -> PQGen[None]:
+def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
# Retry enqueuing end copy message until successful
while pgconn.put_copy_end(error) == 0:
yield pgconn.socket, Wait.W
break
# Retrieve the final result of copy
- results = yield from fetch(pgconn)
- if len(results) != 1:
- raise e.InternalError(
- f"1 result expected from copy end, got {len(results)}"
- )
- if results[0].status != pq.ExecStatus.COMMAND_OK:
+ (result,) = yield from fetch(pgconn)
+ if result.status != pq.ExecStatus.COMMAND_OK:
encoding = py_codecs.get(
pgconn.parameter_status(b"client_encoding") or "", "utf-8"
)
- raise e.error_from_result(results[0], encoding=encoding)
+ raise e.error_from_result(result, encoding=encoding)
+
+ return result
raise PQerror(f"sending copy end failed: {error_message(self)}")
return rv
- def get_copy_data(self, async_: int) -> Tuple[int, Optional[bytes]]:
+ def get_copy_data(self, async_: int) -> Tuple[int, bytes]:
buffer_ptr = c_char_p()
nbytes = impl.PQgetCopyData(self.pgconn_ptr, byref(buffer_ptr), async_)
if nbytes == -2:
impl.PQfreemem(buffer_ptr)
return nbytes, data
else:
- return nbytes, None
+ return nbytes, b""
def make_empty_result(self, exec_status: ExecStatus) -> "PGresult":
rv = impl.PQmakeEmptyPGresult(self.pgconn_ptr, exec_status)
def put_copy_end(self, error: Optional[bytes] = None) -> int:
...
- def get_copy_data(self, async_: int) -> Tuple[int, Optional[bytes]]:
+ def get_copy_data(self, async_: int) -> Tuple[int, bytes]:
...
def make_empty_result(self, exec_status: ExecStatus) -> "PGresult":
raise PQerror(f"sending copy end failed: {error_message(self)}")
return rv
- def get_copy_data(self, async_: int) -> Tuple[int, Optional[bytes]]:
+ def get_copy_data(self, async_: int) -> Tuple[int, bytes]:
cdef char *buffer_ptr = NULL
cdef int nbytes
nbytes = impl.PQgetCopyData(self.pgconn_ptr, &buffer_ptr, async_)
impl.PQfreemem(buffer_ptr)
return nbytes, data
else:
- return nbytes, None
+ return nbytes, b""
def make_empty_result(self, exec_status: ExecStatus) -> PGresult:
cdef impl.PGresult *rv = impl.PQmakeEmptyPGresult(
assert nbytes == len(data)
assert data == row
- assert pgconn.get_copy_data(0) == (-1, None)
+ assert pgconn.get_copy_data(0) == (-1, b"")
res = pgconn.get_result()
assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+import string
+import hashlib
+from io import BytesIO, StringIO
+from itertools import cycle
+
import pytest
from psycopg3 import pq
got = copy.read()
assert got == row
- assert copy.read() is None
- assert copy.read() is None
+ assert copy.read() == b""
+ assert copy.read() == b""
- assert copy.read() is None
+ assert copy.read() == b""
@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
assert data == [(True, True, 1, 256)]
+@pytest.mark.slow
+def test_copy_from_to(conn):
+ # Roundtrip from file to database to file blockwise
+ gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024)
+ gen.ensure_table()
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(block)
+
+ gen.assert_data()
+
+ f = StringIO()
+ with cur.copy("copy copy_in to stdout") as copy:
+ for block in copy:
+ f.write(block.decode("utf8"))
+
+ f.seek(0)
+ assert gen.sha(f) == gen.sha(gen.file())
+
+
+@pytest.mark.slow
+def test_copy_from_to_bytes(conn):
+ # Roundtrip from file to database to file blockwise
+ gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024)
+ gen.ensure_table()
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(block.encode("utf8"))
+
+ gen.assert_data()
+
+ f = BytesIO()
+ with cur.copy("copy copy_in to stdout") as copy:
+ for block in copy:
+ f.write(block)
+
+ f.seek(0)
+ assert gen.sha(f) == gen.sha(gen.file())
+
+
+@pytest.mark.slow
+def test_copy_from_insane_size(conn):
+ # Trying to trigger a "would block" error
+ gen = DataGenerator(
+ conn, nrecs=4 * 1024, srec=10 * 1024, block_size=20 * 1024 * 1024
+ )
+ gen.ensure_table()
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(block)
+
+ gen.assert_data()
+
+
+def test_copy_rowcount(conn):
+ gen = DataGenerator(conn, nrecs=3, srec=10)
+ gen.ensure_table()
+
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(block)
+ assert cur.rowcount == 3
+
+ gen = DataGenerator(conn, nrecs=2, srec=10, offset=3)
+ with cur.copy("copy copy_in from stdin") as copy:
+ for rec in gen.records():
+ copy.write_row(rec)
+ assert cur.rowcount == 2
+
+ with cur.copy("copy copy_in to stdout") as copy:
+ for block in copy:
+ pass
+ assert cur.rowcount == 5
+
+ with pytest.raises(e.BadCopyFileFormat):
+ with cur.copy("copy copy_in (id) from stdin") as copy:
+ for rec in gen.records():
+ copy.write_row(rec)
+ assert cur.rowcount == -1
+
+
def ensure_table(cur, tabledef, name="copy_in"):
cur.execute(f"drop table if exists {name}")
cur.execute(f"create table {name} ({tabledef})")
+
+
+class DataGenerator:
+ def __init__(self, conn, nrecs, srec, offset=0, block_size=8192):
+ self.conn = conn
+ self.nrecs = nrecs
+ self.srec = srec
+ self.offset = offset
+ self.block_size = block_size
+
+ def ensure_table(self):
+ cur = self.conn.cursor()
+ ensure_table(cur, "id integer primary key, data text")
+
+ def records(self):
+ for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)):
+ s = c * self.srec
+ yield (i + self.offset, s)
+
+ def file(self):
+ f = StringIO()
+ for i, s in self.records():
+ f.write("%s\t%s\n" % (i, s))
+
+ f.seek(0)
+ return f
+
+ def blocks(self):
+ f = self.file()
+ while True:
+ block = f.read(self.block_size)
+ if not block:
+ break
+ yield block
+
+ def assert_data(self):
+ cur = self.conn.cursor()
+ cur.execute("select id, data from copy_in order by id")
+ for record in self.records():
+ assert record == cur.fetchone()
+
+ assert cur.fetchone() is None
+
+ def sha(self, f):
+ m = hashlib.sha256()
+ while 1:
+ block = f.read()
+ if not block:
+ break
+ if isinstance(block, str):
+ block = block.encode("utf8")
+ m.update(block)
+ return m.hexdigest()
+import string
+import hashlib
+from io import BytesIO, StringIO
+from itertools import cycle
+
import pytest
from psycopg3 import pq
got = await copy.read()
assert got == row
- assert await copy.read() is None
- assert await copy.read() is None
+ assert await copy.read() == b""
+ assert await copy.read() == b""
- assert await copy.read() is None
+ assert await copy.read() == b""
@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
assert data == [(True, True, 1, 256)]
+@pytest.mark.slow
+async def test_copy_from_to(aconn):
+ # Roundtrip from file to database to file blockwise
+ gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024)
+ await gen.ensure_table()
+ cur = await aconn.cursor()
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ await copy.write(block)
+
+ await gen.assert_data()
+
+ f = StringIO()
+ async with cur.copy("copy copy_in to stdout") as copy:
+ async for block in copy:
+ f.write(block.decode("utf8"))
+
+ f.seek(0)
+ assert gen.sha(f) == gen.sha(gen.file())
+
+
+@pytest.mark.slow
+async def test_copy_from_to_bytes(aconn):
+ # Roundtrip from file to database to file blockwise
+ gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024)
+ await gen.ensure_table()
+ cur = await aconn.cursor()
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ await copy.write(block.encode("utf8"))
+
+ await gen.assert_data()
+
+ f = BytesIO()
+ async with cur.copy("copy copy_in to stdout") as copy:
+ async for block in copy:
+ f.write(block)
+
+ f.seek(0)
+ assert gen.sha(f) == gen.sha(gen.file())
+
+
+@pytest.mark.slow
+async def test_copy_from_insane_size(aconn):
+ # Trying to trigger a "would block" error
+ gen = DataGenerator(
+ aconn, nrecs=4 * 1024, srec=10 * 1024, block_size=20 * 1024 * 1024
+ )
+ await gen.ensure_table()
+ cur = await aconn.cursor()
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ await copy.write(block)
+
+ await gen.assert_data()
+
+
+async def test_copy_rowcount(aconn):
+ gen = DataGenerator(aconn, nrecs=3, srec=10)
+ await gen.ensure_table()
+
+ cur = await aconn.cursor()
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ await copy.write(block)
+ assert cur.rowcount == 3
+
+ gen = DataGenerator(aconn, nrecs=2, srec=10, offset=3)
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for rec in gen.records():
+ await copy.write_row(rec)
+ assert cur.rowcount == 2
+
+ async with cur.copy("copy copy_in to stdout") as copy:
+ async for block in copy:
+ pass
+ assert cur.rowcount == 5
+
+ with pytest.raises(e.BadCopyFileFormat):
+ async with cur.copy("copy copy_in (id) from stdin") as copy:
+ for rec in gen.records():
+ await copy.write_row(rec)
+ assert cur.rowcount == -1
+
+
async def ensure_table(cur, tabledef, name="copy_in"):
await cur.execute(f"drop table if exists {name}")
await cur.execute(f"create table {name} ({tabledef})")
+
+
+class DataGenerator:
+ def __init__(self, conn, nrecs, srec, offset=0, block_size=8192):
+ self.conn = conn
+ self.nrecs = nrecs
+ self.srec = srec
+ self.offset = offset
+ self.block_size = block_size
+
+ async def ensure_table(self):
+ cur = await self.conn.cursor()
+ await ensure_table(cur, "id integer primary key, data text")
+
+ def records(self):
+ for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)):
+ s = c * self.srec
+ yield (i + self.offset, s)
+
+ def file(self):
+ f = StringIO()
+ for i, s in self.records():
+ f.write("%s\t%s\n" % (i, s))
+
+ f.seek(0)
+ return f
+
+ def blocks(self):
+ f = self.file()
+ while True:
+ block = f.read(self.block_size)
+ if not block:
+ break
+ yield block
+
+ async def assert_data(self):
+ cur = await self.conn.cursor()
+ await cur.execute("select id, data from copy_in order by id")
+ for record in self.records():
+ assert record == await cur.fetchone()
+
+ assert await cur.fetchone() is None
+
+ def sha(self, f):
+ m = hashlib.sha256()
+ while 1:
+ block = f.read()
+ if not block:
+ break
+ if isinstance(block, str):
+ block = block.encode("utf8")
+ m.update(block)
+ return m.hexdigest()