]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Allow passing a str buffer to copy
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 24 Jun 2020 08:35:58 +0000 (20:35 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 24 Jun 2020 08:36:29 +0000 (20:36 +1200)
psycopg3/copy.py
tests/test_copy.py
tests/test_copy_async.py

index 026da51e5a1898488c7e43a3608a7474bbed1f42..96a2604e1beceb41b3d04ed86fdb11207bb1e0a1 100644 (file)
@@ -5,8 +5,9 @@ psycopg3 copy support
 # Copyright (C) 2020 The Psycopg Team
 
 import re
-from typing import cast, TYPE_CHECKING, AsyncGenerator, Generator
-from typing import Any, Deque, Dict, List, Match, Optional, Tuple, Type
+import codecs
+from typing import TYPE_CHECKING, AsyncGenerator, Generator
+from typing import Any, Deque, Dict, List, Match, Optional, Tuple, Type, Union
 from types import TracebackType
 from collections import deque
 
@@ -16,10 +17,12 @@ from .proto import AdaptContext
 from .generators import copy_from, copy_to, copy_end
 
 if TYPE_CHECKING:
-    from .connection import Connection, AsyncConnection
+    from .connection import BaseConnection, Connection, AsyncConnection
 
 
 class BaseCopy:
+    _connection: Optional["BaseConnection"]
+
     def __init__(
         self,
         context: AdaptContext,
@@ -28,6 +31,7 @@ class BaseCopy:
     ):
         from .adapt import Transformer
 
+        self._connection = None
         self._transformer = Transformer(context)
         self.format = format
         self.pgresult = result
@@ -35,11 +39,23 @@ class BaseCopy:
 
         self._partial: Deque[bytes] = deque()
         self._header_seen = False
+        self._codec: Optional[codecs.CodecInfo] = None
 
     @property
     def finished(self) -> bool:
         return self._finished
 
+    @property
+    def connection(self) -> "BaseConnection":
+        if self._connection is not None:
+            return self._connection
+
+        self._connection = conn = self._transformer.connection
+        if conn is not None:
+            return conn
+
+        raise ValueError("no connection available")
+
     @property
     def pgresult(self) -> Optional[pq.proto.PGresult]:
         return self._pgresult
@@ -97,6 +113,24 @@ class BaseCopy:
     def _load_binary(self, buffer: bytes) -> List[Tuple[Any, ...]]:
         raise NotImplementedError
 
+    def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
+        if isinstance(data, bytes):
+            return data
+
+        elif isinstance(data, str):
+            if self._codec is not None:
+                return self._codec.encode(data)[0]
+
+            if (
+                self.pgresult is None
+                or self.pgresult.binary_tuples == pq.Format.BINARY
+            ):
+                raise TypeError(
+                    "cannot copy str data in binary mode: use bytes instead"
+                )
+            self._codec = self.connection.codec
+            return self._codec.encode(data)[0]
+
 
 def _bsrepl_sub(
     m: Match[bytes],
@@ -117,24 +151,13 @@ _bsrepl_re = re.compile(rb"\\(.)")
 
 
 class Copy(BaseCopy):
-    def __init__(
-        self,
-        context: AdaptContext,
-        result: Optional[pq.proto.PGresult],
-        format: pq.Format = pq.Format.TEXT,
-    ):
-        super().__init__(context=context, result=result, format=format)
-        self._connection: Optional["Connection"] = None
+    _connection: Optional["Connection"]
 
     @property
     def connection(self) -> "Connection":
-        if self._connection is None:
-            conn = self._transformer.connection
-            if conn is None:
-                raise ValueError("no connection available")
-            self._connection = cast("Connection", conn)
-
-        return self._connection
+        # TODO: mypy error: "Callable[[BaseCopy], BaseConnection]" has no
+        # attribute "fget"
+        return BaseCopy.connection.fget(self)  # type: ignore
 
     def read(self) -> Optional[bytes]:
         if self._finished:
@@ -147,9 +170,9 @@ class Copy(BaseCopy):
 
         return rv
 
-    def write(self, buffer: bytes) -> None:
+    def write(self, buffer: Union[str, bytes]) -> None:
         conn = self.connection
-        conn.wait(copy_to(conn.pgconn, buffer))
+        conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer)))
 
     def finish(self, error: Optional[str] = None) -> None:
         conn = self.connection
@@ -184,24 +207,11 @@ class Copy(BaseCopy):
 
 
 class AsyncCopy(BaseCopy):
-    def __init__(
-        self,
-        context: AdaptContext,
-        result: Optional[pq.proto.PGresult],
-        format: pq.Format = pq.Format.TEXT,
-    ):
-        super().__init__(context=context, result=result, format=format)
-        self._connection: Optional["AsyncConnection"] = None
+    _connection: Optional["AsyncConnection"]
 
     @property
     def connection(self) -> "AsyncConnection":
-        if self._connection is None:
-            conn = self._transformer.connection
-            if conn is None:
-                raise ValueError("no connection available")
-            self._connection = cast("AsyncConnection", conn)
-
-        return self._connection
+        return BaseCopy.connection.fget(self)  # type: ignore
 
     async def read(self) -> Optional[bytes]:
         if self._finished:
@@ -214,9 +224,9 @@ class AsyncCopy(BaseCopy):
 
         return rv
 
-    async def write(self, buffer: bytes) -> None:
+    async def write(self, buffer: Union[str, bytes]) -> None:
         conn = self.connection
-        await conn.wait(copy_to(conn.pgconn, buffer))
+        await conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer)))
 
     async def finish(self, error: Optional[str] = None) -> None:
         conn = self.connection
index 2f6173a384a5a13dd4aaf215bad12f18091d068a..3c690a8b84f22bcd2c2459ec5f4901ec885b552f 100644 (file)
@@ -171,6 +171,26 @@ def test_copy_in_buffers_with(conn, format, buffer):
     assert data == sample_records
 
 
+def test_copy_in_str(conn):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    with cur.copy("copy copy_in from stdin (format text)") as copy:
+        copy.write(sample_text.decode("utf8"))
+
+    data = cur.execute("select * from copy_in order by 1").fetchall()
+    assert data == sample_records
+
+
+def test_copy_in_str_binary(conn):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    with pytest.raises(e.QueryCanceled):
+        with cur.copy("copy copy_in from stdin (format binary)") as copy:
+            copy.write(sample_text.decode("utf8"))
+
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+
+
 def test_copy_in_buffers_with_pg_error(conn):
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
index fff8bd03498d5a408c97a8c7a26758ae68ff323a..cce63aab5bea431c72d3fa3b941f27eb70cc72d6 100644 (file)
@@ -89,6 +89,31 @@ async def test_copy_in_buffers_with(aconn, format, buffer):
     assert data == sample_records
 
 
+async def test_copy_in_str(aconn):
+    cur = aconn.cursor()
+    await ensure_table(cur, sample_tabledef)
+    async with (
+        await cur.copy("copy copy_in from stdin (format text)")
+    ) as copy:
+        await copy.write(sample_text.decode("utf8"))
+
+    await cur.execute("select * from copy_in order by 1")
+    data = await cur.fetchall()
+    assert data == sample_records
+
+
+async def test_copy_in_str_binary(aconn):
+    cur = aconn.cursor()
+    await ensure_table(cur, sample_tabledef)
+    with pytest.raises(e.QueryCanceled):
+        async with (
+            await cur.copy("copy copy_in from stdin (format binary)")
+        ) as copy:
+            await copy.write(sample_text.decode("utf8"))
+
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+
+
 async def test_copy_in_buffers_with_pg_error(aconn):
     cur = aconn.cursor()
     await ensure_table(cur, sample_tabledef)