self.format = Format(self._pgresult.binary_tuples)
self._encoding = self.connection.client_encoding
- self._first_row = True
+ self._signature_sent = False
+ self._row_mode = False # true if the user is using send_row()
self._finished = False
if self.format == Format.TEXT:
def _write_gen(self, buffer: Union[str, bytes]) -> PQGen[None]:
conn = self.connection
+ # if write() was called, assume the header was sent together with the
+ # first block of data (either because we added it to the first row
+ # or, if the user is copying blocks, assume the blocks contain
+ # the header).
+ self._signature_sent = True
yield from copy_to(conn.pgconn, self._ensure_bytes(buffer))
def _finish_gen(self, error: str = "") -> PQGen[None]:
if self._pgresult.status == ExecStatus.COPY_OUT:
return
- if not exc_type:
- if self.format == Format.BINARY and not self._first_row:
- # send EOF only if we copied binary rows (_first_row is False)
- yield from self._write_gen(b"\xff\xff")
- yield from self._finish_gen()
- else:
+ # In case of error in Python let's quit it here
+ if exc_type:
yield from self._finish_gen(
f"error from Python: {exc_type.__qualname__} - {exc_val}"
)
+ return
+
+ if self.format == Format.BINARY:
+ # If we have sent no data we need to send the signature
+ # and the trailer
+ if not self._signature_sent:
+ yield from self._write_gen(self._binary_signature)
+ yield from self._write_gen(self._binary_trailer)
+ elif self._row_mode:
+ # if we have sent data already, we have sent the signature too
+ # (either with the first row, or we assume that in block mode
+ # the signature is included).
+ # Write the trailer only if we are sending rows (with the
+ # assumption that who is copying binary data is sending the
+ # whole format).
+ yield from self._write_gen(self._binary_trailer)
+
+ yield from self._finish_gen()
# Support methods
out.append(dumper.dump(item))
else:
out.append(None)
+
return self._format_copy_row(out)
def _format_row_text(self, row: Sequence[Optional[bytes]]) -> bytes:
) -> bytes:
"""Convert a row of adapted data to the data to send for binary copy"""
out = []
- if self._first_row:
- out.append(
- # Signature, flags, extra length
- b"PGCOPY\n\xff\r\n\0"
- b"\x00\x00\x00\x00"
- b"\x00\x00\x00\x00"
- )
- self._first_row = False
+ if not self._signature_sent:
+ out.append(self._binary_signature)
+ self._signature_sent = True
+
+ # Note down that we are writing in row mode: it means we will have
+ # to take care of the end-of-copy marker too
+ self._row_mode = True
out.append(__int2_struct.pack(len(row)))
for item in row:
out.append(__int4_struct.pack(len(item)))
out.append(item)
else:
- out.append(b"\xff\xff\xff\xff")
+ out.append(self._binary_null)
return b"".join(out)
+ _binary_signature = (
+ # Signature, flags, extra length
+ b"PGCOPY\n\xff\r\n\0"
+ b"\x00\x00\x00\x00"
+ b"\x00\x00\x00\x00"
+ )
+ _binary_trailer = b"\xff\xff"
+ _binary_null = b"\xff\xff\xff\xff"
+
def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
if isinstance(data, bytes):
return data
assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_in_empty(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+ pass
+
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+ assert cur.rowcount == 0
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_in_error_empty(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+ raise Exception("mannaggiamiseria")
+
+ assert "mannaggiamiseria" in str(exc.value)
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+
+
def test_copy_in_buffers_with_pg_error(conn):
cur = conn.cursor()
ensure_table(cur, sample_tabledef)
assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_in_empty(aconn, format):
+ cur = await aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+ pass
+
+ assert cur.rowcount == 0
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_in_error_empty(aconn, format):
+ cur = await aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+ raise Exception("mannaggiamiseria")
+
+ assert "mannaggiamiseria" in str(exc.value)
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+
+
async def test_copy_in_buffers_with_pg_error(aconn):
cur = await aconn.cursor()
await ensure_table(cur, sample_tabledef)