From: Daniele Varrazzo Date: Wed, 30 Dec 2020 17:23:06 +0000 (+0100) Subject: Fixed binary copy with empty content X-Git-Tag: 3.0.dev0~222 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=0338907b8d111c738b551be47c50d3596531cb47;p=thirdparty%2Fpsycopg.git Fixed binary copy with empty content --- diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index 1b0bc17a9..fe39b4656 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -34,7 +34,8 @@ class BaseCopy(Generic[ConnectionType]): self.format = Format(self._pgresult.binary_tuples) self._encoding = self.connection.client_encoding - self._first_row = True + self._signature_sent = False + self._row_mode = False # true if the user is using send_row() self._finished = False if self.format == Format.TEXT: @@ -66,6 +67,11 @@ class BaseCopy(Generic[ConnectionType]): def _write_gen(self, buffer: Union[str, bytes]) -> PQGen[None]: conn = self.connection + # if write() was called, assume the header was sent together with the + # first block of data (either because we added it to the first row + # or, if the user is copying blocks, assume the blocks contain + # the header). + self._signature_sent = True yield from copy_to(conn.pgconn, self._ensure_bytes(buffer)) def _finish_gen(self, error: str = "") -> PQGen[None]: @@ -85,15 +91,29 @@ class BaseCopy(Generic[ConnectionType]): if self._pgresult.status == ExecStatus.COPY_OUT: return - if not exc_type: - if self.format == Format.BINARY and not self._first_row: - # send EOF only if we copied binary rows (_first_row is False) - yield from self._write_gen(b"\xff\xff") - yield from self._finish_gen() - else: + # In case of error in Python let's quit it here + if exc_type: yield from self._finish_gen( f"error from Python: {exc_type.__qualname__} - {exc_val}" ) + return + + if self.format == Format.BINARY: + # If we have sent no data we need to send the signature + # and the trailer + if not self._signature_sent: + yield from self._write_gen(self._binary_signature) + yield from self._write_gen(self._binary_trailer) + elif self._row_mode: + # if we have sent data already, we have sent the signature too + # (either with the first row, or we assume that in block mode + # the signature is included). + # Write the trailer only if we are sending rows (with the + # assumption that who is copying binary data is sending the + # whole format). + yield from self._write_gen(self._binary_trailer) + + yield from self._finish_gen() # Support methods @@ -106,6 +126,7 @@ class BaseCopy(Generic[ConnectionType]): out.append(dumper.dump(item)) else: out.append(None) + return self._format_copy_row(out) def _format_row_text(self, row: Sequence[Optional[bytes]]) -> bytes: @@ -128,14 +149,13 @@ class BaseCopy(Generic[ConnectionType]): ) -> bytes: """Convert a row of adapted data to the data to send for binary copy""" out = [] - if self._first_row: - out.append( - # Signature, flags, extra length - b"PGCOPY\n\xff\r\n\0" - b"\x00\x00\x00\x00" - b"\x00\x00\x00\x00" - ) - self._first_row = False + if not self._signature_sent: + out.append(self._binary_signature) + self._signature_sent = True + + # Note down that we are writing in row mode: it means we will have + # to take care of the end-of-copy marker too + self._row_mode = True out.append(__int2_struct.pack(len(row))) for item in row: @@ -143,10 +163,19 @@ class BaseCopy(Generic[ConnectionType]): out.append(__int4_struct.pack(len(item))) out.append(item) else: - out.append(b"\xff\xff\xff\xff") + out.append(self._binary_null) return b"".join(out) + _binary_signature = ( + # Signature, flags, extra length + b"PGCOPY\n\xff\r\n\0" + b"\x00\x00\x00\x00" + b"\x00\x00\x00\x00" + ) + _binary_trailer = b"\xff\xff" + _binary_null = b"\xff\xff\xff\xff" + def _ensure_bytes(self, data: Union[bytes, str]) -> bytes: if isinstance(data, bytes): return data diff --git a/tests/test_copy.py b/tests/test_copy.py index eb2c17433..af1547eae 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -136,6 +136,29 @@ def test_copy_in_str_binary(conn): assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_copy_in_empty(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin (format {format.name})"): + pass + + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + assert cur.rowcount == 0 + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_copy_in_error_empty(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + with cur.copy(f"copy copy_in from stdin (format {format.name})"): + raise Exception("mannaggiamiseria") + + assert "mannaggiamiseria" in str(exc.value) + assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR + + def test_copy_in_buffers_with_pg_error(conn): cur = conn.cursor() ensure_table(cur, sample_tabledef) diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 57027064b..b1ca5ce13 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -119,6 +119,29 @@ async def test_copy_in_str_binary(aconn): assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_copy_in_empty(aconn, format): + cur = await aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy(f"copy copy_in from stdin (format {format.name})"): + pass + + assert cur.rowcount == 0 + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_copy_in_error_empty(aconn, format): + cur = await aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + async with cur.copy(f"copy copy_in from stdin (format {format.name})"): + raise Exception("mannaggiamiseria") + + assert "mannaggiamiseria" in str(exc.value) + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR + + async def test_copy_in_buffers_with_pg_error(aconn): cur = await aconn.cursor() await ensure_table(cur, sample_tabledef)