]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Handle a bad command passed to copy
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 30 Jun 2020 17:58:08 +0000 (05:58 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 30 Jun 2020 18:42:05 +0000 (06:42 +1200)
psycopg3/cursor.py
tests/test_copy.py
tests/test_copy_async.py

index 8ee944658c157eefc50dd2c049e3233f66a5d4f7..fd4e60252ed59f047591e28cee8cb9eb387db75f 100644 (file)
@@ -261,10 +261,16 @@ class BaseCursor:
 
         result = results[0]
         status = result.status
-        if status not in (pq.ExecStatus.COPY_IN, pq.ExecStatus.COPY_OUT):
+        if status in (pq.ExecStatus.COPY_IN, pq.ExecStatus.COPY_OUT):
+            return
+        elif status == pq.ExecStatus.FATAL_ERROR:
+            raise e.error_from_result(
+                result, encoding=self.connection.codec.name
+            )
+        else:
             raise e.ProgrammingError(
-                "copy() should be used only with COPY ... TO STDOUT"
-                " or COPY ... FROM STDIN statements"
+                "copy() should be used only with COPY ... TO STDOUT or COPY ..."
+                f" FROM STDIN statements, got {pq.ExecStatus(status).name}"
             )
 
 
index 3e256df6ae2ec84ff2deca1109dc974dcee06dcf..e739f1985b60be872cbfc2df52a8f565877c4a6a 100644 (file)
@@ -69,6 +69,21 @@ def test_copy_in_buffers_pg_error(conn):
     assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
 
 
+def test_copy_bad_result(conn):
+    conn.autocommit = True
+
+    cur = conn.cursor()
+
+    with pytest.raises(e.SyntaxError):
+        cur.copy("wat")
+
+    with pytest.raises(e.ProgrammingError):
+        cur.copy("select 1")
+
+    with pytest.raises(e.ProgrammingError):
+        cur.copy("reset timezone")
+
+
 @pytest.mark.parametrize(
     "format, buffer",
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
index cce63aab5bea431c72d3fa3b941f27eb70cc72d6..289a6c428156dbbd85e2cb4f8669fde2e915245c 100644 (file)
@@ -72,6 +72,21 @@ async def test_copy_in_buffers_pg_error(aconn):
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
 
 
+async def test_copy_bad_result(conn):
+    conn.autocommit = True
+
+    cur = conn.cursor()
+
+    with pytest.raises(e.SyntaxError):
+        await cur.copy("wat")
+
+    with pytest.raises(e.ProgrammingError):
+        await cur.copy("select 1")
+
+    with pytest.raises(e.ProgrammingError):
+        await cur.copy("reset timezone")
+
+
 @pytest.mark.parametrize(
     "format, buffer",
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],