]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: add params to Cursor.copy()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 12 May 2022 23:18:59 +0000 (01:18 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 15 May 2022 09:15:30 +0000 (11:15 +0200)
docs/api/cursors.rst
docs/basic/copy.rst
docs/news.rst
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
tests/test_client_cursor.py
tests/test_client_cursor_async.py
tests/test_copy.py
tests/test_copy_async.py
tests/utils.py

index 24fb1b7437d1c522dc80836d0a2d1f3a2403c4d4..57ffadbb8e29c4eab18b552e88ea821980ec3502 100644 (file)
@@ -110,6 +110,8 @@ The `!Cursor` class
 
         :param statement: The copy operation to execute
         :type statement: `!str`, `!bytes`, or `sql.Composable`
+        :param params: The parameters to pass to the statement, if any.
+        :type params: Sequence or Mapping
 
         .. note::
 
@@ -120,6 +122,9 @@ The `!Cursor` class
 
         See :ref:`copy` for information about :sql:`COPY`.
 
+        .. versionchanged:: 3.1
+            Added parameters support.
+
     .. automethod:: stream
 
         This command is similar to execute + iter; however it supports endless
index e83c3b3b998cafad29a39b6c9831737d726332ac..61cd8c03be1bf00e843e7ee538ea35f68c94027b 100644 (file)
@@ -33,6 +33,15 @@ You can compose a COPY statement dynamically by using objects from the
     ) as copy:
         # read data from the 'copy' object using read()/read_row()
 
+.. versionchanged:: 3.1
+
+    You can also pass parameters to `!copy()`, like in `~Cursor.execute()`:
+
+    .. code:: python
+
+        with cur.copy("COPY (SELECT * FROM table_name LIMIT %s) TO STDOUT", (3,)) as copy:
+            # expect no more than three records
+
 The connection is subject to the usual transaction behaviour, so, unless the
 connection is in autocommit, at the end of the COPY operation you will still
 have to commit the pending changes and you can still roll them back. See
index 4f8f8cc2195839a829c573d80241c41e312fc36c..3076554e95b08355001ebe81fe0f755737761462 100644 (file)
@@ -20,6 +20,7 @@ Psycopg 3.1 (unreleased)
   results (:ticket:`#164`).
 - `~Cursor.executemany()` performance improved by using batch mode internally
   (:ticket:`#145`).
+- Add parameters to `~Cursor.copy()`.
 - Add `pq.PGconn.trace()` and related trace functions (:ticket:`#167`).
 - Add ``prepare_threshold`` parameter to `Connection` init (:ticket:`#200`).
 - Add ``cursor_factory`` parameter to `Connection` init.
index 26d1611434558ef78bfaf04ad91f2499339e294b..557556037bfd29d652091af37ae45066a65b8e65 100644 (file)
@@ -20,7 +20,7 @@ from .copy import Copy
 from .rows import Row, RowMaker, RowFactory
 from ._column import Column
 from ._cmodule import _psycopg
-from ._queries import PostgresQuery
+from ._queries import PostgresQuery, PostgresClientQuery
 from ._pipeline import Pipeline
 from ._encodings import pgconn_encoding
 from ._preparing import Prepare
@@ -390,7 +390,9 @@ class BaseCursor(Generic[ConnectionType, Row]):
             self._tx = adapt.Transformer(self)
         yield from self._conn._start_query()
 
-    def _start_copy_gen(self, statement: Query) -> PQGen[None]:
+    def _start_copy_gen(
+        self, statement: Query, params: Optional[Params] = None
+    ) -> PQGen[None]:
         """Generator implementing sending a command for `Cursor.copy()."""
 
         # The connection gets in an unrecoverable state if we attempt COPY in
@@ -399,6 +401,13 @@ class BaseCursor(Generic[ConnectionType, Row]):
             raise e.NotSupportedError("COPY cannot be used in pipeline mode")
 
         yield from self._start_query()
+
+        # Merge the params client-side
+        if params:
+            pgq = PostgresClientQuery(self._tx)
+            pgq.convert(statement, params)
+            statement = pgq.query
+
         query = self._convert_query(statement)
 
         self._execute_send(query, binary=False)
@@ -850,14 +859,14 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         self._scroll(value, mode)
 
     @contextmanager
-    def copy(self, statement: Query) -> Iterator[Copy]:
+    def copy(self, statement: Query, params: Optional[Params] = None) -> Iterator[Copy]:
         """
         Initiate a :sql:`COPY` operation and return an object to manage it.
 
         :rtype: Copy
         """
         with self._conn.lock:
-            self._conn.wait(self._start_copy_gen(statement))
+            self._conn.wait(self._start_copy_gen(statement, params))
 
         with Copy(self) as copy:
             yield copy
index d6c7dbbba0e0f863a9df117868599d6c500b89b9..142845abcf60de4df26414f72b9bae98fb4026bd 100644 (file)
@@ -190,12 +190,14 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         self._scroll(value, mode)
 
     @asynccontextmanager
-    async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]:
+    async def copy(
+        self, statement: Query, params: Optional[Params] = None
+    ) -> AsyncIterator[AsyncCopy]:
         """
         :rtype: AsyncCopy
         """
         async with self._conn.lock:
-            await self._conn.wait(self._start_copy_gen(statement))
+            await self._conn.wait(self._start_copy_gen(statement, params))
 
         async with AsyncCopy(self) as copy:
             yield copy
index e4a308328ae52ce416913de56e76ee969ce25cdb..df202ea7b3438748127d07c29f319a739fa46bdf 100644 (file)
@@ -584,6 +584,18 @@ def test_query_params_executemany(conn):
     assert cur._query.params == (b"3", b"4")
 
 
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+def test_copy_out_param(conn, ph, params):
+    cur = conn.cursor()
+    with cur.copy(
+        f"copy (select * from generate_series(1, {ph})) to stdout", params
+    ) as copy:
+        copy.set_types(["int4"])
+        assert list(copy.rows()) == [(i + 1,) for i in range(10)]
+
+    assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
 def test_stream(conn):
     cur = conn.cursor()
     recs = []
index e6981d5d791eaabb007bc6b4b51deab15fa9e9b5..988ea57ec1c998dae471b14d4378d2542b364908 100644 (file)
@@ -8,7 +8,7 @@ import psycopg
 from psycopg import sql, rows
 from psycopg.adapt import PyFormat
 
-from .utils import gc_collect
+from .utils import alist, gc_collect
 from .test_cursor import my_row_factory
 from .test_cursor import execmany, _execmany  # noqa: F401
 
@@ -579,6 +579,18 @@ async def test_query_params_executemany(aconn):
     assert cur._query.params == (b"3", b"4")
 
 
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+async def test_copy_out_param(aconn, ph, params):
+    cur = aconn.cursor()
+    async with cur.copy(
+        f"copy (select * from generate_series(1, {ph})) to stdout", params
+    ) as copy:
+        copy.set_types(["int4"])
+        assert await alist(copy.rows()) == [(i + 1,) for i in range(10)]
+
+    assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
 async def test_stream(aconn):
     cur = aconn.cursor()
     recs = []
index 2d2826d793eb5ff8bd75483d4b00c39c06a65337..15187f10227b974568fd3d77f5b28e87480502a7 100644 (file)
@@ -84,6 +84,18 @@ def test_copy_out_iter(conn, format):
     assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
 
 
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+def test_copy_out_param(conn, ph, params):
+    cur = conn.cursor()
+    with cur.copy(
+        f"copy (select * from generate_series(1, {ph})) to stdout", params
+    ) as copy:
+        copy.set_types(["int4"])
+        assert list(copy.rows()) == [(i + 1,) for i in range(10)]
+
+    assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
 @pytest.mark.parametrize("format", Format)
 @pytest.mark.parametrize("typetype", ["names", "oids"])
 def test_read_rows(conn, format, typetype):
index 0c0683da84015930888e19e3411c8e2add3c48b3..c5df84f03ff09b854a716dffcd72caeb07649d02 100644 (file)
@@ -17,7 +17,7 @@ from psycopg.adapt import PyFormat
 from psycopg.types.hstore import register_hstore
 from psycopg.types.numeric import Int4
 
-from .utils import gc_collect
+from .utils import alist, gc_collect
 from .test_copy import sample_text, sample_binary, sample_binary_rows  # noqa
 from .test_copy import eur, sample_values, sample_records, sample_tabledef
 from .test_copy import py_to_raw
@@ -64,6 +64,18 @@ async def test_copy_out_iter(aconn, format):
     assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
 
 
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+async def test_copy_out_param(aconn, ph, params):
+    cur = aconn.cursor()
+    async with cur.copy(
+        f"copy (select * from generate_series(1, {ph})) to stdout", params
+    ) as copy:
+        copy.set_types(["int4"])
+        assert await alist(copy.rows()) == [(i + 1,) for i in range(10)]
+
+    assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
 @pytest.mark.parametrize("format", Format)
 @pytest.mark.parametrize("typetype", ["names", "oids"])
 async def test_read_rows(aconn, format, typetype):
@@ -792,7 +804,3 @@ class DataGenerator:
                 block = block.encode()
             m.update(block)
         return m.hexdigest()
-
-
-async def alist(it):
-    return [i async for i in it]
index a02827f985c003126769f3e79860d352b86bb987..677df268880a69d5617d4c0ba8d94ff94eccd44c 100644 (file)
@@ -72,3 +72,7 @@ def gc_collect():
     """
     for i in range(3):
         gc.collect()
+
+
+async def alist(it):
+    return [i async for i in it]