From: Daniele Varrazzo Date: Fri, 13 Nov 2020 00:49:04 +0000 (+0000) Subject: Don't throw an error using COPY TO in a block X-Git-Tag: 3.0.dev0~369 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a119ff0bca333f91522f3648cd2134babe407db4;p=thirdparty%2Fpsycopg.git Don't throw an error using COPY TO in a block --- diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index e3831a6fc..f616ee7da 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic from typing import Any, Dict, List, Match, Optional, Sequence, Type, Union from types import TracebackType -from .pq import Format +from .pq import Format, ExecStatus from .proto import ConnectionType, Transformer from .generators import copy_from, copy_to, copy_end @@ -166,6 +166,10 @@ class Copy(BaseCopy["Connection"]): exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: + # no-op in COPY TO + if self.pgresult.status == ExecStatus.COPY_OUT: + return + if exc_val is None: if self.format == Format.BINARY and not self._first_row: # send EOF only if we copied binary rows (_first_row is False) @@ -217,6 +221,10 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: + # no-op in COPY TO + if self.pgresult.status == ExecStatus.COPY_OUT: + return + if exc_val is None: if self.format == Format.BINARY and not self._first_row: # send EOF only if we copied binary rows (_first_row is False) diff --git a/tests/test_copy.py b/tests/test_copy.py index d0a5acaa0..6dfc30a81 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -45,6 +45,23 @@ def test_copy_out_iter(conn, format): assert list(copy) == want +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_copy_out_context(conn, format): + cur = conn.cursor() + out = [] + with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + for row in copy: + out.append(row) + + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + assert out == want + + @pytest.mark.parametrize( "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 19987973e..9ac4f9dee 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -46,6 +46,25 @@ async def test_copy_out_iter(aconn, format): assert got == want +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_copy_out_context(aconn, format): + cur = await aconn.cursor() + out = [] + async with ( + await cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) + ) as copy: + async for row in copy: + out.append(row) + + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + assert out == want + + @pytest.mark.parametrize( "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],