:param statement: The copy operation to execute
:type statement: `!str`, `!bytes`, or `sql.Composable`
+ :param params: The parameters to pass to the statement, if any.
+ :type params: Sequence or Mapping
.. note::
See :ref:`copy` for information about :sql:`COPY`.
+ .. versionchanged:: 3.1
+ Added parameters support.
+
.. automethod:: stream
This command is similar to execute + iter; however it supports endless
) as copy:
# read data from the 'copy' object using read()/read_row()
+.. versionchanged:: 3.1
+
+ You can also pass parameters to `!copy()`, like in `~Cursor.execute()`:
+
+ .. code:: python
+
+ with cur.copy("COPY (SELECT * FROM table_name LIMIT %s) TO STDOUT", (3,)) as copy:
+ # expect no more than three records
+
The connection is subject to the usual transaction behaviour, so, unless the
connection is in autocommit, at the end of the COPY operation you will still
have to commit the pending changes and you can still roll them back. See
results (:ticket:`#164`).
- `~Cursor.executemany()` performance improved by using batch mode internally
(:ticket:`#145`).
+- Add parameters to `~Cursor.copy()`.
- Add `pq.PGconn.trace()` and related trace functions (:ticket:`#167`).
- Add ``prepare_threshold`` parameter to `Connection` init (:ticket:`#200`).
- Add ``cursor_factory`` parameter to `Connection` init.
from .rows import Row, RowMaker, RowFactory
from ._column import Column
from ._cmodule import _psycopg
-from ._queries import PostgresQuery
+from ._queries import PostgresQuery, PostgresClientQuery
from ._pipeline import Pipeline
from ._encodings import pgconn_encoding
from ._preparing import Prepare
self._tx = adapt.Transformer(self)
yield from self._conn._start_query()
- def _start_copy_gen(self, statement: Query) -> PQGen[None]:
+ def _start_copy_gen(
+ self, statement: Query, params: Optional[Params] = None
+ ) -> PQGen[None]:
"""Generator implementing sending a command for `Cursor.copy()."""
# The connection gets in an unrecoverable state if we attempt COPY in
raise e.NotSupportedError("COPY cannot be used in pipeline mode")
yield from self._start_query()
+
+ # Merge the params client-side
+ if params:
+ pgq = PostgresClientQuery(self._tx)
+ pgq.convert(statement, params)
+ statement = pgq.query
+
query = self._convert_query(statement)
self._execute_send(query, binary=False)
self._scroll(value, mode)
@contextmanager
- def copy(self, statement: Query) -> Iterator[Copy]:
+ def copy(self, statement: Query, params: Optional[Params] = None) -> Iterator[Copy]:
"""
Initiate a :sql:`COPY` operation and return an object to manage it.
:rtype: Copy
"""
with self._conn.lock:
- self._conn.wait(self._start_copy_gen(statement))
+ self._conn.wait(self._start_copy_gen(statement, params))
with Copy(self) as copy:
yield copy
self._scroll(value, mode)
@asynccontextmanager
- async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]:
+ async def copy(
+ self, statement: Query, params: Optional[Params] = None
+ ) -> AsyncIterator[AsyncCopy]:
"""
:rtype: AsyncCopy
"""
async with self._conn.lock:
- await self._conn.wait(self._start_copy_gen(statement))
+ await self._conn.wait(self._start_copy_gen(statement, params))
async with AsyncCopy(self) as copy:
yield copy
assert cur._query.params == (b"3", b"4")
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+def test_copy_out_param(conn, ph, params):
+ cur = conn.cursor()
+ with cur.copy(
+ f"copy (select * from generate_series(1, {ph})) to stdout", params
+ ) as copy:
+ copy.set_types(["int4"])
+ assert list(copy.rows()) == [(i + 1,) for i in range(10)]
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
def test_stream(conn):
cur = conn.cursor()
recs = []
from psycopg import sql, rows
from psycopg.adapt import PyFormat
-from .utils import gc_collect
+from .utils import alist, gc_collect
from .test_cursor import my_row_factory
from .test_cursor import execmany, _execmany # noqa: F401
assert cur._query.params == (b"3", b"4")
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+async def test_copy_out_param(aconn, ph, params):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy (select * from generate_series(1, {ph})) to stdout", params
+ ) as copy:
+ copy.set_types(["int4"])
+ assert await alist(copy.rows()) == [(i + 1,) for i in range(10)]
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
async def test_stream(aconn):
cur = aconn.cursor()
recs = []
assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+def test_copy_out_param(conn, ph, params):
+ cur = conn.cursor()
+ with cur.copy(
+ f"copy (select * from generate_series(1, {ph})) to stdout", params
+ ) as copy:
+ copy.set_types(["int4"])
+ assert list(copy.rows()) == [(i + 1,) for i in range(10)]
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
@pytest.mark.parametrize("format", Format)
@pytest.mark.parametrize("typetype", ["names", "oids"])
def test_read_rows(conn, format, typetype):
from psycopg.types.hstore import register_hstore
from psycopg.types.numeric import Int4
-from .utils import gc_collect
+from .utils import alist, gc_collect
from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa
from .test_copy import eur, sample_values, sample_records, sample_tabledef
from .test_copy import py_to_raw
assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+async def test_copy_out_param(aconn, ph, params):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy (select * from generate_series(1, {ph})) to stdout", params
+ ) as copy:
+ copy.set_types(["int4"])
+ assert await alist(copy.rows()) == [(i + 1,) for i in range(10)]
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
@pytest.mark.parametrize("format", Format)
@pytest.mark.parametrize("typetype", ["names", "oids"])
async def test_read_rows(aconn, format, typetype):
block = block.encode()
m.update(block)
return m.hexdigest()
-
-
-async def alist(it):
- return [i async for i in it]
"""
for i in range(3):
gc.collect()
+
+
+async def alist(it):
+ return [i async for i in it]