From: Daniele Varrazzo Date: Tue, 12 Jan 2021 15:06:23 +0000 (+0100) Subject: Finish correctly a COPY TO operation reading copy.rows() X-Git-Tag: 3.0.dev0~179 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=76f81436f27adf6cd2a13e4e0d39e31ec7e0a037;p=thirdparty%2Fpsycopg.git Finish correctly a COPY TO operation reading copy.rows() Close #23 --- diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index 1dc3b6880..615fcad5d 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -83,6 +83,7 @@ class BaseCopy(Generic[ConnectionType]): data = yield from self._read_gen() if not data: return None + if self.format == Format.BINARY: if not self._signature_sent: if data[: len(_binary_signature)] != _binary_signature: @@ -91,8 +92,12 @@ class BaseCopy(Generic[ConnectionType]): ) self._signature_sent = True data = data[len(_binary_signature) :] + elif data == _binary_trailer: + yield from self._read_gen() + self._finished = True return None + return self._parse_row(data, self.transformer) def _write_gen(self, buffer: Union[str, bytes]) -> PQGen[None]: diff --git a/tests/test_copy.py b/tests/test_copy.py index 96bf4775e..299ecdb38 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -56,11 +56,15 @@ def test_copy_out_read(conn, format): for row in want: got = copy.read() assert got == row + assert ( + conn.pgconn.transaction_status == conn.TransactionStatus.ACTIVE + ) assert copy.read() == b"" assert copy.read() == b"" assert copy.read() == b"" + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) @@ -75,6 +79,8 @@ def test_copy_out_iter(conn, format): ) as copy: assert list(copy) == want + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) def test_read_rows(conn, format): @@ -94,7 +100,9 @@ def test_read_rows(conn, format): if not row: break rows.append(row) + assert rows == sample_records + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) @@ -107,7 +115,9 @@ def test_rows(conn, format): [builtins["int4"].oid, builtins["int4"].oid, builtins["text"].oid] ) rows = list(copy.rows()) + assert rows == sample_records + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 9c6f33c30..608b1ecff 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -32,11 +32,16 @@ async def test_copy_out_read(aconn, format): for row in want: got = await copy.read() assert got == row + assert ( + aconn.pgconn.transaction_status + == aconn.TransactionStatus.ACTIVE + ) assert await copy.read() == b"" assert await copy.read() == b"" assert await copy.read() == b"" + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) @@ -53,7 +58,9 @@ async def test_copy_out_iter(aconn, format): ) as copy: async for row in copy: got.append(row) + assert got == want + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) @@ -74,7 +81,9 @@ async def test_read_rows(aconn, format): if not row: break rows.append(row) + assert rows == sample_records + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) @@ -89,7 +98,9 @@ async def test_rows(aconn, format): rows = [] async for row in copy.rows(): rows.append(row) + assert rows == sample_records + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])