]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Don't leave the connection ACTIVE on error in COPY_OUT
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 7 Jan 2022 21:06:20 +0000 (22:06 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 8 Jan 2022 00:03:25 +0000 (01:03 +0100)
Cancel the active COPY operation if the server has not finished sending
the data yet.

Close #203.

Also fix the tests which were based on this broken behaviour. A case of
self-administered Hyrum's law.

docs/news.rst
psycopg/psycopg/copy.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_copy.py
tests/test_copy_async.py
tests/test_transaction.py
tests/test_transaction_async.py

index fcfcdb1055ab0a5866cc1b049c1910f1af121711..d1b4c7ea4da9de3da7aee3bed2c0c453e574860f 100644 (file)
@@ -27,6 +27,8 @@ Psycopg 3.0.8 (unreleased)
   connection string, if available (:ticket:`#194`).
 - Fix possible warnings in objects deletion on interpreter shutdown
   (:ticket:`#198`).
+- Don't leave connections in ACTIVE state in case of error during COPY ... TO
+  STDOUT (:ticket:`#203`).
 
 
 Psycopg 3.0.7
index 189f70d75c31ef434191c4e241b59347ff5e2d93..71ce1ea83bbbc68ece58c93cd880e41a99985b6d 100644 (file)
@@ -146,7 +146,7 @@ class BaseCopy(Generic[ConnectionType]):
 
         return row
 
-    def _end_copy_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
+    def _end_copy_in_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
         bmsg: Optional[bytes]
         if exc:
             msg = f"error from Python: {type(exc).__qualname__} - {exc}"
@@ -160,6 +160,29 @@ class BaseCopy(Generic[ConnectionType]):
         self.cursor._rowcount = nrows if nrows is not None else -1
         self._finished = True
 
+    def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
+        if not exc:
+            return
+
+        if (
+            self.connection.pgconn.transaction_status
+            != pq.TransactionStatus.ACTIVE
+        ):
+            # The server has already finished to send copy data. The connection
+            # is already in a good state.
+            return
+
+        # Throw a cancel to the server, then consume the rest of the copy data
+        # (which might or might not have been already transferred entirely to
+        # the client, so we won't necessary see the exception associated with
+        # canceling).
+        self.connection.cancel()
+        try:
+            while (yield from self._read_gen()):
+                pass
+        except e.QueryCanceled:
+            pass
+
 
 class Copy(BaseCopy["Connection[Any]"]):
     """Manage a :sql:`COPY` operation."""
@@ -247,12 +270,11 @@ class Copy(BaseCopy["Connection[Any]"]):
         by exit. It is available if, despite what is documented, you end up
         using the `Copy` object outside a block.
         """
-        # no-op in COPY TO
-        if self._pgresult.status == ExecStatus.COPY_OUT:
-            return
-
-        self._write_end()
-        self.connection.wait(self._end_copy_gen(exc))
+        if self._pgresult.status == ExecStatus.COPY_IN:
+            self._write_end()
+            self.connection.wait(self._end_copy_in_gen(exc))
+        else:
+            self.connection.wait(self._end_copy_out_gen(exc))
 
     # Concurrent copy support
 
@@ -263,7 +285,7 @@ class Copy(BaseCopy["Connection[Any]"]):
 
         The function is designed to be run in a separate thread.
         """
-        while 1:
+        while True:
             data = self._queue.get(block=True, timeout=24 * 60 * 60)
             if not data:
                 break
@@ -344,12 +366,11 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
         await self._write(data)
 
     async def finish(self, exc: Optional[BaseException]) -> None:
-        # no-op in COPY TO
-        if self._pgresult.status == ExecStatus.COPY_OUT:
-            return
-
-        await self._write_end()
-        await self.connection.wait(self._end_copy_gen(exc))
+        if self._pgresult.status == ExecStatus.COPY_IN:
+            await self._write_end()
+            await self.connection.wait(self._end_copy_in_gen(exc))
+        else:
+            await self.connection.wait(self._end_copy_out_gen(exc))
 
     # Concurrent copy support
 
@@ -360,7 +381,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
 
         The function is designed to be run in a separate thread.
         """
-        while 1:
+        while True:
             data = await self._queue.get()
             if not data:
                 break
index dac8557c892dd92e448ffcf6ea51f9dbbda0431d..6fcd62aa4036411312d5691f68990a97c3426306 100644 (file)
@@ -181,17 +181,17 @@ def test_context_inerror_rollback_no_clobber(conn, dsn, caplog):
     assert "in rollback" in rec.message
 
 
-def test_context_active_rollback_no_clobber(conn, dsn, caplog):
+def test_context_active_rollback_no_clobber(dsn, caplog):
     caplog.set_level(logging.WARNING, logger="psycopg")
 
     with pytest.raises(ZeroDivisionError):
-        with psycopg.connect(dsn) as conn2:
-            with conn2.cursor() as cur:
-                with cur.copy(
-                    "copy (select generate_series(1, 10)) to stdout"
-                ) as copy:
-                    for row in copy.rows():
-                        1 / 0
+        with psycopg.connect(dsn) as conn:
+            conn.pgconn.exec_(
+                b"copy (select generate_series(1, 10)) to stdout"
+            )
+            status = conn.info.transaction_status
+            assert status == conn.TransactionStatus.ACTIVE
+            1 / 0
 
     assert len(caplog.records) == 1
     rec = caplog.records[0]
index beacc3e7d440c49172296723f428b4af7c3a8ac3..a0606af7473f0bf59a97c836054c782dab4bcbad 100644 (file)
@@ -180,17 +180,17 @@ async def test_context_inerror_rollback_no_clobber(conn, dsn, caplog):
     assert "in rollback" in rec.message
 
 
-async def test_context_active_rollback_no_clobber(conn, dsn, caplog):
+async def test_context_active_rollback_no_clobber(dsn, caplog):
     caplog.set_level(logging.WARNING, logger="psycopg")
 
     with pytest.raises(ZeroDivisionError):
-        async with await psycopg.AsyncConnection.connect(dsn) as conn2:
-            async with conn2.cursor() as cur:
-                async with cur.copy(
-                    "copy (select generate_series(1, 10)) to stdout"
-                ) as copy:
-                    async for row in copy.rows():
-                        1 / 0
+        async with await psycopg.AsyncConnection.connect(dsn) as conn:
+            conn.pgconn.exec_(
+                b"copy (select generate_series(1, 10)) to stdout"
+            )
+            status = conn.info.transaction_status
+            assert status == conn.TransactionStatus.ACTIVE
+            1 / 0
 
     assert len(caplog.records) == 1
     rec = caplog.records[0]
index f730e76bcd14aec447af82aedbc3b83f4455e4f2..72b0abfbf14e4ff7845639499aa6ab2609f39f36 100644 (file)
@@ -347,6 +347,40 @@ def test_copy_in_buffers_with_py_error(conn):
     assert conn.info.transaction_status == conn.TransactionStatus.INERROR
 
 
+def test_copy_out_error_with_copy_finished(conn):
+    cur = conn.cursor()
+    with pytest.raises(ZeroDivisionError):
+        with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy:
+            copy.read_row()
+            1 / 0
+
+    assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_copy_out_error_with_copy_not_finished(conn):
+    cur = conn.cursor()
+    with pytest.raises(ZeroDivisionError):
+        with cur.copy(
+            "copy (select generate_series(1, 1000000)) to stdout"
+        ) as copy:
+            copy.read_row()
+            1 / 0
+
+    assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_out_server_error(conn):
+    cur = conn.cursor()
+    with pytest.raises(e.DivisionByZero):
+        with cur.copy(
+            "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout"
+        ) as copy:
+            for block in copy:
+                pass
+
+    assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
 @pytest.mark.parametrize("format", Format)
 def test_copy_in_records(conn, format):
     cur = conn.cursor()
index d03e0ef2b87b6ebad3e098af8116f03f8e52ba74..717da3fa695e52c16c1855efc8427dcc021d806e 100644 (file)
@@ -327,6 +327,42 @@ async def test_copy_in_buffers_with_py_error(aconn):
     assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
 
 
+async def test_copy_out_error_with_copy_finished(aconn):
+    cur = aconn.cursor()
+    with pytest.raises(ZeroDivisionError):
+        async with cur.copy(
+            "copy (select generate_series(1, 2)) to stdout"
+        ) as copy:
+            await copy.read_row()
+            1 / 0
+
+    assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+async def test_copy_out_error_with_copy_not_finished(aconn):
+    cur = aconn.cursor()
+    with pytest.raises(ZeroDivisionError):
+        async with cur.copy(
+            "copy (select generate_series(1, 1000000)) to stdout"
+        ) as copy:
+            await copy.read_row()
+            1 / 0
+
+    assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_out_server_error(aconn):
+    cur = aconn.cursor()
+    with pytest.raises(e.DivisionByZero):
+        async with cur.copy(
+            "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout"
+        ) as copy:
+            async for block in copy:
+                pass
+
+    assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
 @pytest.mark.parametrize("format", Format)
 async def test_copy_in_records(aconn, format):
     cur = aconn.cursor()
index 74dc8c5c3058fe8c09a4aa8228381007646379ac..b2b5e0fb9b639c9d1586b5123e39eb8811ecaf1c 100644 (file)
@@ -151,24 +151,26 @@ def test_context_inerror_rollback_no_clobber(conn, dsn, caplog):
     assert "in rollback" in rec.message
 
 
-def test_context_active_rollback_no_clobber(conn, dsn, caplog):
+def test_context_active_rollback_no_clobber(dsn, caplog):
     caplog.set_level(logging.WARNING, logger="psycopg")
 
-    with pytest.raises(ZeroDivisionError):
-        conn2 = Connection.connect(dsn)
-        with conn2.transaction():
-            with conn2.cursor() as cur:
-                with cur.copy(
-                    "copy (select generate_series(1, 10)) to stdout"
-                ) as copy:
-                    for row in copy.rows():
-                        1 / 0
+    conn = Connection.connect(dsn)
+    try:
+        with pytest.raises(ZeroDivisionError):
+            with conn.transaction():
+                conn.pgconn.exec_(
+                    b"copy (select generate_series(1, 10)) to stdout"
+                )
+                status = conn.info.transaction_status
+                assert status == conn.TransactionStatus.ACTIVE
+                1 / 0
 
-    assert len(caplog.records) == 1
-    rec = caplog.records[0]
-    assert rec.levelno == logging.WARNING
-    assert "in rollback" in rec.message
-    conn2.close()
+        assert len(caplog.records) == 1
+        rec = caplog.records[0]
+        assert rec.levelno == logging.WARNING
+        assert "in rollback" in rec.message
+    finally:
+        conn.close()
 
 
 def test_interaction_dbapi_transaction(conn):
index c36552c11a5498f79b6d5db1fb66cd4746010e05..4335eb837d764d80d6470ab9c270286d75014c2d 100644 (file)
@@ -94,24 +94,26 @@ async def test_context_inerror_rollback_no_clobber(aconn, dsn, caplog):
     assert "in rollback" in rec.message
 
 
-async def test_context_active_rollback_no_clobber(aconn, dsn, caplog):
+async def test_context_active_rollback_no_clobber(dsn, caplog):
     caplog.set_level(logging.WARNING, logger="psycopg")
 
-    with pytest.raises(ZeroDivisionError):
-        conn2 = await AsyncConnection.connect(dsn)
-        async with conn2.transaction():
-            async with conn2.cursor() as cur:
-                async with cur.copy(
-                    "copy (select generate_series(1, 10)) to stdout"
-                ) as copy:
-                    async for row in copy.rows():
-                        1 / 0
+    conn = await AsyncConnection.connect(dsn)
+    try:
+        with pytest.raises(ZeroDivisionError):
+            async with conn.transaction():
+                conn.pgconn.exec_(
+                    b"copy (select generate_series(1, 10)) to stdout"
+                )
+                status = conn.info.transaction_status
+                assert status == conn.TransactionStatus.ACTIVE
+                1 / 0
 
-    assert len(caplog.records) == 1
-    rec = caplog.records[0]
-    assert rec.levelno == logging.WARNING
-    assert "in rollback" in rec.message
-    await conn2.close()
+        assert len(caplog.records) == 1
+        rec = caplog.records[0]
+        assert rec.levelno == logging.WARNING
+        assert "in rollback" in rec.message
+    finally:
+        await conn.close()
 
 
 async def test_interaction_dbapi_transaction(aconn):