From: Daniele Varrazzo Date: Thu, 12 May 2022 23:18:59 +0000 (+0200) Subject: feat: add params to Cursor.copy() X-Git-Tag: 3.1~99^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f34614cfff48e26024fd0302a85949c0470ab56f;p=thirdparty%2Fpsycopg.git feat: add params to Cursor.copy() --- diff --git a/docs/api/cursors.rst b/docs/api/cursors.rst index 24fb1b743..57ffadbb8 100644 --- a/docs/api/cursors.rst +++ b/docs/api/cursors.rst @@ -110,6 +110,8 @@ The `!Cursor` class :param statement: The copy operation to execute :type statement: `!str`, `!bytes`, or `sql.Composable` + :param params: The parameters to pass to the statement, if any. + :type params: Sequence or Mapping .. note:: @@ -120,6 +122,9 @@ The `!Cursor` class See :ref:`copy` for information about :sql:`COPY`. + .. versionchanged:: 3.1 + Added parameters support. + .. automethod:: stream This command is similar to execute + iter; however it supports endless diff --git a/docs/basic/copy.rst b/docs/basic/copy.rst index e83c3b3b9..61cd8c03b 100644 --- a/docs/basic/copy.rst +++ b/docs/basic/copy.rst @@ -33,6 +33,15 @@ You can compose a COPY statement dynamically by using objects from the ) as copy: # read data from the 'copy' object using read()/read_row() +.. versionchanged:: 3.1 + + You can also pass parameters to `!copy()`, like in `~Cursor.execute()`: + + .. code:: python + + with cur.copy("COPY (SELECT * FROM table_name LIMIT %s) TO STDOUT", (3,)) as copy: + # expect no more than three records + The connection is subject to the usual transaction behaviour, so, unless the connection is in autocommit, at the end of the COPY operation you will still have to commit the pending changes and you can still roll them back. See diff --git a/docs/news.rst b/docs/news.rst index 4f8f8cc21..3076554e9 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -20,6 +20,7 @@ Psycopg 3.1 (unreleased) results (:ticket:`#164`). - `~Cursor.executemany()` performance improved by using batch mode internally (:ticket:`#145`). +- Add parameters to `~Cursor.copy()`. - Add `pq.PGconn.trace()` and related trace functions (:ticket:`#167`). - Add ``prepare_threshold`` parameter to `Connection` init (:ticket:`#200`). - Add ``cursor_factory`` parameter to `Connection` init. diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 26d161143..557556037 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -20,7 +20,7 @@ from .copy import Copy from .rows import Row, RowMaker, RowFactory from ._column import Column from ._cmodule import _psycopg -from ._queries import PostgresQuery +from ._queries import PostgresQuery, PostgresClientQuery from ._pipeline import Pipeline from ._encodings import pgconn_encoding from ._preparing import Prepare @@ -390,7 +390,9 @@ class BaseCursor(Generic[ConnectionType, Row]): self._tx = adapt.Transformer(self) yield from self._conn._start_query() - def _start_copy_gen(self, statement: Query) -> PQGen[None]: + def _start_copy_gen( + self, statement: Query, params: Optional[Params] = None + ) -> PQGen[None]: """Generator implementing sending a command for `Cursor.copy().""" # The connection gets in an unrecoverable state if we attempt COPY in @@ -399,6 +401,13 @@ class BaseCursor(Generic[ConnectionType, Row]): raise e.NotSupportedError("COPY cannot be used in pipeline mode") yield from self._start_query() + + # Merge the params client-side + if params: + pgq = PostgresClientQuery(self._tx) + pgq.convert(statement, params) + statement = pgq.query + query = self._convert_query(statement) self._execute_send(query, binary=False) @@ -850,14 +859,14 @@ class Cursor(BaseCursor["Connection[Any]", Row]): self._scroll(value, mode) @contextmanager - def copy(self, statement: Query) -> Iterator[Copy]: + def copy(self, statement: Query, params: Optional[Params] = None) -> Iterator[Copy]: """ Initiate a :sql:`COPY` operation and return an object to manage it. :rtype: Copy """ with self._conn.lock: - self._conn.wait(self._start_copy_gen(statement)) + self._conn.wait(self._start_copy_gen(statement, params)) with Copy(self) as copy: yield copy diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index d6c7dbbba..142845abc 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -190,12 +190,14 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): self._scroll(value, mode) @asynccontextmanager - async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]: + async def copy( + self, statement: Query, params: Optional[Params] = None + ) -> AsyncIterator[AsyncCopy]: """ :rtype: AsyncCopy """ async with self._conn.lock: - await self._conn.wait(self._start_copy_gen(statement)) + await self._conn.wait(self._start_copy_gen(statement, params)) async with AsyncCopy(self) as copy: yield copy diff --git a/tests/test_client_cursor.py b/tests/test_client_cursor.py index e4a308328..df202ea7b 100644 --- a/tests/test_client_cursor.py +++ b/tests/test_client_cursor.py @@ -584,6 +584,18 @@ def test_query_params_executemany(conn): assert cur._query.params == (b"3", b"4") +@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})]) +def test_copy_out_param(conn, ph, params): + cur = conn.cursor() + with cur.copy( + f"copy (select * from generate_series(1, {ph})) to stdout", params + ) as copy: + copy.set_types(["int4"]) + assert list(copy.rows()) == [(i + 1,) for i in range(10)] + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + def test_stream(conn): cur = conn.cursor() recs = [] diff --git a/tests/test_client_cursor_async.py b/tests/test_client_cursor_async.py index e6981d5d7..988ea57ec 100644 --- a/tests/test_client_cursor_async.py +++ b/tests/test_client_cursor_async.py @@ -8,7 +8,7 @@ import psycopg from psycopg import sql, rows from psycopg.adapt import PyFormat -from .utils import gc_collect +from .utils import alist, gc_collect from .test_cursor import my_row_factory from .test_cursor import execmany, _execmany # noqa: F401 @@ -579,6 +579,18 @@ async def test_query_params_executemany(aconn): assert cur._query.params == (b"3", b"4") +@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})]) +async def test_copy_out_param(aconn, ph, params): + cur = aconn.cursor() + async with cur.copy( + f"copy (select * from generate_series(1, {ph})) to stdout", params + ) as copy: + copy.set_types(["int4"]) + assert await alist(copy.rows()) == [(i + 1,) for i in range(10)] + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + async def test_stream(aconn): cur = aconn.cursor() recs = [] diff --git a/tests/test_copy.py b/tests/test_copy.py index 2d2826d79..15187f102 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -84,6 +84,18 @@ def test_copy_out_iter(conn, format): assert conn.info.transaction_status == conn.TransactionStatus.INTRANS +@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})]) +def test_copy_out_param(conn, ph, params): + cur = conn.cursor() + with cur.copy( + f"copy (select * from generate_series(1, {ph})) to stdout", params + ) as copy: + copy.set_types(["int4"]) + assert list(copy.rows()) == [(i + 1,) for i in range(10)] + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + @pytest.mark.parametrize("format", Format) @pytest.mark.parametrize("typetype", ["names", "oids"]) def test_read_rows(conn, format, typetype): diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 0c0683da8..c5df84f03 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -17,7 +17,7 @@ from psycopg.adapt import PyFormat from psycopg.types.hstore import register_hstore from psycopg.types.numeric import Int4 -from .utils import gc_collect +from .utils import alist, gc_collect from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa from .test_copy import eur, sample_values, sample_records, sample_tabledef from .test_copy import py_to_raw @@ -64,6 +64,18 @@ async def test_copy_out_iter(aconn, format): assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS +@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})]) +async def test_copy_out_param(aconn, ph, params): + cur = aconn.cursor() + async with cur.copy( + f"copy (select * from generate_series(1, {ph})) to stdout", params + ) as copy: + copy.set_types(["int4"]) + assert await alist(copy.rows()) == [(i + 1,) for i in range(10)] + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + @pytest.mark.parametrize("format", Format) @pytest.mark.parametrize("typetype", ["names", "oids"]) async def test_read_rows(aconn, format, typetype): @@ -792,7 +804,3 @@ class DataGenerator: block = block.encode() m.update(block) return m.hexdigest() - - -async def alist(it): - return [i async for i in it] diff --git a/tests/utils.py b/tests/utils.py index a02827f98..677df2688 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -72,3 +72,7 @@ def gc_collect(): """ for i in range(3): gc.collect() + + +async def alist(it): + return [i async for i in it]