]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Sketching an interface for a copy object
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 21 Jun 2020 05:51:11 +0000 (17:51 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 21 Jun 2020 05:51:11 +0000 (17:51 +1200)
psycopg3/copy.py [new file with mode: 0644]
psycopg3/cursor.py
tests/test_copy.py [new file with mode: 0644]

diff --git a/psycopg3/copy.py b/psycopg3/copy.py
new file mode 100644 (file)
index 0000000..39b010f
--- /dev/null
@@ -0,0 +1,41 @@
+"""
+psycopg3 copy support
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Optional
+
+from .proto import AdaptContext
+from . import pq
+
+
+class BaseCopy:
+    def __init__(
+        self,
+        context: AdaptContext,
+        result: pq.proto.PGresult,
+        format: pq.Format = pq.Format.TEXT,
+    ):
+        from .transform import Transformer
+
+        self._transformer = Transformer(context)
+        self.format = format  # TODO: maybe not needed
+        self.pgresult = result
+
+    @property
+    def pgresult(self) -> Optional[pq.proto.PGresult]:
+        return self._pgresult
+
+    @pgresult.setter
+    def pgresult(self, result: Optional[pq.proto.PGresult]) -> None:
+        self._pgresult = result
+        self._transformer.pgresult = result
+
+
+class Copy(BaseCopy):
+    pass
+
+
+class AsyncCopy(BaseCopy):
+    pass
index eb59831e0057387f9ee03fe7c2592bfa0317aba2..d85f23966de0eb4867d09e0f4d908a70318bcefc 100644 (file)
@@ -12,6 +12,7 @@ from . import pq
 from . import proto
 from .proto import Query, Params, DumpersMap, LoadersMap, PQGen
 from .utils.queries import PostgresQuery
+from .copy import Copy, AsyncCopy
 
 if TYPE_CHECKING:
     from .connection import BaseConnection, Connection, AsyncConnection
@@ -154,14 +155,16 @@ class BaseCursor:
         self._reset()
         self._transformer = Transformer(self)
 
-    def _execute_send(self, query: Query, vars: Optional[Params]) -> None:
+    def _execute_send(
+        self, query: Query, vars: Optional[Params], no_pqexec: bool = False
+    ) -> None:
         """
         Implement part of execute() before waiting common to sync and async
         """
         pgq = PostgresQuery(self._transformer)
         pgq.convert(query, vars)
 
-        if pgq.params:
+        if pgq.params or no_pqexec or self.format == pq.Format.BINARY:
             self.connection.pgconn.send_query_params(
                 pgq.query,
                 pgq.params,
@@ -169,16 +172,10 @@ class BaseCursor:
                 param_types=pgq.types,
                 result_format=self.format,
             )
-
         else:
             # if we don't have to, let's use exec_ as it can run more than
             # one query in one go
-            if self.format == pq.Format.BINARY:
-                self.connection.pgconn.send_query_params(
-                    pgq.query, None, result_format=self.format
-                )
-            else:
-                self.connection.pgconn.send_query(pgq.query)
+            self.connection.pgconn.send_query(pgq.query)
 
     def _execute_results(self, results: Sequence[pq.proto.PGresult]) -> None:
         """
@@ -251,6 +248,25 @@ class BaseCursor:
                 "the last operation didn't produce a result"
             )
 
+    def _check_copy_results(
+        self, results: Sequence[pq.proto.PGresult]
+    ) -> None:
+        """
+        Check that the value returned in a copy() operation is a legit COPY.
+        """
+        if len(results) != 1:
+            raise e.InternalError(
+                f"expected 1 result from copy, got {len(results)} instead"
+            )
+
+        result = results[0]
+        status = result.status
+        if status not in (pq.ExecStatus.COPY_IN, pq.ExecStatus.COPY_OUT):
+            raise e.ProgrammingError(
+                "copy() should be used only with COPY ... TO STDOUT"
+                " or COPY ... FROM STDIN statements"
+            )
+
 
 class Cursor(BaseCursor):
     connection: "Connection"
@@ -343,6 +359,20 @@ class Cursor(BaseCursor):
         self._pos = pos
         return rv
 
+    def copy(self, statement: Query, vars: Optional[Params] = None) -> Copy:
+        with self.connection.lock:
+            self._start_query()
+            self.connection._start_query()
+            # Make sure to avoid PQexec to avoid sending a mix of COPY and
+            # other operations.
+            self._execute_send(statement, vars, no_pqexec=True)
+            gen = execute(self.connection.pgconn)
+            results = self.connection.wait(gen)
+            tx = self._transformer
+
+        self._check_copy_results(results)
+        return Copy(context=tx, result=results[0], format=self.format)
+
 
 class AsyncCursor(BaseCursor):
     connection: "AsyncConnection"
@@ -439,6 +469,22 @@ class AsyncCursor(BaseCursor):
         self._pos = pos
         return rv
 
+    async def copy(
+        self, statement: Query, vars: Optional[Params] = None
+    ) -> AsyncCopy:
+        async with self.connection.lock:
+            self._start_query()
+            await self.connection._start_query()
+            # Make sure to avoid PQexec to avoid sending a mix of COPY and
+            # other operations.
+            self._execute_send(statement, vars, no_pqexec=True)
+            gen = execute(self.connection.pgconn)
+            results = await self.connection.wait(gen)
+            tx = self._transformer
+
+        self._check_copy_results(results)
+        return AsyncCopy(context=tx, result=results[0], format=self.format)
+
 
 class NamedCursorMixin:
     pass
diff --git a/tests/test_copy.py b/tests/test_copy.py
new file mode 100644 (file)
index 0000000..69b431e
--- /dev/null
@@ -0,0 +1,119 @@
+import pytest
+
+sample_records = [(10, 20, "hello"), (40, None, "world")]
+
+sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')"
+
+sample_tabledef = "col1 int primary key, col2 int, date text"
+
+sample_text = b"""
+10\t20\thello
+40\t\\N\tworld
+"""
+
+sample_binary = """
+5047 434f 5059 0aff 0d0a 0000 0000 0000
+0000 0000 0300 0000 0400 0000 0a00 0000
+0400 0000 1400 0000 0568 656c 6c6f 0003
+0000 0004 0000 0028 ffff ffff 0000 0005
+776f 726c 64ff ff
+"""
+
+
+@pytest.mark.parametrize(
+    "format, block", [("text", sample_text), ("binary", sample_binary)]
+)
+def test_load(format, block):
+    from psycopg3.copy import Copy
+
+    copy = Copy(format=format)
+    records = copy.load(block)
+    assert records == sample_records
+
+
+@pytest.mark.parametrize(
+    "format, block", [("text", sample_text), ("binary", sample_binary)]
+)
+def test_dump(format, block):
+    from psycopg3.copy import Copy
+
+    copy = Copy(format=format)
+    assert copy.get_buffer() is None
+    for row in sample_records:
+        copy.dump(row)
+    assert copy.get_buffer() == block
+    assert copy.get_buffer() is None
+
+
+@pytest.mark.parametrize(
+    "format, block", [("text", sample_text), ("binary", sample_binary)]
+)
+def test_buffers(format, block):
+    from psycopg3.copy import Copy
+
+    copy = Copy(format=format)
+    assert list(copy.buffers(sample_records)) == [block]
+
+
+@pytest.mark.parametrize(
+    "format, want", [("text", sample_text), ("binary", sample_binary)]
+)
+def test_copy_out_read(conn, format, want):
+    cur = conn.cursor()
+    copy = cur.copy(f"copy ({sample_values}) to stdout (format {format})")
+    assert copy.read() == want
+    assert copy.read() is None
+    assert copy.read() is None
+
+
+@pytest.mark.parametrize("format", ["text", "binary"])
+def test_iter(conn, format):
+    cur = conn.cursor()
+    copy = cur.copy(f"copy ({sample_values}) to stdout (format {format})")
+    assert list(copy) == sample_records
+
+
+@pytest.mark.parametrize(
+    "format, buffer", [("text", sample_text), ("binary", sample_binary)]
+)
+def test_copy_in_buffers(conn, format, buffer):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    copy = cur.copy(f"copy copy_in from stdin (format {format})")
+    copy.write(buffer)
+    copy.end()
+    data = cur.execute("select * from copy_in order by 1").fetchall()
+    assert data == sample_records
+
+
+@pytest.mark.parametrize(
+    "format, buffer", [("text", sample_text), ("binary", sample_binary)]
+)
+def test_copy_in_buffers_with(conn, format, buffer):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    with cur.copy(f"copy copy_in from stdin (format {format})") as copy:
+        copy.write(buffer)
+
+    data = cur.execute("select * from copy_in order by 1").fetchall()
+    assert data == sample_records
+
+
+@pytest.mark.parametrize(
+    "format, buffer", [("text", sample_text), ("binary", sample_binary)]
+)
+def test_copy_in_records(conn, format, buffer):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+
+    with cur.copy(f"copy copy_in from stdin (format {format})") as copy:
+        for row in sample_records:
+            copy.write(row)
+
+    data = cur.execute("select * from copy_in order by 1").fetchall()
+    assert data == sample_records
+
+
+def ensure_table(cur, tabledef, name="copy_in"):
+    cur.execute(f"drop table if exists {name}")
+    cur.execute(f"create table {name} ({tabledef})")