]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Do not use the extended protocol in COPY
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 3 Sep 2021 16:39:09 +0000 (18:39 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 21 Sep 2021 16:59:39 +0000 (17:59 +0100)
Error recovery is reported to be problematic (see #78).

Close #82.

psycopg/psycopg/cursor.py
psycopg/psycopg/generators.py
tests/test_copy.py
tests/test_copy_async.py

index 477ed84c6c234b4da16a0817329767c31e083218..d70fcd748b4140e2d8de335ffc5adef737f2dd65 100644 (file)
@@ -316,10 +316,14 @@ class BaseCursor(Generic[ConnectionType, Row]):
         yield from self._start_query()
         query = self._convert_query(statement)
 
-        # Make sure to avoid PQexec to avoid receiving a mix of COPY and
-        # other operations.
-        self._execute_send(query, no_pqexec=True)
-        (result,) = yield from execute(self._conn.pgconn)
+        self._execute_send(query, binary=False)
+        results = yield from execute(self._conn.pgconn)
+        if len(results) != 1:
+            raise e.ProgrammingError(
+                "COPY cannot be mixed with other operations"
+            )
+
+        result = results[0]
         self._check_copy_result(result)
         self.pgresult = result
         self._tx.set_pgresult(result)
index 770e7f350ea1a0f4e119c9a9d56cf156487fe4b2..4675f02117581ed1cba91da7a6a73bdab1e09325 100644 (file)
@@ -187,7 +187,11 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
         return data
 
     # Retrieve the final result of copy
-    (result,) = yield from fetch_many(pgconn)
+    results = yield from fetch_many(pgconn)
+    if len(results) > 1:
+        # TODO: too brutal? Copy worked.
+        raise e.ProgrammingError("you cannot mix COPY with other operations")
+    result = results[0]
     if result.status != ExecStatus.COMMAND_OK:
         encoding = py_codecs.get(
             pgconn.parameter_status(b"client_encoding") or "", "utf-8"
index c3f7545fe06aba6e0b847e6ecd6af63ba194f3ca..0d70140ebef8f8731fd0cbcf7ea1f9111184e511 100644 (file)
@@ -244,6 +244,14 @@ def test_copy_bad_result(conn):
         with cur.copy("reset timezone"):
             pass
 
+    with pytest.raises(e.ProgrammingError):
+        with cur.copy("copy (select 1) to stdout; select 1") as copy:
+            list(copy)
+
+    with pytest.raises(e.ProgrammingError):
+        with cur.copy("select 1; copy (select 1) to stdout"):
+            pass
+
 
 def test_copy_in_str(conn):
     cur = conn.cursor()
index 34f9dd2d459d82ef88514694a8d37d2f6f6f4499..5fe8e1980f87eb211f26d1a28cc0241d671e76c4 100644 (file)
@@ -224,6 +224,14 @@ async def test_copy_bad_result(aconn):
         async with cur.copy("reset timezone"):
             pass
 
+    with pytest.raises(e.ProgrammingError):
+        async with cur.copy("copy (select 1) to stdout; select 1") as copy:
+            [_ async for _ in copy]
+
+    with pytest.raises(e.ProgrammingError):
+        async with cur.copy("select 1; copy (select 1) to stdout"):
+            pass
+
 
 async def test_copy_in_str(aconn):
     cur = aconn.cursor()