]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Don't throw an error using COPY TO in a block
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 13 Nov 2020 00:49:04 +0000 (00:49 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 13 Nov 2020 00:49:04 +0000 (00:49 +0000)
psycopg3/psycopg3/copy.py
tests/test_copy.py
tests/test_copy_async.py

index e3831a6fc9e905a00c17095db1c11771f9466b8e..f616ee7daa57d888f6ff6c6c6264261d82b9654e 100644 (file)
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic
 from typing import Any, Dict, List, Match, Optional, Sequence, Type, Union
 from types import TracebackType
 
-from .pq import Format
+from .pq import Format, ExecStatus
 from .proto import ConnectionType, Transformer
 from .generators import copy_from, copy_to, copy_end
 
@@ -166,6 +166,10 @@ class Copy(BaseCopy["Connection"]):
         exc_val: Optional[BaseException],
         exc_tb: Optional[TracebackType],
     ) -> None:
+        # no-op in COPY TO
+        if self.pgresult.status == ExecStatus.COPY_OUT:
+            return
+
         if exc_val is None:
             if self.format == Format.BINARY and not self._first_row:
                 # send EOF only if we copied binary rows (_first_row is False)
@@ -217,6 +221,10 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
         exc_val: Optional[BaseException],
         exc_tb: Optional[TracebackType],
     ) -> None:
+        # no-op in COPY TO
+        if self.pgresult.status == ExecStatus.COPY_OUT:
+            return
+
         if exc_val is None:
             if self.format == Format.BINARY and not self._first_row:
                 # send EOF only if we copied binary rows (_first_row is False)
index d0a5acaa072eef3ab5816fe6fd55fa57c1c48613..6dfc30a81376d888d5df7e7b1213b9a21c815119 100644 (file)
@@ -45,6 +45,23 @@ def test_copy_out_iter(conn, format):
     assert list(copy) == want
 
 
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_out_context(conn, format):
+    cur = conn.cursor()
+    out = []
+    with cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    ) as copy:
+        for row in copy:
+            out.append(row)
+
+    if format == pq.Format.TEXT:
+        want = [row + b"\n" for row in sample_text.splitlines()]
+    else:
+        want = sample_binary_rows
+    assert out == want
+
+
 @pytest.mark.parametrize(
     "format, buffer",
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
index 19987973efcf859f57095aab54e667d8e0e88f61..9ac4f9dee771b158b42f74d179b0ae5a56d1fc3d 100644 (file)
@@ -46,6 +46,25 @@ async def test_copy_out_iter(aconn, format):
     assert got == want
 
 
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_out_context(aconn, format):
+    cur = await aconn.cursor()
+    out = []
+    async with (
+        await cur.copy(
+            f"copy ({sample_values}) to stdout (format {format.name})"
+        )
+    ) as copy:
+        async for row in copy:
+            out.append(row)
+
+    if format == pq.Format.TEXT:
+        want = [row + b"\n" for row in sample_text.splitlines()]
+    else:
+        want = sample_binary_rows
+    assert out == want
+
+
 @pytest.mark.parametrize(
     "format, buffer",
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],