from . import proto
from .proto import Query, Params, DumpersMap, LoadersMap, PQGen
from .utils.queries import PostgresQuery
+from .copy import Copy, AsyncCopy
if TYPE_CHECKING:
from .connection import BaseConnection, Connection, AsyncConnection
self._reset()
self._transformer = Transformer(self)
- def _execute_send(self, query: Query, vars: Optional[Params]) -> None:
+ def _execute_send(
+ self, query: Query, vars: Optional[Params], no_pqexec: bool = False
+ ) -> None:
"""
Implement part of execute() before waiting common to sync and async
"""
pgq = PostgresQuery(self._transformer)
pgq.convert(query, vars)
- if pgq.params:
+ if pgq.params or no_pqexec or self.format == pq.Format.BINARY:
self.connection.pgconn.send_query_params(
pgq.query,
pgq.params,
param_types=pgq.types,
result_format=self.format,
)
-
else:
# if we don't have to, let's use exec_ as it can run more than
# one query in one go
- if self.format == pq.Format.BINARY:
- self.connection.pgconn.send_query_params(
- pgq.query, None, result_format=self.format
- )
- else:
- self.connection.pgconn.send_query(pgq.query)
+ self.connection.pgconn.send_query(pgq.query)
def _execute_results(self, results: Sequence[pq.proto.PGresult]) -> None:
"""
"the last operation didn't produce a result"
)
+ def _check_copy_results(
+ self, results: Sequence[pq.proto.PGresult]
+ ) -> None:
+ """
+ Check that the value returned in a copy() operation is a legit COPY.
+ """
+ if len(results) != 1:
+ raise e.InternalError(
+ f"expected 1 result from copy, got {len(results)} instead"
+ )
+
+ result = results[0]
+ status = result.status
+ if status not in (pq.ExecStatus.COPY_IN, pq.ExecStatus.COPY_OUT):
+ raise e.ProgrammingError(
+ "copy() should be used only with COPY ... TO STDOUT"
+ " or COPY ... FROM STDIN statements"
+ )
+
class Cursor(BaseCursor):
connection: "Connection"
self._pos = pos
return rv
+ def copy(self, statement: Query, vars: Optional[Params] = None) -> Copy:
+ with self.connection.lock:
+ self._start_query()
+ self.connection._start_query()
+ # Make sure to avoid PQexec to avoid sending a mix of COPY and
+ # other operations.
+ self._execute_send(statement, vars, no_pqexec=True)
+ gen = execute(self.connection.pgconn)
+ results = self.connection.wait(gen)
+ tx = self._transformer
+
+ self._check_copy_results(results)
+ return Copy(context=tx, result=results[0], format=self.format)
+
class AsyncCursor(BaseCursor):
connection: "AsyncConnection"
self._pos = pos
return rv
+ async def copy(
+ self, statement: Query, vars: Optional[Params] = None
+ ) -> AsyncCopy:
+ async with self.connection.lock:
+ self._start_query()
+ await self.connection._start_query()
+ # Make sure to avoid PQexec to avoid sending a mix of COPY and
+ # other operations.
+ self._execute_send(statement, vars, no_pqexec=True)
+ gen = execute(self.connection.pgconn)
+ results = await self.connection.wait(gen)
+ tx = self._transformer
+
+ self._check_copy_results(results)
+ return AsyncCopy(context=tx, result=results[0], format=self.format)
+
class NamedCursorMixin:
pass
--- /dev/null
+import pytest
+
+sample_records = [(10, 20, "hello"), (40, None, "world")]
+
+sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')"
+
+sample_tabledef = "col1 int primary key, col2 int, date text"
+
+sample_text = b"""
+10\t20\thello
+40\t\\N\tworld
+"""
+
+sample_binary = """
+5047 434f 5059 0aff 0d0a 0000 0000 0000
+0000 0000 0300 0000 0400 0000 0a00 0000
+0400 0000 1400 0000 0568 656c 6c6f 0003
+0000 0004 0000 0028 ffff ffff 0000 0005
+776f 726c 64ff ff
+"""
+
+
+@pytest.mark.parametrize(
+ "format, block", [("text", sample_text), ("binary", sample_binary)]
+)
+def test_load(format, block):
+ from psycopg3.copy import Copy
+
+ copy = Copy(format=format)
+ records = copy.load(block)
+ assert records == sample_records
+
+
+@pytest.mark.parametrize(
+ "format, block", [("text", sample_text), ("binary", sample_binary)]
+)
+def test_dump(format, block):
+ from psycopg3.copy import Copy
+
+ copy = Copy(format=format)
+ assert copy.get_buffer() is None
+ for row in sample_records:
+ copy.dump(row)
+ assert copy.get_buffer() == block
+ assert copy.get_buffer() is None
+
+
+@pytest.mark.parametrize(
+ "format, block", [("text", sample_text), ("binary", sample_binary)]
+)
+def test_buffers(format, block):
+ from psycopg3.copy import Copy
+
+ copy = Copy(format=format)
+ assert list(copy.buffers(sample_records)) == [block]
+
+
+@pytest.mark.parametrize(
+ "format, want", [("text", sample_text), ("binary", sample_binary)]
+)
+def test_copy_out_read(conn, format, want):
+ cur = conn.cursor()
+ copy = cur.copy(f"copy ({sample_values}) to stdout (format {format})")
+ assert copy.read() == want
+ assert copy.read() is None
+ assert copy.read() is None
+
+
+@pytest.mark.parametrize("format", ["text", "binary"])
+def test_iter(conn, format):
+ cur = conn.cursor()
+ copy = cur.copy(f"copy ({sample_values}) to stdout (format {format})")
+ assert list(copy) == sample_records
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [("text", sample_text), ("binary", sample_binary)]
+)
+def test_copy_in_buffers(conn, format, buffer):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ copy = cur.copy(f"copy copy_in from stdin (format {format})")
+ copy.write(buffer)
+ copy.end()
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [("text", sample_text), ("binary", sample_binary)]
+)
+def test_copy_in_buffers_with(conn, format, buffer):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(f"copy copy_in from stdin (format {format})") as copy:
+ copy.write(buffer)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [("text", sample_text), ("binary", sample_binary)]
+)
+def test_copy_in_records(conn, format, buffer):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in from stdin (format {format})") as copy:
+ for row in sample_records:
+ copy.write(row)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+def ensure_table(cur, tabledef, name="copy_in"):
+ cur.execute(f"drop table if exists {name}")
+ cur.execute(f"create table {name} ({tabledef})")