]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fixed binary copy with empty content
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 30 Dec 2020 17:23:06 +0000 (18:23 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 8 Jan 2021 01:26:53 +0000 (02:26 +0100)
psycopg3/psycopg3/copy.py
tests/test_copy.py
tests/test_copy_async.py

index 1b0bc17a99c358885f41e94210a86c9d43db2b6a..fe39b4656be1000646d51c2d5cf5af6bfeaa0372 100644 (file)
@@ -34,7 +34,8 @@ class BaseCopy(Generic[ConnectionType]):
 
         self.format = Format(self._pgresult.binary_tuples)
         self._encoding = self.connection.client_encoding
-        self._first_row = True
+        self._signature_sent = False
+        self._row_mode = False  # true if the user is using send_row()
         self._finished = False
 
         if self.format == Format.TEXT:
@@ -66,6 +67,11 @@ class BaseCopy(Generic[ConnectionType]):
 
     def _write_gen(self, buffer: Union[str, bytes]) -> PQGen[None]:
         conn = self.connection
+        # if write() was called, assume the header was sent together with the
+        # first block of data (either because we added it to the first row
+        # or, if the user is copying blocks, assume the blocks contain
+        # the header).
+        self._signature_sent = True
         yield from copy_to(conn.pgconn, self._ensure_bytes(buffer))
 
     def _finish_gen(self, error: str = "") -> PQGen[None]:
@@ -85,15 +91,29 @@ class BaseCopy(Generic[ConnectionType]):
         if self._pgresult.status == ExecStatus.COPY_OUT:
             return
 
-        if not exc_type:
-            if self.format == Format.BINARY and not self._first_row:
-                # send EOF only if we copied binary rows (_first_row is False)
-                yield from self._write_gen(b"\xff\xff")
-            yield from self._finish_gen()
-        else:
+        # In case of error in Python let's quit it here
+        if exc_type:
             yield from self._finish_gen(
                 f"error from Python: {exc_type.__qualname__} - {exc_val}"
             )
+            return
+
+        if self.format == Format.BINARY:
+            # If we have sent no data we need to send the signature
+            # and the trailer
+            if not self._signature_sent:
+                yield from self._write_gen(self._binary_signature)
+                yield from self._write_gen(self._binary_trailer)
+            elif self._row_mode:
+                # if we have sent data already, we have sent the signature too
+                # (either with the first row, or we assume that in block mode
+                # the signature is included).
+                # Write the trailer only if we are sending rows (with the
+                # assumption that who is copying binary data is sending the
+                # whole format).
+                yield from self._write_gen(self._binary_trailer)
+
+        yield from self._finish_gen()
 
     # Support methods
 
@@ -106,6 +126,7 @@ class BaseCopy(Generic[ConnectionType]):
                 out.append(dumper.dump(item))
             else:
                 out.append(None)
+
         return self._format_copy_row(out)
 
     def _format_row_text(self, row: Sequence[Optional[bytes]]) -> bytes:
@@ -128,14 +149,13 @@ class BaseCopy(Generic[ConnectionType]):
     ) -> bytes:
         """Convert a row of adapted data to the data to send for binary copy"""
         out = []
-        if self._first_row:
-            out.append(
-                # Signature, flags, extra length
-                b"PGCOPY\n\xff\r\n\0"
-                b"\x00\x00\x00\x00"
-                b"\x00\x00\x00\x00"
-            )
-            self._first_row = False
+        if not self._signature_sent:
+            out.append(self._binary_signature)
+            self._signature_sent = True
+
+            # Note down that we are writing in row mode: it means we will have
+            # to take care of the end-of-copy marker too
+            self._row_mode = True
 
         out.append(__int2_struct.pack(len(row)))
         for item in row:
@@ -143,10 +163,19 @@ class BaseCopy(Generic[ConnectionType]):
                 out.append(__int4_struct.pack(len(item)))
                 out.append(item)
             else:
-                out.append(b"\xff\xff\xff\xff")
+                out.append(self._binary_null)
 
         return b"".join(out)
 
+    _binary_signature = (
+        # Signature, flags, extra length
+        b"PGCOPY\n\xff\r\n\0"
+        b"\x00\x00\x00\x00"
+        b"\x00\x00\x00\x00"
+    )
+    _binary_trailer = b"\xff\xff"
+    _binary_null = b"\xff\xff\xff\xff"
+
     def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
         if isinstance(data, bytes):
             return data
index eb2c17433fe1f25d09bfb8359ab9462f8b9f1a3d..af1547eae5e635634878b77e9004ef01cc146a21 100644 (file)
@@ -136,6 +136,29 @@ def test_copy_in_str_binary(conn):
     assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
 
 
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_in_empty(conn, format):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+        pass
+
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+    assert cur.rowcount == 0
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_in_error_empty(conn, format):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    with pytest.raises(e.QueryCanceled) as exc:
+        with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+            raise Exception("mannaggiamiseria")
+
+    assert "mannaggiamiseria" in str(exc.value)
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+
+
 def test_copy_in_buffers_with_pg_error(conn):
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
index 57027064b7e988f1cc5d46057914e0eb98982822..b1ca5ce136a6f2c10dbef3e222bc06b2d88042ba 100644 (file)
@@ -119,6 +119,29 @@ async def test_copy_in_str_binary(aconn):
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
 
 
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_in_empty(aconn, format):
+    cur = await aconn.cursor()
+    await ensure_table(cur, sample_tabledef)
+    async with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+        pass
+
+    assert cur.rowcount == 0
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_in_error_empty(aconn, format):
+    cur = await aconn.cursor()
+    await ensure_table(cur, sample_tabledef)
+    with pytest.raises(e.QueryCanceled) as exc:
+        async with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+            raise Exception("mannaggiamiseria")
+
+    assert "mannaggiamiseria" in str(exc.value)
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+
+
 async def test_copy_in_buffers_with_pg_error(aconn):
     cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)