]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Finish correctly a COPY TO operation reading copy.rows()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jan 2021 15:06:23 +0000 (16:06 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jan 2021 15:15:41 +0000 (16:15 +0100)
Close #23

psycopg3/psycopg3/copy.py
tests/test_copy.py
tests/test_copy_async.py

index 1dc3b6880a270580cac60992201bb40696fe0925..615fcad5d02ad286499e1aa8f4567a6ca6bc23a9 100644 (file)
@@ -83,6 +83,7 @@ class BaseCopy(Generic[ConnectionType]):
         data = yield from self._read_gen()
         if not data:
             return None
+
         if self.format == Format.BINARY:
             if not self._signature_sent:
                 if data[: len(_binary_signature)] != _binary_signature:
@@ -91,8 +92,12 @@ class BaseCopy(Generic[ConnectionType]):
                     )
                 self._signature_sent = True
                 data = data[len(_binary_signature) :]
+
             elif data == _binary_trailer:
+                yield from self._read_gen()
+                self._finished = True
                 return None
+
         return self._parse_row(data, self.transformer)
 
     def _write_gen(self, buffer: Union[str, bytes]) -> PQGen[None]:
index 96bf4775e96e3765ae4ea9bb24a8c58893364a82..299ecdb3892ef725178337610b67b6fef9f9b5a4 100644 (file)
@@ -56,11 +56,15 @@ def test_copy_out_read(conn, format):
         for row in want:
             got = copy.read()
             assert got == row
+            assert (
+                conn.pgconn.transaction_status == conn.TransactionStatus.ACTIVE
+            )
 
         assert copy.read() == b""
         assert copy.read() == b""
 
     assert copy.read() == b""
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
 
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
@@ -75,6 +79,8 @@ def test_copy_out_iter(conn, format):
     ) as copy:
         assert list(copy) == want
 
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
 def test_read_rows(conn, format):
@@ -94,7 +100,9 @@ def test_read_rows(conn, format):
             if not row:
                 break
             rows.append(row)
+
     assert rows == sample_records
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
 
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
@@ -107,7 +115,9 @@ def test_rows(conn, format):
             [builtins["int4"].oid, builtins["int4"].oid, builtins["text"].oid]
         )
         rows = list(copy.rows())
+
     assert rows == sample_records
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
 
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
index 9c6f33c30ada969af68caae65f0190907fdd18c3..608b1ecff1ec4a8eb834f39d170494623f594046 100644 (file)
@@ -32,11 +32,16 @@ async def test_copy_out_read(aconn, format):
         for row in want:
             got = await copy.read()
             assert got == row
+            assert (
+                aconn.pgconn.transaction_status
+                == aconn.TransactionStatus.ACTIVE
+            )
 
         assert await copy.read() == b""
         assert await copy.read() == b""
 
     assert await copy.read() == b""
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
 
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
@@ -53,7 +58,9 @@ async def test_copy_out_iter(aconn, format):
     ) as copy:
         async for row in copy:
             got.append(row)
+
     assert got == want
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
 
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
@@ -74,7 +81,9 @@ async def test_read_rows(aconn, format):
             if not row:
                 break
             rows.append(row)
+
     assert rows == sample_records
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
 
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
@@ -89,7 +98,9 @@ async def test_rows(aconn, format):
         rows = []
         async for row in copy.rows():
             rows.append(row)
+
     assert rows == sample_records
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
 
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])