]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Declare all loaders to receive an object supporting the buffer interface
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 16 Jan 2021 02:06:56 +0000 (03:06 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 16 Jan 2021 10:26:19 +0000 (11:26 +0100)
There isn't in mypy such an object, so just use `Union[bytes, bytearray,
memoryview]`

Avoid a memory copy passing data from the Transformer to the loaders.

I had started exercising this code path by adding copy tests, but after
dropping the extra copy the same path is now exercised by any select, so
I've stopped doing that.

19 files changed:
psycopg3/psycopg3/adapt.py
psycopg3/psycopg3/proto.py
psycopg3/psycopg3/types/array.py
psycopg3/psycopg3/types/composite.py
psycopg3/psycopg3/types/date.py
psycopg3/psycopg3/types/json.py
psycopg3/psycopg3/types/network.py
psycopg3/psycopg3/types/numeric.py
psycopg3/psycopg3/types/range.py
psycopg3/psycopg3/types/singletons.py
psycopg3/psycopg3/types/text.py
psycopg3/psycopg3/types/uuid.py
psycopg3_c/psycopg3_c/_psycopg3/transform.pyx
tests/test_adapt.py
tests/types/test_array.py
tests/types/test_json.py
tests/types/test_network.py
tests/types/test_text.py
tests/types/test_uuid.py

index aa16e251f7d043991fc1ab3f7404ef7d0a25d7a8..ffe90431ee063c1de5cc915ec672d76b9fcf72d0 100644 (file)
@@ -12,7 +12,7 @@ from . import pq
 from . import proto
 from ._enums import Format as Format
 from .oids import builtins
-from .proto import AdaptContext
+from .proto import AdaptContext, Buffer as Buffer
 
 if TYPE_CHECKING:
     from .connection import BaseConnection
@@ -41,13 +41,11 @@ class Dumper(ABC):
         """The oid to pass to the server, if known."""
 
     @abstractmethod
-    def dump(self, obj: Any) -> bytes:
+    def dump(self, obj: Any) -> Buffer:
         """Convert the object *obj* to PostgreSQL representation."""
         ...
 
-    # TODO: the protocol signature should probably return a Buffer like object
-    # (the C implementation may return bytearray)
-    def quote(self, obj: Any) -> bytes:
+    def quote(self, obj: Any) -> Buffer:
         """Convert the object *obj* to escaped representation."""
         value = self.dump(obj)
 
@@ -83,7 +81,7 @@ class Loader(ABC):
         )
 
     @abstractmethod
-    def load(self, data: bytes) -> Any:
+    def load(self, data: Buffer) -> Any:
         """Convert a PostgreSQL value to a Python object."""
         ...
 
@@ -179,7 +177,6 @@ class AdaptersMap(AdaptContext):
 
         Return None if not found.
         """
-        # TODO: auto selection
         if format == Format.AUTO:
             dmaps = [
                 self._dumpers[pq.Format.BINARY],
index 3c0b66748bcbedc36a2430c6318c43814f3a2d31..81e4061acc903b6958ecb1747fc7e2647588e05e 100644 (file)
@@ -18,6 +18,9 @@ if TYPE_CHECKING:
     from .waiting import Wait, Ready
     from .sql import Composable
 
+# An object implementing the buffer protocol
+Buffer = Union[bytes, bytearray, memoryview]
+
 Query = Union[str, bytes, "Composable"]
 Params = Union[Sequence[Any], Mapping[str, Any]]
 ConnectionType = TypeVar("ConnectionType", bound="BaseConnection")
index 0a931022e5517a2bd175b810e8e8e56d63f00101..75ae3895ca0fb1489a994f6fe5c44b8d2738e439 100644 (file)
@@ -12,7 +12,7 @@ from .. import pq
 from .._enums import Format
 from .. import errors as e
 from ..oids import builtins, TEXT_OID, TEXT_ARRAY_OID, INVALID_OID
-from ..adapt import Dumper, Loader, Transformer
+from ..adapt import Buffer, Dumper, Loader, Transformer
 from ..proto import AdaptContext
 
 
@@ -178,7 +178,7 @@ class ArrayLoader(BaseArrayLoader):
         """
     )
 
-    def load(self, data: bytes) -> List[Any]:
+    def load(self, data: Buffer) -> List[Any]:
         rv = None
         stack: List[Any] = []
         cast = self._tx.get_loader(self.base_oid, self.format).load
@@ -230,7 +230,7 @@ class ArrayBinaryLoader(BaseArrayLoader):
 
     format = pq.Format.BINARY
 
-    def load(self, data: bytes) -> List[Any]:
+    def load(self, data: Buffer) -> List[Any]:
         ndims, hasnull, oid = _struct_head.unpack_from(data[:12])
         if not ndims:
             return []
index 2640cc827e34711e0d6943c3d76348f8e3e49d5e..074bfa38c363681b196c1b5348fe777d42ad7fb7 100644 (file)
@@ -14,7 +14,7 @@ from .. import pq
 from .. import sql
 from .. import errors as e
 from ..oids import TypeInfo, TEXT_OID
-from ..adapt import Format, Dumper, Loader, Transformer
+from ..adapt import Buffer, Format, Dumper, Loader, Transformer
 from ..proto import AdaptContext
 from . import array
 
@@ -232,7 +232,7 @@ class BaseCompositeLoader(Loader):
 
 
 class RecordLoader(BaseCompositeLoader):
-    def load(self, data: bytes) -> Tuple[Any, ...]:
+    def load(self, data: Buffer) -> Tuple[Any, ...]:
         if data == b"()":
             return ()
 
@@ -256,7 +256,7 @@ class RecordBinaryLoader(Loader):
         super().__init__(oid, context)
         self._tx = Transformer(context)
 
-    def load(self, data: bytes) -> Tuple[Any, ...]:
+    def load(self, data: Buffer) -> Tuple[Any, ...]:
         if not self._types_set:
             self._config_types(data)
             self._types_set = True
@@ -291,7 +291,7 @@ class CompositeLoader(RecordLoader):
     fields_types: List[int]
     _types_set = False
 
-    def load(self, data: bytes) -> Any:
+    def load(self, data: Buffer) -> Any:
         if not self._types_set:
             self._config_types(data)
             self._types_set = True
@@ -314,6 +314,6 @@ class CompositeBinaryLoader(RecordBinaryLoader):
     format = pq.Format.BINARY
     factory: Callable[..., Any]
 
-    def load(self, data: bytes) -> Any:
+    def load(self, data: Buffer) -> Any:
         r = super().load(data)
         return type(self).factory(*r)
index 66738cad3c266b8aa94a49513db58eb154d8242f..8e5a90380e35f21eeb33e58e642887b37face724 100644 (file)
@@ -11,7 +11,7 @@ from typing import cast, Optional
 
 from ..pq import Format
 from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
 from ..proto import AdaptContext
 from ..errors import InterfaceError, DataError
 
@@ -82,7 +82,9 @@ class DateLoader(Loader):
         super().__init__(oid, context)
         self._format = self._format_from_context()
 
-    def load(self, data: bytes) -> date:
+    def load(self, data: Buffer) -> date:
+        if isinstance(data, memoryview):
+            data = bytes(data)
         try:
             return datetime.strptime(data.decode("utf8"), self._format).date()
         except ValueError as e:
@@ -140,8 +142,10 @@ class TimeLoader(Loader):
     _format = "%H:%M:%S.%f"
     _format_no_micro = _format.replace(".%f", "")
 
-    def load(self, data: bytes) -> time:
+    def load(self, data: Buffer) -> time:
         # check if the data contains microseconds
+        if isinstance(data, memoryview):
+            data = bytes(data)
         fmt = self._format if b"." in data else self._format_no_micro
         try:
             return datetime.strptime(data.decode("utf8"), fmt).time()
@@ -171,7 +175,10 @@ class TimeTzLoader(TimeLoader):
 
         super().__init__(oid, context)
 
-    def load(self, data: bytes) -> time:
+    def load(self, data: Buffer) -> time:
+        if isinstance(data, memoryview):
+            data = bytes(data)
+
         # Hack to convert +HH in +HHMM
         if data[-3] in (43, 45):
             data += b"00"
@@ -184,7 +191,9 @@ class TimeTzLoader(TimeLoader):
 
         return dt.time().replace(tzinfo=dt.tzinfo)
 
-    def _load_py36(self, data: bytes) -> time:
+    def _load_py36(self, data: Buffer) -> time:
+        if isinstance(data, memoryview):
+            data = bytes(data)
         # Drop seconds from timezone for Python 3.6
         # Also, Python 3.6 doesn't support HHMM, only HH:MM
         if data[-6] in (43, 45):  # +-HH:MM -> +-HHMM
@@ -203,7 +212,10 @@ class TimestampLoader(DateLoader):
         super().__init__(oid, context)
         self._format_no_micro = self._format.replace(".%f", "")
 
-    def load(self, data: bytes) -> datetime:
+    def load(self, data: Buffer) -> datetime:
+        if isinstance(data, memoryview):
+            data = bytes(data)
+
         # check if the data contains microseconds
         fmt = (
             self._format if data.find(b".", 19) >= 0 else self._format_no_micro
@@ -285,14 +297,19 @@ class TimestamptzLoader(TimestampLoader):
             setattr(self, "load", self._load_notimpl)
             return ""
 
-    def load(self, data: bytes) -> datetime:
+    def load(self, data: Buffer) -> datetime:
+        if isinstance(data, memoryview):
+            data = bytes(data)
+
         # Hack to convert +HH in +HHMM
         if data[-3] in (43, 45):
             data += b"00"
 
         return super().load(data)
 
-    def _load_py36(self, data: bytes) -> datetime:
+    def _load_py36(self, data: Buffer) -> datetime:
+        if isinstance(data, memoryview):
+            data = bytes(data)
         # Drop seconds from timezone for Python 3.6
         # Also, Python 3.6 doesn't support HHMM, only HH:MM
         tzsep = (43, 45)  # + and - bytes
@@ -305,7 +322,9 @@ class TimestamptzLoader(TimestampLoader):
 
         return super().load(data)
 
-    def _load_notimpl(self, data: bytes) -> datetime:
+    def _load_notimpl(self, data: Buffer) -> datetime:
+        if isinstance(data, memoryview):
+            data = bytes(data)
         raise NotImplementedError(
             "can't parse datetimetz with DateStyle"
             f" {self._get_datestyle().decode('ascii')}: {data.decode('ascii')}"
@@ -337,7 +356,7 @@ class IntervalLoader(Loader):
             if ints != b"postgres":
                 setattr(self, "load", self._load_notimpl)
 
-    def load(self, data: bytes) -> timedelta:
+    def load(self, data: Buffer) -> timedelta:
         m = self._re_interval.match(data)
         if not m:
             raise ValueError("can't parse interval: {data.decode('ascii')}")
@@ -371,7 +390,9 @@ class IntervalLoader(Loader):
         except OverflowError as e:
             raise DataError(str(e))
 
-    def _load_notimpl(self, data: bytes) -> timedelta:
+    def _load_notimpl(self, data: Buffer) -> timedelta:
+        if isinstance(data, memoryview):
+            data = bytes(data)
         ints = (
             self.connection
             and self.connection.pgconn.parameter_status(b"IntervalStyle")
index 917e2dbb1d8113737b166ce06488c2b0d33a4e8b..db70e8afdf43f98dc6c8cf86cb281eb0f1921373 100644 (file)
@@ -9,7 +9,7 @@ from typing import Any, Callable, Optional
 
 from ..pq import Format
 from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
 from ..errors import DataError
 
 JsonDumpsFunction = Callable[[Any], str]
@@ -77,7 +77,10 @@ class JsonLoader(Loader):
 
     format = Format.TEXT
 
-    def load(self, data: bytes) -> Any:
+    def load(self, data: Buffer) -> Any:
+        # Json crashes on memoryview
+        if isinstance(data, memoryview):
+            data = bytes(data)
         return json.loads(data)
 
 
@@ -90,7 +93,10 @@ class JsonbBinaryLoader(Loader):
 
     format = Format.BINARY
 
-    def load(self, data: bytes) -> Any:
+    def load(self, data: Buffer) -> Any:
         if data and data[0] != 1:
             raise DataError("unknown jsonb binary format: {data[0]}")
-        return json.loads(data[1:])
+        data = data[1:]
+        if isinstance(data, memoryview):
+            data = bytes(data)
+        return json.loads(data)
index bbe4913c945eaa73c981dda6b75cbf4652879ff8..a1b62605381a463239da8f242d59660d8ba9a167 100644 (file)
@@ -8,7 +8,7 @@ from typing import Callable, Optional, Union, TYPE_CHECKING
 
 from ..pq import Format
 from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
 from ..proto import AdaptContext
 
 if TYPE_CHECKING:
@@ -57,7 +57,10 @@ class InetLoader(_LazyIpaddress):
 
     format = Format.TEXT
 
-    def load(self, data: bytes) -> Union[Address, Interface]:
+    def load(self, data: Buffer) -> Union[Address, Interface]:
+        if isinstance(data, memoryview):
+            data = bytes(data)
+
         if b"/" in data:
             return ip_interface(data.decode("utf8"))
         else:
@@ -68,5 +71,8 @@ class CidrLoader(_LazyIpaddress):
 
     format = Format.TEXT
 
-    def load(self, data: bytes) -> Network:
+    def load(self, data: Buffer) -> Network:
+        if isinstance(data, memoryview):
+            data = bytes(data)
+
         return ip_network(data.decode("utf8"))
index 3fd829965c0b30784bd776bc82399da6e3a70c21..9d90e4794562fd0c9a7804836ee83435ac279ffc 100644 (file)
@@ -10,7 +10,7 @@ from decimal import Decimal
 
 from ..pq import Format
 from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
 
 _PackInt = Callable[[int], bytes]
 _PackFloat = Callable[[float], bytes]
@@ -166,7 +166,7 @@ class IntLoader(Loader):
 
     format = Format.TEXT
 
-    def load(self, data: bytes) -> int:
+    def load(self, data: Buffer) -> int:
         # it supports bytes directly
         return int(data)
 
@@ -175,7 +175,7 @@ class Int2BinaryLoader(Loader):
 
     format = Format.BINARY
 
-    def load(self, data: bytes) -> int:
+    def load(self, data: Buffer) -> int:
         return _unpack_int2(data)[0]
 
 
@@ -183,7 +183,7 @@ class Int4BinaryLoader(Loader):
 
     format = Format.BINARY
 
-    def load(self, data: bytes) -> int:
+    def load(self, data: Buffer) -> int:
         return _unpack_int4(data)[0]
 
 
@@ -191,7 +191,7 @@ class Int8BinaryLoader(Loader):
 
     format = Format.BINARY
 
-    def load(self, data: bytes) -> int:
+    def load(self, data: Buffer) -> int:
         return _unpack_int8(data)[0]
 
 
@@ -199,7 +199,7 @@ class OidBinaryLoader(Loader):
 
     format = Format.BINARY
 
-    def load(self, data: bytes) -> int:
+    def load(self, data: Buffer) -> int:
         return _unpack_uint4(data)[0]
 
 
@@ -207,7 +207,7 @@ class FloatLoader(Loader):
 
     format = Format.TEXT
 
-    def load(self, data: bytes) -> float:
+    def load(self, data: Buffer) -> float:
         # it supports bytes directly
         return float(data)
 
@@ -216,7 +216,7 @@ class Float4BinaryLoader(Loader):
 
     format = Format.BINARY
 
-    def load(self, data: bytes) -> float:
+    def load(self, data: Buffer) -> float:
         return _unpack_float4(data)[0]
 
 
@@ -224,7 +224,7 @@ class Float8BinaryLoader(Loader):
 
     format = Format.BINARY
 
-    def load(self, data: bytes) -> float:
+    def load(self, data: Buffer) -> float:
         return _unpack_float8(data)[0]
 
 
@@ -232,5 +232,7 @@ class NumericLoader(Loader):
 
     format = Format.TEXT
 
-    def load(self, data: bytes) -> Decimal:
+    def load(self, data: Buffer) -> Decimal:
+        if isinstance(data, memoryview):
+            data = bytes(data)
         return Decimal(data.decode("utf8"))
index 2ae531ae8cb2e165d5585d2faca33ba242bf19c8..0a4ea9f8ff75a8e66ed56a045f309057fd614193 100644 (file)
@@ -14,7 +14,7 @@ from .. import sql
 from .. import errors as e
 from ..pq import Format
 from ..oids import builtins, TypeInfo
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
 from ..proto import AdaptContext
 
 from . import array
@@ -245,12 +245,12 @@ class RangeLoader(BaseCompositeLoader, Generic[T]):
     subtype_oid: int
     cls: Type[Range[T]]
 
-    def load(self, data: bytes) -> Range[T]:
+    def load(self, data: Buffer) -> Range[T]:
         if data == b"empty":
             return self.cls(empty=True)
 
         cast = self._tx.get_loader(self.subtype_oid, format=Format.TEXT).load
-        bounds = (data[:1] + data[-1:]).decode("utf-8")
+        bounds = _int2parens[data[0]] + _int2parens[data[-1]]
         min, max = (
             cast(token) if token is not None else None
             for token in self._parse_record(data[1:-1])
@@ -258,6 +258,9 @@ class RangeLoader(BaseCompositeLoader, Generic[T]):
         return self.cls(min, max, bounds)
 
 
+_int2parens = {ord(c): c for c in "[]()"}
+
+
 # Python wrappers for builtin range types
 
 
index 2f10577cb8e1b4ab76cad7f88b1f516bd2ecd907..0b52f3989d1d06f052b51738c9eab5af198fab4f 100644 (file)
@@ -6,7 +6,7 @@ Adapters for None and boolean.
 
 from ..pq import Format
 from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
 
 
 class BoolDumper(Dumper):
@@ -49,7 +49,7 @@ class BoolLoader(Loader):
 
     format = Format.TEXT
 
-    def load(self, data: bytes) -> bool:
+    def load(self, data: Buffer) -> bool:
         return data == b"t"
 
 
@@ -57,5 +57,5 @@ class BoolBinaryLoader(Loader):
 
     format = Format.BINARY
 
-    def load(self, data: bytes) -> bool:
+    def load(self, data: Buffer) -> bool:
         return data != b"\x00"
index 47b676f8cbc9a639b0467aedde84743c40aab4e5..37010c9fc43b364948d65bfd8b782b4c4f39be40 100644 (file)
@@ -8,7 +8,7 @@ from typing import Optional, Union, TYPE_CHECKING
 
 from ..pq import Format, Escaping
 from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
 from ..proto import AdaptContext
 from ..errors import DataError
 
@@ -65,7 +65,7 @@ class TextLoader(Loader):
             enc = conn.client_encoding
             self._encoding = enc if enc != "ascii" else ""
 
-    def load(self, data: bytes) -> Union[bytes, str]:
+    def load(self, data: Buffer) -> Union[bytes, str]:
         if self._encoding:
             if isinstance(data, memoryview):
                 return bytes(data).decode(self._encoding)
@@ -120,7 +120,7 @@ class ByteaLoader(Loader):
         if not hasattr(self.__class__, "_escaping"):
             self.__class__._escaping = Escaping()
 
-    def load(self, data: bytes) -> bytes:
+    def load(self, data: Buffer) -> bytes:
         return self._escaping.unescape_bytea(data)
 
 
@@ -128,5 +128,5 @@ class ByteaBinaryLoader(Loader):
 
     format = Format.BINARY
 
-    def load(self, data: bytes) -> bytes:
+    def load(self, data: Buffer) -> bytes:
         return data
index ea833f2c1e7015a71e5069f31b8d0d02316272cb..4311ba24c213c5269c39bbb639d8dfd17fd389c9 100644 (file)
@@ -8,7 +8,7 @@ from typing import Callable, Optional, TYPE_CHECKING
 
 from ..pq import Format
 from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
 from ..proto import AdaptContext
 
 if TYPE_CHECKING:
@@ -48,7 +48,9 @@ class UUIDLoader(Loader):
 
             imported = True
 
-    def load(self, data: bytes) -> "uuid.UUID":
+    def load(self, data: Buffer) -> "uuid.UUID":
+        if isinstance(data, memoryview):
+            data = bytes(data)
         return UUID(data.decode("utf8"))
 
 
@@ -56,5 +58,7 @@ class UUIDBinaryLoader(UUIDLoader):
 
     format = Format.BINARY
 
-    def load(self, data: bytes) -> "uuid.UUID":
+    def load(self, data: Buffer) -> "uuid.UUID":
+        if isinstance(data, memoryview):
+            data = bytes(data)
         return UUID(bytes=data)
index da88d933407c7fde4acb1ec9a57d4df361ca6c74..41ed1b22d36a388cefdf6e30ecf2ae4846c272c8 100644 (file)
@@ -332,8 +332,10 @@ cdef class Transformer:
                     if attval.len == -1:  # NULL_LEN
                         pyval = None
                     else:
-                        # TODO: no copy
-                        b = attval.value[:attval.len]
+                        b = PyMemoryView_FromObject(
+                            ViewBuffer._from_buffer(
+                                self._pgresult,
+                                <unsigned char *>attval.value, attval.len))
                         pyval = PyObject_CallFunctionObjArgs(
                             (<RowLoader>loader).loadfunc, <PyObject *>b, NULL)
 
@@ -371,8 +373,10 @@ cdef class Transformer:
                     pyval = (<RowLoader>loader).cloader.cload(
                         attval.value, attval.len)
                 else:
-                    # TODO: no copy
-                    b = attval.value[:attval.len]
+                    b = PyMemoryView_FromObject(
+                        ViewBuffer._from_buffer(
+                            self._pgresult,
+                            <unsigned char *>attval.value, attval.len))
                     pyval = PyObject_CallFunctionObjArgs(
                         (<RowLoader>loader).loadfunc, <PyObject *>b, NULL)
 
index 6123b98bd697106d0114bf7c93382b59894e9438..366a49b1970e2c689bebfab7704aa5fa46319083 100644 (file)
@@ -107,7 +107,7 @@ def test_subclass_loader(conn):
 
     class MyTextLoader(TextLoader):
         def load(self, data):
-            return (data * 2).decode("utf-8")
+            return (bytes(data) * 2).decode("utf-8")
 
     MyTextLoader.register("text", conn)
     assert conn.execute("select 'hello'::text").fetchone()[0] == "hellohello"
@@ -321,7 +321,7 @@ def make_loader(suffix):
         format = pq.Format.TEXT
 
         def load(self, b):
-            return b.decode("ascii") + suffix
+            return bytes(b).decode("ascii") + suffix
 
     return TestLoader
 
index a54330caef53a5e6b67ab912a786b640f7683993..46089c2781f11511ff308a182764849b65845a9b 100644 (file)
@@ -1,6 +1,7 @@
 import pytest
 import psycopg3
 from psycopg3 import pq
+from psycopg3 import sql
 from psycopg3.oids import builtins
 from psycopg3.adapt import Format, Transformer
 from psycopg3.types import array
@@ -99,6 +100,15 @@ def test_load_list_int(conn, obj, want, fmt_out):
     cur.execute("select %s::int[]", (obj,))
     assert cur.fetchone()[0] == want
 
+    stmt = sql.SQL("copy (select {}::int[]) to stdout (format {})").format(
+        obj, sql.SQL(fmt_out.name)
+    )
+    with cur.copy(stmt) as copy:
+        copy.set_types(["int4[]"])
+        (got,) = copy.read_row()
+
+    assert got == want
+
 
 def test_array_register(conn):
     cur = conn.cursor()
index 8309faee5b85e42ba72983ff6d1b6750f01dba08..bea17f3da0409c5aa892110958eb7755438f3b48 100644 (file)
@@ -4,6 +4,7 @@ import pytest
 
 import psycopg3.types
 from psycopg3 import pq
+from psycopg3 import sql
 from psycopg3.types import Json, Jsonb
 from psycopg3.adapt import Format
 
@@ -48,6 +49,21 @@ def test_json_load(conn, val, jtype, fmt_out):
     assert cur.fetchone()[0] == json.loads(val)
 
 
+@pytest.mark.parametrize("val", samples)
+@pytest.mark.parametrize("jtype", ["json", "jsonb"])
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_json_load_copy(conn, val, jtype, fmt_out):
+    cur = conn.cursor()
+    stmt = sql.SQL("copy (select {}::{}) to stdout (format {})").format(
+        val, sql.Identifier(jtype), sql.SQL(fmt_out.name)
+    )
+    with cur.copy(stmt) as copy:
+        copy.set_types([jtype])
+        (got,) = copy.read_row()
+
+    assert got == json.loads(val)
+
+
 @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
 def test_json_dump_customise(conn, wrapper, fmt_in):
index ada082eeadf55505cf835772ae348954443b4057..3b88d38a184b7ca263e309af5ecb61637365f5db 100644 (file)
@@ -5,6 +5,7 @@ import subprocess as sp
 import pytest
 
 from psycopg3 import pq
+from psycopg3 import sql
 from psycopg3.adapt import Format
 
 
@@ -60,34 +61,69 @@ def test_network_dump(conn, fmt_in, val):
 @pytest.mark.parametrize("val", ["127.0.0.1/32", "::ffff:102:300/128"])
 def test_inet_load_address(conn, fmt_out, val):
     binary_check(fmt_out)
+    addr = ipaddress.ip_address(val.split("/", 1)[0])
     cur = conn.cursor(binary=fmt_out)
+
     cur.execute("select %s::inet", (val,))
-    addr = ipaddress.ip_address(val.split("/", 1)[0])
     assert cur.fetchone()[0] == addr
+
     cur.execute("select array[null, %s::inet]", (val,))
     assert cur.fetchone()[0] == [None, addr]
 
+    stmt = sql.SQL("copy (select {}::inet) to stdout (format {})").format(
+        val, sql.SQL(fmt_out.name)
+    )
+    with cur.copy(stmt) as copy:
+        copy.set_types(["inet"])
+        (got,) = copy.read_row()
+
+    assert got == addr
+
 
 @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
 @pytest.mark.parametrize("val", ["127.0.0.1/24", "::ffff:102:300/127"])
 def test_inet_load_network(conn, fmt_out, val):
     binary_check(fmt_out)
+    pyval = ipaddress.ip_interface(val)
     cur = conn.cursor(binary=fmt_out)
+
     cur.execute("select %s::inet", (val,))
-    assert cur.fetchone()[0] == ipaddress.ip_interface(val)
+    assert cur.fetchone()[0] == pyval
+
     cur.execute("select array[null, %s::inet]", (val,))
-    assert cur.fetchone()[0] == [None, ipaddress.ip_interface(val)]
+    assert cur.fetchone()[0] == [None, pyval]
+
+    stmt = sql.SQL("copy (select {}::inet) to stdout (format {})").format(
+        val, sql.SQL(fmt_out.name)
+    )
+    with cur.copy(stmt) as copy:
+        copy.set_types(["inet"])
+        (got,) = copy.read_row()
+
+    assert got == pyval
 
 
 @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
 @pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"])
 def test_cidr_load(conn, fmt_out, val):
     binary_check(fmt_out)
+    pyval = ipaddress.ip_network(val)
     cur = conn.cursor(binary=fmt_out)
+
     cur.execute("select %s::cidr", (val,))
-    assert cur.fetchone()[0] == ipaddress.ip_network(val)
+    assert cur.fetchone()[0] == pyval
+
     cur.execute("select array[null, %s::cidr]", (val,))
-    assert cur.fetchone()[0] == [None, ipaddress.ip_network(val)]
+    assert cur.fetchone()[0] == [None, pyval]
+
+    stmt = sql.SQL("copy (select {}::cidr) to stdout (format {})").format(
+        val, sql.SQL(fmt_out.name)
+    )
+    with cur.copy(stmt) as copy:
+        copy.set_types(["cidr"])
+        (got,) = copy.read_row()
+
+    assert got == pyval
 
 
 def binary_check(fmt):
index f0fd97ec0b271c67f5f8d54329142483c9483301..c18dcd35818792c8f3e732c823605c73bc9dce71 100644 (file)
@@ -114,16 +114,34 @@ def test_load_enc(conn, typename, encoding, fmt_out):
     ).fetchone()
     assert res == eur
 
+    stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format(
+        ord(eur), sql.SQL(fmt_out.name)
+    )
+    with cur.copy(stmt) as copy:
+        copy.set_types([typename])
+        (res,) = copy.read_row()
+
+    assert res == eur
+
 
 @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
 @pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
 def test_load_badenc(conn, typename, fmt_out):
+    conn.autocommit = True
     cur = conn.cursor(binary=fmt_out)
 
     conn.client_encoding = "latin1"
-    with pytest.raises(psycopg3.DatabaseError):
+    with pytest.raises(psycopg3.DataError):
         cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),))
 
+    stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format(
+        ord(eur), sql.SQL(fmt_out.name)
+    )
+    with cur.copy(stmt) as copy:
+        copy.set_types([typename])
+        with pytest.raises(psycopg3.DataError):
+            copy.read_row()
+
 
 @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
 @pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
@@ -131,9 +149,16 @@ def test_load_ascii(conn, typename, fmt_out):
     cur = conn.cursor(binary=fmt_out)
 
     conn.client_encoding = "ascii"
-    (res,) = cur.execute(
-        f"select chr(%s::int)::{typename}", (ord(eur),)
-    ).fetchone()
+    cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),))
+    assert cur.fetchone()[0] == eur.encode("utf8")
+
+    stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format(
+        ord(eur), sql.SQL(fmt_out.name)
+    )
+    with cur.copy(stmt) as copy:
+        copy.set_types([typename])
+        (res,) = copy.read_row()
+
     assert res == eur.encode("utf8")
 
 
index 548664668e147da38a268c9edb2844a950748f75..6fbea2697c00524aa6986a8426ea6d18e3950411 100644 (file)
@@ -5,6 +5,7 @@ import subprocess as sp
 import pytest
 
 from psycopg3 import pq
+from psycopg3 import sql
 from psycopg3.adapt import Format
 
 
@@ -23,6 +24,15 @@ def test_uuid_load(conn, fmt_out):
     cur.execute("select %s::uuid", (val,))
     assert cur.fetchone()[0] == UUID(val)
 
+    stmt = sql.SQL("copy (select {}::uuid) to stdout (format {})").format(
+        val, sql.SQL(fmt_out.name)
+    )
+    with cur.copy(stmt) as copy:
+        copy.set_types(["uuid"])
+        (res,) = copy.read_row()
+
+    assert res == UUID(val)
+
 
 @pytest.mark.subprocess
 def test_lazy_load(dsn):