From: Daniele Varrazzo Date: Tue, 29 Mar 2022 23:54:56 +0000 (+0200) Subject: fix: forbid COPY in pipeline mode X-Git-Tag: 3.1~145^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=67b4515c18aa1a812925119167e89ed95ceee6ba;p=thirdparty%2Fpsycopg.git fix: forbid COPY in pipeline mode COPY is not supported. Attempting it puts the connection in unrecoverable state, with pipeline sync failing and pipeline exit complaining that there are still results. So let's try to not get in that state. --- diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 4fd756ff9..6f5a96e0e 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -387,6 +387,12 @@ class BaseCursor(Generic[ConnectionType, Row]): def _start_copy_gen(self, statement: Query) -> PQGen[None]: """Generator implementing sending a command for `Cursor.copy().""" + + # The connection gets in an unrecoverable state if we attempt COPY in + # pipeline mode. Forbid it explicitly. + if self._conn._pipeline: + raise e.NotSupportedError("COPY cannot be used in pipeline mode") + yield from self._start_query() query = self._convert_query(statement) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 52a022ffe..43c862d38 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -98,6 +98,14 @@ def test_cannot_insert_multiple_commands(conn): assert cm.value.sqlstate == "42601" +def test_copy(conn): + with conn.pipeline(): + cur = conn.cursor() + with pytest.raises(e.NotSupportedError): + with cur.copy("copy (select 1) to stdout"): + pass + + def test_pipeline_processed_at_exit(conn): with conn.cursor() as cur: with conn.pipeline() as p: diff --git a/tests/test_pipeline_async.py b/tests/test_pipeline_async.py index 60338062c..668a8b3e0 100644 --- a/tests/test_pipeline_async.py +++ b/tests/test_pipeline_async.py @@ -101,6 +101,14 @@ async def test_cannot_insert_multiple_commands(aconn): assert cm.value.sqlstate == "42601" +async def test_copy(aconn): + async with aconn.pipeline(): + cur = aconn.cursor() + with pytest.raises(e.NotSupportedError): + async with cur.copy("copy (select 1) to stdout") as copy: + await copy.read() + + async def test_pipeline_processed_at_exit(aconn): async with aconn.cursor() as cur: async with aconn.pipeline() as p: