From: Daniele Varrazzo Date: Fri, 7 Jan 2022 21:06:20 +0000 (+0100) Subject: Don't leave the connection ACTIVE on error in COPY_OUT X-Git-Tag: pool-3.1~30 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fac39ff6ed5ea1d9fd14789a1471ddfe4c6aa7fd;p=thirdparty%2Fpsycopg.git Don't leave the connection ACTIVE on error in COPY_OUT Cancel the active COPY operation if the server has not finished sending the data yet. Close #203. Also fix the tests which were based on this broken behaviour. A case of self-administered Hyrum's law. --- diff --git a/docs/news.rst b/docs/news.rst index fcfcdb105..d1b4c7ea4 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -27,6 +27,8 @@ Psycopg 3.0.8 (unreleased) connection string, if available (:ticket:`#194`). - Fix possible warnings in objects deletion on interpreter shutdown (:ticket:`#198`). +- Don't leave connections in ACTIVE state in case of error during COPY ... TO + STDOUT (:ticket:`#203`). Psycopg 3.0.7 diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index 189f70d75..71ce1ea83 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -146,7 +146,7 @@ class BaseCopy(Generic[ConnectionType]): return row - def _end_copy_gen(self, exc: Optional[BaseException]) -> PQGen[None]: + def _end_copy_in_gen(self, exc: Optional[BaseException]) -> PQGen[None]: bmsg: Optional[bytes] if exc: msg = f"error from Python: {type(exc).__qualname__} - {exc}" @@ -160,6 +160,29 @@ class BaseCopy(Generic[ConnectionType]): self.cursor._rowcount = nrows if nrows is not None else -1 self._finished = True + def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]: + if not exc: + return + + if ( + self.connection.pgconn.transaction_status + != pq.TransactionStatus.ACTIVE + ): + # The server has already finished to send copy data. The connection + # is already in a good state. + return + + # Throw a cancel to the server, then consume the rest of the copy data + # (which might or might not have been already transferred entirely to + # the client, so we won't necessary see the exception associated with + # canceling). + self.connection.cancel() + try: + while (yield from self._read_gen()): + pass + except e.QueryCanceled: + pass + class Copy(BaseCopy["Connection[Any]"]): """Manage a :sql:`COPY` operation.""" @@ -247,12 +270,11 @@ class Copy(BaseCopy["Connection[Any]"]): by exit. It is available if, despite what is documented, you end up using the `Copy` object outside a block. """ - # no-op in COPY TO - if self._pgresult.status == ExecStatus.COPY_OUT: - return - - self._write_end() - self.connection.wait(self._end_copy_gen(exc)) + if self._pgresult.status == ExecStatus.COPY_IN: + self._write_end() + self.connection.wait(self._end_copy_in_gen(exc)) + else: + self.connection.wait(self._end_copy_out_gen(exc)) # Concurrent copy support @@ -263,7 +285,7 @@ class Copy(BaseCopy["Connection[Any]"]): The function is designed to be run in a separate thread. """ - while 1: + while True: data = self._queue.get(block=True, timeout=24 * 60 * 60) if not data: break @@ -344,12 +366,11 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): await self._write(data) async def finish(self, exc: Optional[BaseException]) -> None: - # no-op in COPY TO - if self._pgresult.status == ExecStatus.COPY_OUT: - return - - await self._write_end() - await self.connection.wait(self._end_copy_gen(exc)) + if self._pgresult.status == ExecStatus.COPY_IN: + await self._write_end() + await self.connection.wait(self._end_copy_in_gen(exc)) + else: + await self.connection.wait(self._end_copy_out_gen(exc)) # Concurrent copy support @@ -360,7 +381,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): The function is designed to be run in a separate thread. """ - while 1: + while True: data = await self._queue.get() if not data: break diff --git a/tests/test_connection.py b/tests/test_connection.py index dac8557c8..6fcd62aa4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -181,17 +181,17 @@ def test_context_inerror_rollback_no_clobber(conn, dsn, caplog): assert "in rollback" in rec.message -def test_context_active_rollback_no_clobber(conn, dsn, caplog): +def test_context_active_rollback_no_clobber(dsn, caplog): caplog.set_level(logging.WARNING, logger="psycopg") with pytest.raises(ZeroDivisionError): - with psycopg.connect(dsn) as conn2: - with conn2.cursor() as cur: - with cur.copy( - "copy (select generate_series(1, 10)) to stdout" - ) as copy: - for row in copy.rows(): - 1 / 0 + with psycopg.connect(dsn) as conn: + conn.pgconn.exec_( + b"copy (select generate_series(1, 10)) to stdout" + ) + status = conn.info.transaction_status + assert status == conn.TransactionStatus.ACTIVE + 1 / 0 assert len(caplog.records) == 1 rec = caplog.records[0] diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index beacc3e7d..a0606af74 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -180,17 +180,17 @@ async def test_context_inerror_rollback_no_clobber(conn, dsn, caplog): assert "in rollback" in rec.message -async def test_context_active_rollback_no_clobber(conn, dsn, caplog): +async def test_context_active_rollback_no_clobber(dsn, caplog): caplog.set_level(logging.WARNING, logger="psycopg") with pytest.raises(ZeroDivisionError): - async with await psycopg.AsyncConnection.connect(dsn) as conn2: - async with conn2.cursor() as cur: - async with cur.copy( - "copy (select generate_series(1, 10)) to stdout" - ) as copy: - async for row in copy.rows(): - 1 / 0 + async with await psycopg.AsyncConnection.connect(dsn) as conn: + conn.pgconn.exec_( + b"copy (select generate_series(1, 10)) to stdout" + ) + status = conn.info.transaction_status + assert status == conn.TransactionStatus.ACTIVE + 1 / 0 assert len(caplog.records) == 1 rec = caplog.records[0] diff --git a/tests/test_copy.py b/tests/test_copy.py index f730e76bc..72b0abfbf 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -347,6 +347,40 @@ def test_copy_in_buffers_with_py_error(conn): assert conn.info.transaction_status == conn.TransactionStatus.INERROR +def test_copy_out_error_with_copy_finished(conn): + cur = conn.cursor() + with pytest.raises(ZeroDivisionError): + with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy: + copy.read_row() + 1 / 0 + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +def test_copy_out_error_with_copy_not_finished(conn): + cur = conn.cursor() + with pytest.raises(ZeroDivisionError): + with cur.copy( + "copy (select generate_series(1, 1000000)) to stdout" + ) as copy: + copy.read_row() + 1 / 0 + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_out_server_error(conn): + cur = conn.cursor() + with pytest.raises(e.DivisionByZero): + with cur.copy( + "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout" + ) as copy: + for block in copy: + pass + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + @pytest.mark.parametrize("format", Format) def test_copy_in_records(conn, format): cur = conn.cursor() diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index d03e0ef2b..717da3fa6 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -327,6 +327,42 @@ async def test_copy_in_buffers_with_py_error(aconn): assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR +async def test_copy_out_error_with_copy_finished(aconn): + cur = aconn.cursor() + with pytest.raises(ZeroDivisionError): + async with cur.copy( + "copy (select generate_series(1, 2)) to stdout" + ) as copy: + await copy.read_row() + 1 / 0 + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_copy_out_error_with_copy_not_finished(aconn): + cur = aconn.cursor() + with pytest.raises(ZeroDivisionError): + async with cur.copy( + "copy (select generate_series(1, 1000000)) to stdout" + ) as copy: + await copy.read_row() + 1 / 0 + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_out_server_error(aconn): + cur = aconn.cursor() + with pytest.raises(e.DivisionByZero): + async with cur.copy( + "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout" + ) as copy: + async for block in copy: + pass + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + @pytest.mark.parametrize("format", Format) async def test_copy_in_records(aconn, format): cur = aconn.cursor() diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 74dc8c5c3..b2b5e0fb9 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -151,24 +151,26 @@ def test_context_inerror_rollback_no_clobber(conn, dsn, caplog): assert "in rollback" in rec.message -def test_context_active_rollback_no_clobber(conn, dsn, caplog): +def test_context_active_rollback_no_clobber(dsn, caplog): caplog.set_level(logging.WARNING, logger="psycopg") - with pytest.raises(ZeroDivisionError): - conn2 = Connection.connect(dsn) - with conn2.transaction(): - with conn2.cursor() as cur: - with cur.copy( - "copy (select generate_series(1, 10)) to stdout" - ) as copy: - for row in copy.rows(): - 1 / 0 + conn = Connection.connect(dsn) + try: + with pytest.raises(ZeroDivisionError): + with conn.transaction(): + conn.pgconn.exec_( + b"copy (select generate_series(1, 10)) to stdout" + ) + status = conn.info.transaction_status + assert status == conn.TransactionStatus.ACTIVE + 1 / 0 - assert len(caplog.records) == 1 - rec = caplog.records[0] - assert rec.levelno == logging.WARNING - assert "in rollback" in rec.message - conn2.close() + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + finally: + conn.close() def test_interaction_dbapi_transaction(conn): diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py index c36552c11..4335eb837 100644 --- a/tests/test_transaction_async.py +++ b/tests/test_transaction_async.py @@ -94,24 +94,26 @@ async def test_context_inerror_rollback_no_clobber(aconn, dsn, caplog): assert "in rollback" in rec.message -async def test_context_active_rollback_no_clobber(aconn, dsn, caplog): +async def test_context_active_rollback_no_clobber(dsn, caplog): caplog.set_level(logging.WARNING, logger="psycopg") - with pytest.raises(ZeroDivisionError): - conn2 = await AsyncConnection.connect(dsn) - async with conn2.transaction(): - async with conn2.cursor() as cur: - async with cur.copy( - "copy (select generate_series(1, 10)) to stdout" - ) as copy: - async for row in copy.rows(): - 1 / 0 + conn = await AsyncConnection.connect(dsn) + try: + with pytest.raises(ZeroDivisionError): + async with conn.transaction(): + conn.pgconn.exec_( + b"copy (select generate_series(1, 10)) to stdout" + ) + status = conn.info.transaction_status + assert status == conn.TransactionStatus.ACTIVE + 1 / 0 - assert len(caplog.records) == 1 - rec = caplog.records[0] - assert rec.levelno == logging.WARNING - assert "in rollback" in rec.message - await conn2.close() + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + finally: + await conn.close() async def test_interaction_dbapi_transaction(aconn):