]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: forbid COPY in pipeline mode
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 29 Mar 2022 23:54:56 +0000 (01:54 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Apr 2022 23:23:22 +0000 (01:23 +0200)
COPY is not supported. Attempting it puts the connection in
unrecoverable state, with pipeline sync failing and pipeline exit
complaining that there are still results. So let's try to not get in
that state.

psycopg/psycopg/cursor.py
tests/test_pipeline.py
tests/test_pipeline_async.py

index 4fd756ff9da01d3893253cdf704734a40fecd2ae..6f5a96e0e53eb05cfaa13915886337d652f4ce17 100644 (file)
@@ -387,6 +387,12 @@ class BaseCursor(Generic[ConnectionType, Row]):
 
     def _start_copy_gen(self, statement: Query) -> PQGen[None]:
         """Generator implementing sending a command for `Cursor.copy()."""
+
+        # The connection gets in an unrecoverable state if we attempt COPY in
+        # pipeline mode. Forbid it explicitly.
+        if self._conn._pipeline:
+            raise e.NotSupportedError("COPY cannot be used in pipeline mode")
+
         yield from self._start_query()
         query = self._convert_query(statement)
 
index 52a022ffec580607a8e990fd9e13ad4740d0a83e..43c862d389a37004c974cab0af24d3e710faa52b 100644 (file)
@@ -98,6 +98,14 @@ def test_cannot_insert_multiple_commands(conn):
     assert cm.value.sqlstate == "42601"
 
 
+def test_copy(conn):
+    with conn.pipeline():
+        cur = conn.cursor()
+        with pytest.raises(e.NotSupportedError):
+            with cur.copy("copy (select 1) to stdout"):
+                pass
+
+
 def test_pipeline_processed_at_exit(conn):
     with conn.cursor() as cur:
         with conn.pipeline() as p:
index 60338062ce8c27c35ddf586ff08b5a8fb8f2634d..668a8b3e0126e86024f609148451bfef0e62f048 100644 (file)
@@ -101,6 +101,14 @@ async def test_cannot_insert_multiple_commands(aconn):
     assert cm.value.sqlstate == "42601"
 
 
+async def test_copy(aconn):
+    async with aconn.pipeline():
+        cur = aconn.cursor()
+        with pytest.raises(e.NotSupportedError):
+            async with cur.copy("copy (select 1) to stdout") as copy:
+                await copy.read()
+
+
 async def test_pipeline_processed_at_exit(aconn):
     async with aconn.cursor() as cur:
         async with aconn.pipeline() as p: