From: Daniele Varrazzo Date: Fri, 3 Sep 2021 16:39:09 +0000 (+0200) Subject: Do not use the extended protocol in COPY X-Git-Tag: 3.0~75 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=54de5084515db285092a2d303075f000ba5ab4b0;p=thirdparty%2Fpsycopg.git Do not use the extended protocol in COPY Error recovery is reported to be problematic (see #78). Close #82. --- diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 477ed84c6..d70fcd748 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -316,10 +316,14 @@ class BaseCursor(Generic[ConnectionType, Row]): yield from self._start_query() query = self._convert_query(statement) - # Make sure to avoid PQexec to avoid receiving a mix of COPY and - # other operations. - self._execute_send(query, no_pqexec=True) - (result,) = yield from execute(self._conn.pgconn) + self._execute_send(query, binary=False) + results = yield from execute(self._conn.pgconn) + if len(results) != 1: + raise e.ProgrammingError( + "COPY cannot be mixed with other operations" + ) + + result = results[0] self._check_copy_result(result) self.pgresult = result self._tx.set_pgresult(result) diff --git a/psycopg/psycopg/generators.py b/psycopg/psycopg/generators.py index 770e7f350..4675f0211 100644 --- a/psycopg/psycopg/generators.py +++ b/psycopg/psycopg/generators.py @@ -187,7 +187,11 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]: return data # Retrieve the final result of copy - (result,) = yield from fetch_many(pgconn) + results = yield from fetch_many(pgconn) + if len(results) > 1: + # TODO: too brutal? Copy worked. + raise e.ProgrammingError("you cannot mix COPY with other operations") + result = results[0] if result.status != ExecStatus.COMMAND_OK: encoding = py_codecs.get( pgconn.parameter_status(b"client_encoding") or "", "utf-8" diff --git a/tests/test_copy.py b/tests/test_copy.py index c3f7545fe..0d70140eb 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -244,6 +244,14 @@ def test_copy_bad_result(conn): with cur.copy("reset timezone"): pass + with pytest.raises(e.ProgrammingError): + with cur.copy("copy (select 1) to stdout; select 1") as copy: + list(copy) + + with pytest.raises(e.ProgrammingError): + with cur.copy("select 1; copy (select 1) to stdout"): + pass + def test_copy_in_str(conn): cur = conn.cursor() diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 34f9dd2d4..5fe8e1980 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -224,6 +224,14 @@ async def test_copy_bad_result(aconn): async with cur.copy("reset timezone"): pass + with pytest.raises(e.ProgrammingError): + async with cur.copy("copy (select 1) to stdout; select 1") as copy: + [_ async for _ in copy] + + with pytest.raises(e.ProgrammingError): + async with cur.copy("select 1; copy (select 1) to stdout"): + pass + async def test_copy_in_str(aconn): cur = aconn.cursor()