From: Daniele Varrazzo Date: Sun, 21 Jun 2020 05:51:11 +0000 (+1200) Subject: Sketching an interface for a copy object X-Git-Tag: 3.0.dev0~485 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=059c635688b2fb2cd7003dc0cda014d8a032a583;p=thirdparty%2Fpsycopg.git Sketching an interface for a copy object --- diff --git a/psycopg3/copy.py b/psycopg3/copy.py new file mode 100644 index 000000000..39b010ffc --- /dev/null +++ b/psycopg3/copy.py @@ -0,0 +1,41 @@ +""" +psycopg3 copy support +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Optional + +from .proto import AdaptContext +from . import pq + + +class BaseCopy: + def __init__( + self, + context: AdaptContext, + result: pq.proto.PGresult, + format: pq.Format = pq.Format.TEXT, + ): + from .transform import Transformer + + self._transformer = Transformer(context) + self.format = format # TODO: maybe not needed + self.pgresult = result + + @property + def pgresult(self) -> Optional[pq.proto.PGresult]: + return self._pgresult + + @pgresult.setter + def pgresult(self, result: Optional[pq.proto.PGresult]) -> None: + self._pgresult = result + self._transformer.pgresult = result + + +class Copy(BaseCopy): + pass + + +class AsyncCopy(BaseCopy): + pass diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py index eb59831e0..d85f23966 100644 --- a/psycopg3/cursor.py +++ b/psycopg3/cursor.py @@ -12,6 +12,7 @@ from . import pq from . import proto from .proto import Query, Params, DumpersMap, LoadersMap, PQGen from .utils.queries import PostgresQuery +from .copy import Copy, AsyncCopy if TYPE_CHECKING: from .connection import BaseConnection, Connection, AsyncConnection @@ -154,14 +155,16 @@ class BaseCursor: self._reset() self._transformer = Transformer(self) - def _execute_send(self, query: Query, vars: Optional[Params]) -> None: + def _execute_send( + self, query: Query, vars: Optional[Params], no_pqexec: bool = False + ) -> None: """ Implement part of execute() before waiting common to sync and async """ pgq = PostgresQuery(self._transformer) pgq.convert(query, vars) - if pgq.params: + if pgq.params or no_pqexec or self.format == pq.Format.BINARY: self.connection.pgconn.send_query_params( pgq.query, pgq.params, @@ -169,16 +172,10 @@ class BaseCursor: param_types=pgq.types, result_format=self.format, ) - else: # if we don't have to, let's use exec_ as it can run more than # one query in one go - if self.format == pq.Format.BINARY: - self.connection.pgconn.send_query_params( - pgq.query, None, result_format=self.format - ) - else: - self.connection.pgconn.send_query(pgq.query) + self.connection.pgconn.send_query(pgq.query) def _execute_results(self, results: Sequence[pq.proto.PGresult]) -> None: """ @@ -251,6 +248,25 @@ class BaseCursor: "the last operation didn't produce a result" ) + def _check_copy_results( + self, results: Sequence[pq.proto.PGresult] + ) -> None: + """ + Check that the value returned in a copy() operation is a legit COPY. + """ + if len(results) != 1: + raise e.InternalError( + f"expected 1 result from copy, got {len(results)} instead" + ) + + result = results[0] + status = result.status + if status not in (pq.ExecStatus.COPY_IN, pq.ExecStatus.COPY_OUT): + raise e.ProgrammingError( + "copy() should be used only with COPY ... TO STDOUT" + " or COPY ... FROM STDIN statements" + ) + class Cursor(BaseCursor): connection: "Connection" @@ -343,6 +359,20 @@ class Cursor(BaseCursor): self._pos = pos return rv + def copy(self, statement: Query, vars: Optional[Params] = None) -> Copy: + with self.connection.lock: + self._start_query() + self.connection._start_query() + # Make sure to avoid PQexec to avoid sending a mix of COPY and + # other operations. + self._execute_send(statement, vars, no_pqexec=True) + gen = execute(self.connection.pgconn) + results = self.connection.wait(gen) + tx = self._transformer + + self._check_copy_results(results) + return Copy(context=tx, result=results[0], format=self.format) + class AsyncCursor(BaseCursor): connection: "AsyncConnection" @@ -439,6 +469,22 @@ class AsyncCursor(BaseCursor): self._pos = pos return rv + async def copy( + self, statement: Query, vars: Optional[Params] = None + ) -> AsyncCopy: + async with self.connection.lock: + self._start_query() + await self.connection._start_query() + # Make sure to avoid PQexec to avoid sending a mix of COPY and + # other operations. + self._execute_send(statement, vars, no_pqexec=True) + gen = execute(self.connection.pgconn) + results = await self.connection.wait(gen) + tx = self._transformer + + self._check_copy_results(results) + return AsyncCopy(context=tx, result=results[0], format=self.format) + class NamedCursorMixin: pass diff --git a/tests/test_copy.py b/tests/test_copy.py new file mode 100644 index 000000000..69b431efa --- /dev/null +++ b/tests/test_copy.py @@ -0,0 +1,119 @@ +import pytest + +sample_records = [(10, 20, "hello"), (40, None, "world")] + +sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')" + +sample_tabledef = "col1 int primary key, col2 int, date text" + +sample_text = b""" +10\t20\thello +40\t\\N\tworld +""" + +sample_binary = """ +5047 434f 5059 0aff 0d0a 0000 0000 0000 +0000 0000 0300 0000 0400 0000 0a00 0000 +0400 0000 1400 0000 0568 656c 6c6f 0003 +0000 0004 0000 0028 ffff ffff 0000 0005 +776f 726c 64ff ff +""" + + +@pytest.mark.parametrize( + "format, block", [("text", sample_text), ("binary", sample_binary)] +) +def test_load(format, block): + from psycopg3.copy import Copy + + copy = Copy(format=format) + records = copy.load(block) + assert records == sample_records + + +@pytest.mark.parametrize( + "format, block", [("text", sample_text), ("binary", sample_binary)] +) +def test_dump(format, block): + from psycopg3.copy import Copy + + copy = Copy(format=format) + assert copy.get_buffer() is None + for row in sample_records: + copy.dump(row) + assert copy.get_buffer() == block + assert copy.get_buffer() is None + + +@pytest.mark.parametrize( + "format, block", [("text", sample_text), ("binary", sample_binary)] +) +def test_buffers(format, block): + from psycopg3.copy import Copy + + copy = Copy(format=format) + assert list(copy.buffers(sample_records)) == [block] + + +@pytest.mark.parametrize( + "format, want", [("text", sample_text), ("binary", sample_binary)] +) +def test_copy_out_read(conn, format, want): + cur = conn.cursor() + copy = cur.copy(f"copy ({sample_values}) to stdout (format {format})") + assert copy.read() == want + assert copy.read() is None + assert copy.read() is None + + +@pytest.mark.parametrize("format", ["text", "binary"]) +def test_iter(conn, format): + cur = conn.cursor() + copy = cur.copy(f"copy ({sample_values}) to stdout (format {format})") + assert list(copy) == sample_records + + +@pytest.mark.parametrize( + "format, buffer", [("text", sample_text), ("binary", sample_binary)] +) +def test_copy_in_buffers(conn, format, buffer): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + copy = cur.copy(f"copy copy_in from stdin (format {format})") + copy.write(buffer) + copy.end() + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize( + "format, buffer", [("text", sample_text), ("binary", sample_binary)] +) +def test_copy_in_buffers_with(conn, format, buffer): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin (format {format})") as copy: + copy.write(buffer) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize( + "format, buffer", [("text", sample_text), ("binary", sample_binary)] +) +def test_copy_in_records(conn, format, buffer): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in from stdin (format {format})") as copy: + for row in sample_records: + copy.write(row) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +def ensure_table(cur, tabledef, name="copy_in"): + cur.execute(f"drop table if exists {name}") + cur.execute(f"create table {name} ({tabledef})")