From: Daniele Varrazzo Date: Sun, 15 Nov 2020 01:47:51 +0000 (+0000) Subject: Cursor.copy() made into a context manager X-Git-Tag: 3.0.dev0~351^2~15 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f738b2b6b9e7ef2221179fdd0cc1a1abd3d39ea6;p=thirdparty%2Fpsycopg.git Cursor.copy() made into a context manager What was before was a factory function, however that forced to have a pattern like: async with (await cursor.copy()) as copy Now instead what should be used is: async with cursor.copy() as copy With this change the user pretty much is never exposed anymore to a Copy object in a non-entered state. This is actually useful because it reduces the surface of the API: now for instance Copy.finis() can become a private method. --- diff --git a/docs/cursor.rst b/docs/cursor.rst index 6debec8a7..8f0c651e7 100644 --- a/docs/cursor.rst +++ b/docs/cursor.rst @@ -145,12 +145,6 @@ Cursor support objects The data in the tuple will be converted as configured on the cursor; see :ref:`adaptation` for details. - .. automethod:: finish - - If an *error* is specified, the :sql:`COPY` operation is cancelled. - - The method is called automatically at the end of a `!with` block. - .. autoclass:: AsyncCopy @@ -161,4 +155,3 @@ Cursor support objects .. automethod:: read .. automethod:: write .. automethod:: write_row - .. automethod:: finish diff --git a/docs/usage.rst b/docs/usage.rst index 83934d3cb..d4872a5c6 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -202,8 +202,9 @@ produce `!bytes`: .. code:: python with open("data.out", "wb") as f: - for data in cursor.copy("COPY table_name TO STDOUT") as copy: - f.write(data) + with cursor.copy("COPY table_name TO STDOUT") as copy: + for data in copy: + f.write(data) Asynchronous operations are supported using the same patterns on an `AsyncConnection`. For instance, if `!f` is an object supporting an @@ -212,7 +213,7 @@ copy operation could be: .. code:: python - async with (await cursor.copy("COPY data FROM STDIN")) as copy: + async with cursor.copy("COPY data FROM STDIN") as copy: data = await f.read() if not data: break diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index bf865e410..4cef7efae 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -5,17 +5,10 @@ psycopg3 cursor objects # Copyright (C) 2020 The Psycopg Team from types import TracebackType -from typing import ( - Any, - AsyncIterator, - Callable, - Generic, - Iterator, - List, - Mapping, -) -from typing import Optional, Sequence, Type, TYPE_CHECKING, Union +from typing import Any, AsyncIterator, Callable, Generic, Iterator, List +from typing import Mapping, Optional, Sequence, Type, TYPE_CHECKING, Union from operator import attrgetter +from contextlib import asynccontextmanager, contextmanager from . import errors as e from . import pq @@ -556,10 +549,19 @@ class Cursor(BaseCursor["Connection"]): self._pos += 1 yield row - def copy(self, statement: Query, vars: Optional[Params] = None) -> Copy: + @contextmanager + def copy( + self, statement: Query, vars: Optional[Params] = None + ) -> Iterator[Copy]: """ Initiate a :sql:`COPY` operation and return a `Copy` object to manage it. """ + with self._start_copy(statement, vars) as copy: + yield copy + + def _start_copy( + self, statement: Query, vars: Optional[Params] = None + ) -> Copy: with self.connection.lock: self._start_query() self.connection._start_query() @@ -682,12 +684,20 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): self._pos += 1 yield row + @asynccontextmanager async def copy( self, statement: Query, vars: Optional[Params] = None - ) -> AsyncCopy: + ) -> AsyncIterator[AsyncCopy]: """ Initiate a :sql:`COPY` operation and return an `AsyncCopy` object. """ + copy = await self._start_copy(statement, vars) + async with copy: + yield copy + + async def _start_copy( + self, statement: Query, vars: Optional[Params] = None + ) -> AsyncCopy: async with self.connection.lock: self._start_query() await self.connection._start_query() diff --git a/tests/test_copy.py b/tests/test_copy.py index 6dfc30a81..1d342b84f 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -35,31 +35,37 @@ sample_binary = b"".join(sample_binary_rows) @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_copy_out_iter(conn, format): - cur = conn.cursor() - copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") +def test_copy_out_read(conn, format): if format == pq.Format.TEXT: want = [row + b"\n" for row in sample_text.splitlines()] else: want = sample_binary_rows - assert list(copy) == want - -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_copy_out_context(conn, format): cur = conn.cursor() - out = [] with cur.copy( f"copy ({sample_values}) to stdout (format {format.name})" ) as copy: - for row in copy: - out.append(row) + for row in want: + got = copy.read() + assert got == row + + assert copy.read() is None + assert copy.read() is None + assert copy.read() is None + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_copy_out_iter(conn, format): if format == pq.Format.TEXT: want = [row + b"\n" for row in sample_text.splitlines()] else: want = sample_binary_rows - assert out == want + cur = conn.cursor() + with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + assert list(copy) == want @pytest.mark.parametrize( @@ -69,9 +75,9 @@ def test_copy_out_context(conn, format): def test_copy_in_buffers(conn, format, buffer): cur = conn.cursor() ensure_table(cur, sample_tabledef) - copy = cur.copy(f"copy copy_in from stdin (format {format.name})") - copy.write(globals()[buffer]) - copy.finish() + with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + copy.write(globals()[buffer]) + data = cur.execute("select * from copy_in order by 1").fetchall() assert data == sample_records @@ -79,11 +85,10 @@ def test_copy_in_buffers(conn, format, buffer): def test_copy_in_buffers_pg_error(conn): cur = conn.cursor() ensure_table(cur, sample_tabledef) - copy = cur.copy("copy copy_in from stdin (format text)") - copy.write(sample_text) - copy.write(sample_text) with pytest.raises(e.UniqueViolation): - copy.finish() + with cur.copy("copy copy_in from stdin (format text)") as copy: + copy.write(sample_text) + copy.write(sample_text) assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR @@ -93,13 +98,16 @@ def test_copy_bad_result(conn): cur = conn.cursor() with pytest.raises(e.SyntaxError): - cur.copy("wat") + with cur.copy("wat"): + pass with pytest.raises(e.ProgrammingError): - cur.copy("select 1") + with cur.copy("select 1"): + pass with pytest.raises(e.ProgrammingError): - cur.copy("reset timezone") + with cur.copy("reset timezone"): + pass @pytest.mark.parametrize( diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 9ac4f9dee..8cc0e05ae 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -12,57 +12,40 @@ pytestmark = pytest.mark.asyncio @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) async def test_copy_out_read(aconn, format): - cur = await aconn.cursor() - copy = await cur.copy( - f"copy ({sample_values}) to stdout (format {format.name})" - ) - if format == pq.Format.TEXT: want = [row + b"\n" for row in sample_text.splitlines()] else: want = sample_binary_rows - for row in want: - got = await copy.read() - assert got == row + cur = await aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + for row in want: + got = await copy.read() + assert got == row + + assert await copy.read() is None + assert await copy.read() is None assert await copy.read() is None - assert await copy.read() is None @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) async def test_copy_out_iter(aconn, format): - cur = await aconn.cursor() - copy = await cur.copy( - f"copy ({sample_values}) to stdout (format {format.name})" - ) if format == pq.Format.TEXT: want = [row + b"\n" for row in sample_text.splitlines()] else: want = sample_binary_rows - got = [] - async for row in copy: - got.append(row) - assert got == want - -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -async def test_copy_out_context(aconn, format): cur = await aconn.cursor() - out = [] - async with ( - await cur.copy( - f"copy ({sample_values}) to stdout (format {format.name})" - ) + got = [] + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" ) as copy: async for row in copy: - out.append(row) - - if format == pq.Format.TEXT: - want = [row + b"\n" for row in sample_text.splitlines()] - else: - want = sample_binary_rows - assert out == want + got.append(row) + assert got == want @pytest.mark.parametrize( @@ -72,9 +55,11 @@ async def test_copy_out_context(aconn, format): async def test_copy_in_buffers(aconn, format, buffer): cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) - copy = await cur.copy(f"copy copy_in from stdin (format {format.name})") - await copy.write(globals()[buffer]) - await copy.finish() + async with cur.copy( + f"copy copy_in from stdin (format {format.name})" + ) as copy: + await copy.write(globals()[buffer]) + await cur.execute("select * from copy_in order by 1") data = await cur.fetchall() assert data == sample_records @@ -83,11 +68,10 @@ async def test_copy_in_buffers(aconn, format, buffer): async def test_copy_in_buffers_pg_error(aconn): cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) - copy = await cur.copy("copy copy_in from stdin (format text)") - await copy.write(sample_text) - await copy.write(sample_text) with pytest.raises(e.UniqueViolation): - await copy.finish() + async with cur.copy("copy copy_in from stdin (format text)") as copy: + await copy.write(sample_text) + await copy.write(sample_text) assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR @@ -97,38 +81,22 @@ async def test_copy_bad_result(aconn): cur = await aconn.cursor() with pytest.raises(e.SyntaxError): - await cur.copy("wat") + async with cur.copy("wat"): + pass with pytest.raises(e.ProgrammingError): - await cur.copy("select 1") + async with cur.copy("select 1"): + pass with pytest.raises(e.ProgrammingError): - await cur.copy("reset timezone") - - -@pytest.mark.parametrize( - "format, buffer", - [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], -) -async def test_copy_in_buffers_with(aconn, format, buffer): - cur = await aconn.cursor() - await ensure_table(cur, sample_tabledef) - async with ( - await cur.copy(f"copy copy_in from stdin (format {format.name})") - ) as copy: - await copy.write(globals()[buffer]) - - await cur.execute("select * from copy_in order by 1") - data = await cur.fetchall() - assert data == sample_records + async with cur.copy("reset timezone"): + pass async def test_copy_in_str(aconn): cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) - async with ( - await cur.copy("copy copy_in from stdin (format text)") - ) as copy: + async with cur.copy("copy copy_in from stdin (format text)") as copy: await copy.write(sample_text.decode("utf8")) await cur.execute("select * from copy_in order by 1") @@ -140,9 +108,7 @@ async def test_copy_in_str_binary(aconn): cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) with pytest.raises(e.QueryCanceled): - async with ( - await cur.copy("copy copy_in from stdin (format binary)") - ) as copy: + async with cur.copy("copy copy_in from stdin (format binary)") as copy: await copy.write(sample_text.decode("utf8")) assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR @@ -152,9 +118,7 @@ async def test_copy_in_buffers_with_pg_error(aconn): cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) with pytest.raises(e.UniqueViolation): - async with ( - await cur.copy("copy copy_in from stdin (format text)") - ) as copy: + async with cur.copy("copy copy_in from stdin (format text)") as copy: await copy.write(sample_text) await copy.write(sample_text) @@ -165,9 +129,7 @@ async def test_copy_in_buffers_with_py_error(aconn): cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) with pytest.raises(e.QueryCanceled) as exc: - async with ( - await cur.copy("copy copy_in from stdin (format text)") - ) as copy: + async with cur.copy("copy copy_in from stdin (format text)") as copy: await copy.write(sample_text) raise Exception("nuttengoggenio") @@ -183,8 +145,8 @@ async def test_copy_in_records(aconn, format): cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) - async with ( - await cur.copy(f"copy copy_in from stdin (format {format.name})") + async with cur.copy( + f"copy copy_in from stdin (format {format.name})" ) as copy: for row in sample_records: await copy.write_row(row) @@ -202,10 +164,8 @@ async def test_copy_in_records_binary(aconn, format): cur = await aconn.cursor() await ensure_table(cur, "col1 serial primary key, col2 int, data text") - async with ( - await cur.copy( - f"copy copy_in (col2, data) from stdin (format {format.name})" - ) + async with cur.copy( + f"copy copy_in (col2, data) from stdin (format {format.name})" ) as copy: for row in sample_records: await copy.write_row((None, row[2])) @@ -220,9 +180,7 @@ async def test_copy_in_allchars(aconn): await ensure_table(cur, sample_tabledef) await aconn.set_client_encoding("utf8") - async with ( - await cur.copy("copy copy_in from stdin (format text)") - ) as copy: + async with cur.copy("copy copy_in from stdin (format text)") as copy: for i in range(1, 256): await copy.write_row((i, None, chr(i))) await copy.write_row((ord(eur), None, eur)) diff --git a/tests/test_sql.py b/tests/test_sql.py index 62ff15d0b..2d093c9a5 100755 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -182,12 +182,12 @@ class TestSqlFormat: copy.write_row((10, "a", "b", "c")) copy.write_row((20, "d", "e", "f")) - copy = cur.copy( + with cur.copy( sql.SQL("copy (select {f} from {t} order by id) to stdout").format( t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z") ) - ) - assert list(copy) == [b"c\n", b"f\n"] + ) as copy: + assert list(copy) == [b"c\n", b"f\n"] class TestIdentifier: