From: Daniele Varrazzo Date: Sat, 16 Jan 2021 02:06:56 +0000 (+0100) Subject: Declare all loaders to receive an object supporting the buffer interface X-Git-Tag: 3.0.dev0~150 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=422f53ff13170dfa95abbd7be7a25302025b8c64;p=thirdparty%2Fpsycopg.git Declare all loaders to receive an object supporting the buffer interface There isn't in mypy such an object, so just use `Union[bytes, bytearray, memoryview]` Avoid a memory copy passing data from the Transformer to the loaders. I had started exercising this code path by adding copy tests, but after dropping the extra copy the same path is now exercised by any select, so I've stopped doing that. --- diff --git a/psycopg3/psycopg3/adapt.py b/psycopg3/psycopg3/adapt.py index aa16e251f..ffe90431e 100644 --- a/psycopg3/psycopg3/adapt.py +++ b/psycopg3/psycopg3/adapt.py @@ -12,7 +12,7 @@ from . import pq from . import proto from ._enums import Format as Format from .oids import builtins -from .proto import AdaptContext +from .proto import AdaptContext, Buffer as Buffer if TYPE_CHECKING: from .connection import BaseConnection @@ -41,13 +41,11 @@ class Dumper(ABC): """The oid to pass to the server, if known.""" @abstractmethod - def dump(self, obj: Any) -> bytes: + def dump(self, obj: Any) -> Buffer: """Convert the object *obj* to PostgreSQL representation.""" ... - # TODO: the protocol signature should probably return a Buffer like object - # (the C implementation may return bytearray) - def quote(self, obj: Any) -> bytes: + def quote(self, obj: Any) -> Buffer: """Convert the object *obj* to escaped representation.""" value = self.dump(obj) @@ -83,7 +81,7 @@ class Loader(ABC): ) @abstractmethod - def load(self, data: bytes) -> Any: + def load(self, data: Buffer) -> Any: """Convert a PostgreSQL value to a Python object.""" ... @@ -179,7 +177,6 @@ class AdaptersMap(AdaptContext): Return None if not found. """ - # TODO: auto selection if format == Format.AUTO: dmaps = [ self._dumpers[pq.Format.BINARY], diff --git a/psycopg3/psycopg3/proto.py b/psycopg3/psycopg3/proto.py index 3c0b66748..81e4061ac 100644 --- a/psycopg3/psycopg3/proto.py +++ b/psycopg3/psycopg3/proto.py @@ -18,6 +18,9 @@ if TYPE_CHECKING: from .waiting import Wait, Ready from .sql import Composable +# An object implementing the buffer protocol +Buffer = Union[bytes, bytearray, memoryview] + Query = Union[str, bytes, "Composable"] Params = Union[Sequence[Any], Mapping[str, Any]] ConnectionType = TypeVar("ConnectionType", bound="BaseConnection") diff --git a/psycopg3/psycopg3/types/array.py b/psycopg3/psycopg3/types/array.py index 0a931022e..75ae3895c 100644 --- a/psycopg3/psycopg3/types/array.py +++ b/psycopg3/psycopg3/types/array.py @@ -12,7 +12,7 @@ from .. import pq from .._enums import Format from .. import errors as e from ..oids import builtins, TEXT_OID, TEXT_ARRAY_OID, INVALID_OID -from ..adapt import Dumper, Loader, Transformer +from ..adapt import Buffer, Dumper, Loader, Transformer from ..proto import AdaptContext @@ -178,7 +178,7 @@ class ArrayLoader(BaseArrayLoader): """ ) - def load(self, data: bytes) -> List[Any]: + def load(self, data: Buffer) -> List[Any]: rv = None stack: List[Any] = [] cast = self._tx.get_loader(self.base_oid, self.format).load @@ -230,7 +230,7 @@ class ArrayBinaryLoader(BaseArrayLoader): format = pq.Format.BINARY - def load(self, data: bytes) -> List[Any]: + def load(self, data: Buffer) -> List[Any]: ndims, hasnull, oid = _struct_head.unpack_from(data[:12]) if not ndims: return [] diff --git a/psycopg3/psycopg3/types/composite.py b/psycopg3/psycopg3/types/composite.py index 2640cc827..074bfa38c 100644 --- a/psycopg3/psycopg3/types/composite.py +++ b/psycopg3/psycopg3/types/composite.py @@ -14,7 +14,7 @@ from .. import pq from .. import sql from .. import errors as e from ..oids import TypeInfo, TEXT_OID -from ..adapt import Format, Dumper, Loader, Transformer +from ..adapt import Buffer, Format, Dumper, Loader, Transformer from ..proto import AdaptContext from . import array @@ -232,7 +232,7 @@ class BaseCompositeLoader(Loader): class RecordLoader(BaseCompositeLoader): - def load(self, data: bytes) -> Tuple[Any, ...]: + def load(self, data: Buffer) -> Tuple[Any, ...]: if data == b"()": return () @@ -256,7 +256,7 @@ class RecordBinaryLoader(Loader): super().__init__(oid, context) self._tx = Transformer(context) - def load(self, data: bytes) -> Tuple[Any, ...]: + def load(self, data: Buffer) -> Tuple[Any, ...]: if not self._types_set: self._config_types(data) self._types_set = True @@ -291,7 +291,7 @@ class CompositeLoader(RecordLoader): fields_types: List[int] _types_set = False - def load(self, data: bytes) -> Any: + def load(self, data: Buffer) -> Any: if not self._types_set: self._config_types(data) self._types_set = True @@ -314,6 +314,6 @@ class CompositeBinaryLoader(RecordBinaryLoader): format = pq.Format.BINARY factory: Callable[..., Any] - def load(self, data: bytes) -> Any: + def load(self, data: Buffer) -> Any: r = super().load(data) return type(self).factory(*r) diff --git a/psycopg3/psycopg3/types/date.py b/psycopg3/psycopg3/types/date.py index 66738cad3..8e5a90380 100644 --- a/psycopg3/psycopg3/types/date.py +++ b/psycopg3/psycopg3/types/date.py @@ -11,7 +11,7 @@ from typing import cast, Optional from ..pq import Format from ..oids import builtins -from ..adapt import Dumper, Loader +from ..adapt import Buffer, Dumper, Loader from ..proto import AdaptContext from ..errors import InterfaceError, DataError @@ -82,7 +82,9 @@ class DateLoader(Loader): super().__init__(oid, context) self._format = self._format_from_context() - def load(self, data: bytes) -> date: + def load(self, data: Buffer) -> date: + if isinstance(data, memoryview): + data = bytes(data) try: return datetime.strptime(data.decode("utf8"), self._format).date() except ValueError as e: @@ -140,8 +142,10 @@ class TimeLoader(Loader): _format = "%H:%M:%S.%f" _format_no_micro = _format.replace(".%f", "") - def load(self, data: bytes) -> time: + def load(self, data: Buffer) -> time: # check if the data contains microseconds + if isinstance(data, memoryview): + data = bytes(data) fmt = self._format if b"." in data else self._format_no_micro try: return datetime.strptime(data.decode("utf8"), fmt).time() @@ -171,7 +175,10 @@ class TimeTzLoader(TimeLoader): super().__init__(oid, context) - def load(self, data: bytes) -> time: + def load(self, data: Buffer) -> time: + if isinstance(data, memoryview): + data = bytes(data) + # Hack to convert +HH in +HHMM if data[-3] in (43, 45): data += b"00" @@ -184,7 +191,9 @@ class TimeTzLoader(TimeLoader): return dt.time().replace(tzinfo=dt.tzinfo) - def _load_py36(self, data: bytes) -> time: + def _load_py36(self, data: Buffer) -> time: + if isinstance(data, memoryview): + data = bytes(data) # Drop seconds from timezone for Python 3.6 # Also, Python 3.6 doesn't support HHMM, only HH:MM if data[-6] in (43, 45): # +-HH:MM -> +-HHMM @@ -203,7 +212,10 @@ class TimestampLoader(DateLoader): super().__init__(oid, context) self._format_no_micro = self._format.replace(".%f", "") - def load(self, data: bytes) -> datetime: + def load(self, data: Buffer) -> datetime: + if isinstance(data, memoryview): + data = bytes(data) + # check if the data contains microseconds fmt = ( self._format if data.find(b".", 19) >= 0 else self._format_no_micro @@ -285,14 +297,19 @@ class TimestamptzLoader(TimestampLoader): setattr(self, "load", self._load_notimpl) return "" - def load(self, data: bytes) -> datetime: + def load(self, data: Buffer) -> datetime: + if isinstance(data, memoryview): + data = bytes(data) + # Hack to convert +HH in +HHMM if data[-3] in (43, 45): data += b"00" return super().load(data) - def _load_py36(self, data: bytes) -> datetime: + def _load_py36(self, data: Buffer) -> datetime: + if isinstance(data, memoryview): + data = bytes(data) # Drop seconds from timezone for Python 3.6 # Also, Python 3.6 doesn't support HHMM, only HH:MM tzsep = (43, 45) # + and - bytes @@ -305,7 +322,9 @@ class TimestamptzLoader(TimestampLoader): return super().load(data) - def _load_notimpl(self, data: bytes) -> datetime: + def _load_notimpl(self, data: Buffer) -> datetime: + if isinstance(data, memoryview): + data = bytes(data) raise NotImplementedError( "can't parse datetimetz with DateStyle" f" {self._get_datestyle().decode('ascii')}: {data.decode('ascii')}" @@ -337,7 +356,7 @@ class IntervalLoader(Loader): if ints != b"postgres": setattr(self, "load", self._load_notimpl) - def load(self, data: bytes) -> timedelta: + def load(self, data: Buffer) -> timedelta: m = self._re_interval.match(data) if not m: raise ValueError("can't parse interval: {data.decode('ascii')}") @@ -371,7 +390,9 @@ class IntervalLoader(Loader): except OverflowError as e: raise DataError(str(e)) - def _load_notimpl(self, data: bytes) -> timedelta: + def _load_notimpl(self, data: Buffer) -> timedelta: + if isinstance(data, memoryview): + data = bytes(data) ints = ( self.connection and self.connection.pgconn.parameter_status(b"IntervalStyle") diff --git a/psycopg3/psycopg3/types/json.py b/psycopg3/psycopg3/types/json.py index 917e2dbb1..db70e8afd 100644 --- a/psycopg3/psycopg3/types/json.py +++ b/psycopg3/psycopg3/types/json.py @@ -9,7 +9,7 @@ from typing import Any, Callable, Optional from ..pq import Format from ..oids import builtins -from ..adapt import Dumper, Loader +from ..adapt import Buffer, Dumper, Loader from ..errors import DataError JsonDumpsFunction = Callable[[Any], str] @@ -77,7 +77,10 @@ class JsonLoader(Loader): format = Format.TEXT - def load(self, data: bytes) -> Any: + def load(self, data: Buffer) -> Any: + # Json crashes on memoryview + if isinstance(data, memoryview): + data = bytes(data) return json.loads(data) @@ -90,7 +93,10 @@ class JsonbBinaryLoader(Loader): format = Format.BINARY - def load(self, data: bytes) -> Any: + def load(self, data: Buffer) -> Any: if data and data[0] != 1: raise DataError("unknown jsonb binary format: {data[0]}") - return json.loads(data[1:]) + data = data[1:] + if isinstance(data, memoryview): + data = bytes(data) + return json.loads(data) diff --git a/psycopg3/psycopg3/types/network.py b/psycopg3/psycopg3/types/network.py index bbe4913c9..a1b626053 100644 --- a/psycopg3/psycopg3/types/network.py +++ b/psycopg3/psycopg3/types/network.py @@ -8,7 +8,7 @@ from typing import Callable, Optional, Union, TYPE_CHECKING from ..pq import Format from ..oids import builtins -from ..adapt import Dumper, Loader +from ..adapt import Buffer, Dumper, Loader from ..proto import AdaptContext if TYPE_CHECKING: @@ -57,7 +57,10 @@ class InetLoader(_LazyIpaddress): format = Format.TEXT - def load(self, data: bytes) -> Union[Address, Interface]: + def load(self, data: Buffer) -> Union[Address, Interface]: + if isinstance(data, memoryview): + data = bytes(data) + if b"/" in data: return ip_interface(data.decode("utf8")) else: @@ -68,5 +71,8 @@ class CidrLoader(_LazyIpaddress): format = Format.TEXT - def load(self, data: bytes) -> Network: + def load(self, data: Buffer) -> Network: + if isinstance(data, memoryview): + data = bytes(data) + return ip_network(data.decode("utf8")) diff --git a/psycopg3/psycopg3/types/numeric.py b/psycopg3/psycopg3/types/numeric.py index 3fd829965..9d90e4794 100644 --- a/psycopg3/psycopg3/types/numeric.py +++ b/psycopg3/psycopg3/types/numeric.py @@ -10,7 +10,7 @@ from decimal import Decimal from ..pq import Format from ..oids import builtins -from ..adapt import Dumper, Loader +from ..adapt import Buffer, Dumper, Loader _PackInt = Callable[[int], bytes] _PackFloat = Callable[[float], bytes] @@ -166,7 +166,7 @@ class IntLoader(Loader): format = Format.TEXT - def load(self, data: bytes) -> int: + def load(self, data: Buffer) -> int: # it supports bytes directly return int(data) @@ -175,7 +175,7 @@ class Int2BinaryLoader(Loader): format = Format.BINARY - def load(self, data: bytes) -> int: + def load(self, data: Buffer) -> int: return _unpack_int2(data)[0] @@ -183,7 +183,7 @@ class Int4BinaryLoader(Loader): format = Format.BINARY - def load(self, data: bytes) -> int: + def load(self, data: Buffer) -> int: return _unpack_int4(data)[0] @@ -191,7 +191,7 @@ class Int8BinaryLoader(Loader): format = Format.BINARY - def load(self, data: bytes) -> int: + def load(self, data: Buffer) -> int: return _unpack_int8(data)[0] @@ -199,7 +199,7 @@ class OidBinaryLoader(Loader): format = Format.BINARY - def load(self, data: bytes) -> int: + def load(self, data: Buffer) -> int: return _unpack_uint4(data)[0] @@ -207,7 +207,7 @@ class FloatLoader(Loader): format = Format.TEXT - def load(self, data: bytes) -> float: + def load(self, data: Buffer) -> float: # it supports bytes directly return float(data) @@ -216,7 +216,7 @@ class Float4BinaryLoader(Loader): format = Format.BINARY - def load(self, data: bytes) -> float: + def load(self, data: Buffer) -> float: return _unpack_float4(data)[0] @@ -224,7 +224,7 @@ class Float8BinaryLoader(Loader): format = Format.BINARY - def load(self, data: bytes) -> float: + def load(self, data: Buffer) -> float: return _unpack_float8(data)[0] @@ -232,5 +232,7 @@ class NumericLoader(Loader): format = Format.TEXT - def load(self, data: bytes) -> Decimal: + def load(self, data: Buffer) -> Decimal: + if isinstance(data, memoryview): + data = bytes(data) return Decimal(data.decode("utf8")) diff --git a/psycopg3/psycopg3/types/range.py b/psycopg3/psycopg3/types/range.py index 2ae531ae8..0a4ea9f8f 100644 --- a/psycopg3/psycopg3/types/range.py +++ b/psycopg3/psycopg3/types/range.py @@ -14,7 +14,7 @@ from .. import sql from .. import errors as e from ..pq import Format from ..oids import builtins, TypeInfo -from ..adapt import Dumper, Loader +from ..adapt import Buffer, Dumper, Loader from ..proto import AdaptContext from . import array @@ -245,12 +245,12 @@ class RangeLoader(BaseCompositeLoader, Generic[T]): subtype_oid: int cls: Type[Range[T]] - def load(self, data: bytes) -> Range[T]: + def load(self, data: Buffer) -> Range[T]: if data == b"empty": return self.cls(empty=True) cast = self._tx.get_loader(self.subtype_oid, format=Format.TEXT).load - bounds = (data[:1] + data[-1:]).decode("utf-8") + bounds = _int2parens[data[0]] + _int2parens[data[-1]] min, max = ( cast(token) if token is not None else None for token in self._parse_record(data[1:-1]) @@ -258,6 +258,9 @@ class RangeLoader(BaseCompositeLoader, Generic[T]): return self.cls(min, max, bounds) +_int2parens = {ord(c): c for c in "[]()"} + + # Python wrappers for builtin range types diff --git a/psycopg3/psycopg3/types/singletons.py b/psycopg3/psycopg3/types/singletons.py index 2f10577cb..0b52f3989 100644 --- a/psycopg3/psycopg3/types/singletons.py +++ b/psycopg3/psycopg3/types/singletons.py @@ -6,7 +6,7 @@ Adapters for None and boolean. from ..pq import Format from ..oids import builtins -from ..adapt import Dumper, Loader +from ..adapt import Buffer, Dumper, Loader class BoolDumper(Dumper): @@ -49,7 +49,7 @@ class BoolLoader(Loader): format = Format.TEXT - def load(self, data: bytes) -> bool: + def load(self, data: Buffer) -> bool: return data == b"t" @@ -57,5 +57,5 @@ class BoolBinaryLoader(Loader): format = Format.BINARY - def load(self, data: bytes) -> bool: + def load(self, data: Buffer) -> bool: return data != b"\x00" diff --git a/psycopg3/psycopg3/types/text.py b/psycopg3/psycopg3/types/text.py index 47b676f8c..37010c9fc 100644 --- a/psycopg3/psycopg3/types/text.py +++ b/psycopg3/psycopg3/types/text.py @@ -8,7 +8,7 @@ from typing import Optional, Union, TYPE_CHECKING from ..pq import Format, Escaping from ..oids import builtins -from ..adapt import Dumper, Loader +from ..adapt import Buffer, Dumper, Loader from ..proto import AdaptContext from ..errors import DataError @@ -65,7 +65,7 @@ class TextLoader(Loader): enc = conn.client_encoding self._encoding = enc if enc != "ascii" else "" - def load(self, data: bytes) -> Union[bytes, str]: + def load(self, data: Buffer) -> Union[bytes, str]: if self._encoding: if isinstance(data, memoryview): return bytes(data).decode(self._encoding) @@ -120,7 +120,7 @@ class ByteaLoader(Loader): if not hasattr(self.__class__, "_escaping"): self.__class__._escaping = Escaping() - def load(self, data: bytes) -> bytes: + def load(self, data: Buffer) -> bytes: return self._escaping.unescape_bytea(data) @@ -128,5 +128,5 @@ class ByteaBinaryLoader(Loader): format = Format.BINARY - def load(self, data: bytes) -> bytes: + def load(self, data: Buffer) -> bytes: return data diff --git a/psycopg3/psycopg3/types/uuid.py b/psycopg3/psycopg3/types/uuid.py index ea833f2c1..4311ba24c 100644 --- a/psycopg3/psycopg3/types/uuid.py +++ b/psycopg3/psycopg3/types/uuid.py @@ -8,7 +8,7 @@ from typing import Callable, Optional, TYPE_CHECKING from ..pq import Format from ..oids import builtins -from ..adapt import Dumper, Loader +from ..adapt import Buffer, Dumper, Loader from ..proto import AdaptContext if TYPE_CHECKING: @@ -48,7 +48,9 @@ class UUIDLoader(Loader): imported = True - def load(self, data: bytes) -> "uuid.UUID": + def load(self, data: Buffer) -> "uuid.UUID": + if isinstance(data, memoryview): + data = bytes(data) return UUID(data.decode("utf8")) @@ -56,5 +58,7 @@ class UUIDBinaryLoader(UUIDLoader): format = Format.BINARY - def load(self, data: bytes) -> "uuid.UUID": + def load(self, data: Buffer) -> "uuid.UUID": + if isinstance(data, memoryview): + data = bytes(data) return UUID(bytes=data) diff --git a/psycopg3_c/psycopg3_c/_psycopg3/transform.pyx b/psycopg3_c/psycopg3_c/_psycopg3/transform.pyx index da88d9334..41ed1b22d 100644 --- a/psycopg3_c/psycopg3_c/_psycopg3/transform.pyx +++ b/psycopg3_c/psycopg3_c/_psycopg3/transform.pyx @@ -332,8 +332,10 @@ cdef class Transformer: if attval.len == -1: # NULL_LEN pyval = None else: - # TODO: no copy - b = attval.value[:attval.len] + b = PyMemoryView_FromObject( + ViewBuffer._from_buffer( + self._pgresult, + attval.value, attval.len)) pyval = PyObject_CallFunctionObjArgs( (loader).loadfunc, b, NULL) @@ -371,8 +373,10 @@ cdef class Transformer: pyval = (loader).cloader.cload( attval.value, attval.len) else: - # TODO: no copy - b = attval.value[:attval.len] + b = PyMemoryView_FromObject( + ViewBuffer._from_buffer( + self._pgresult, + attval.value, attval.len)) pyval = PyObject_CallFunctionObjArgs( (loader).loadfunc, b, NULL) diff --git a/tests/test_adapt.py b/tests/test_adapt.py index 6123b98bd..366a49b19 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -107,7 +107,7 @@ def test_subclass_loader(conn): class MyTextLoader(TextLoader): def load(self, data): - return (data * 2).decode("utf-8") + return (bytes(data) * 2).decode("utf-8") MyTextLoader.register("text", conn) assert conn.execute("select 'hello'::text").fetchone()[0] == "hellohello" @@ -321,7 +321,7 @@ def make_loader(suffix): format = pq.Format.TEXT def load(self, b): - return b.decode("ascii") + suffix + return bytes(b).decode("ascii") + suffix return TestLoader diff --git a/tests/types/test_array.py b/tests/types/test_array.py index a54330cae..46089c278 100644 --- a/tests/types/test_array.py +++ b/tests/types/test_array.py @@ -1,6 +1,7 @@ import pytest import psycopg3 from psycopg3 import pq +from psycopg3 import sql from psycopg3.oids import builtins from psycopg3.adapt import Format, Transformer from psycopg3.types import array @@ -99,6 +100,15 @@ def test_load_list_int(conn, obj, want, fmt_out): cur.execute("select %s::int[]", (obj,)) assert cur.fetchone()[0] == want + stmt = sql.SQL("copy (select {}::int[]) to stdout (format {})").format( + obj, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["int4[]"]) + (got,) = copy.read_row() + + assert got == want + def test_array_register(conn): cur = conn.cursor() diff --git a/tests/types/test_json.py b/tests/types/test_json.py index 8309faee5..bea17f3da 100644 --- a/tests/types/test_json.py +++ b/tests/types/test_json.py @@ -4,6 +4,7 @@ import pytest import psycopg3.types from psycopg3 import pq +from psycopg3 import sql from psycopg3.types import Json, Jsonb from psycopg3.adapt import Format @@ -48,6 +49,21 @@ def test_json_load(conn, val, jtype, fmt_out): assert cur.fetchone()[0] == json.loads(val) +@pytest.mark.parametrize("val", samples) +@pytest.mark.parametrize("jtype", ["json", "jsonb"]) +@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) +def test_json_load_copy(conn, val, jtype, fmt_out): + cur = conn.cursor() + stmt = sql.SQL("copy (select {}::{}) to stdout (format {})").format( + val, sql.Identifier(jtype), sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types([jtype]) + (got,) = copy.read_row() + + assert got == json.loads(val) + + @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) @pytest.mark.parametrize("wrapper", ["Json", "Jsonb"]) def test_json_dump_customise(conn, wrapper, fmt_in): diff --git a/tests/types/test_network.py b/tests/types/test_network.py index ada082eea..3b88d38a1 100644 --- a/tests/types/test_network.py +++ b/tests/types/test_network.py @@ -5,6 +5,7 @@ import subprocess as sp import pytest from psycopg3 import pq +from psycopg3 import sql from psycopg3.adapt import Format @@ -60,34 +61,69 @@ def test_network_dump(conn, fmt_in, val): @pytest.mark.parametrize("val", ["127.0.0.1/32", "::ffff:102:300/128"]) def test_inet_load_address(conn, fmt_out, val): binary_check(fmt_out) + addr = ipaddress.ip_address(val.split("/", 1)[0]) cur = conn.cursor(binary=fmt_out) + cur.execute("select %s::inet", (val,)) - addr = ipaddress.ip_address(val.split("/", 1)[0]) assert cur.fetchone()[0] == addr + cur.execute("select array[null, %s::inet]", (val,)) assert cur.fetchone()[0] == [None, addr] + stmt = sql.SQL("copy (select {}::inet) to stdout (format {})").format( + val, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["inet"]) + (got,) = copy.read_row() + + assert got == addr + @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) @pytest.mark.parametrize("val", ["127.0.0.1/24", "::ffff:102:300/127"]) def test_inet_load_network(conn, fmt_out, val): binary_check(fmt_out) + pyval = ipaddress.ip_interface(val) cur = conn.cursor(binary=fmt_out) + cur.execute("select %s::inet", (val,)) - assert cur.fetchone()[0] == ipaddress.ip_interface(val) + assert cur.fetchone()[0] == pyval + cur.execute("select array[null, %s::inet]", (val,)) - assert cur.fetchone()[0] == [None, ipaddress.ip_interface(val)] + assert cur.fetchone()[0] == [None, pyval] + + stmt = sql.SQL("copy (select {}::inet) to stdout (format {})").format( + val, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["inet"]) + (got,) = copy.read_row() + + assert got == pyval @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) @pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"]) def test_cidr_load(conn, fmt_out, val): binary_check(fmt_out) + pyval = ipaddress.ip_network(val) cur = conn.cursor(binary=fmt_out) + cur.execute("select %s::cidr", (val,)) - assert cur.fetchone()[0] == ipaddress.ip_network(val) + assert cur.fetchone()[0] == pyval + cur.execute("select array[null, %s::cidr]", (val,)) - assert cur.fetchone()[0] == [None, ipaddress.ip_network(val)] + assert cur.fetchone()[0] == [None, pyval] + + stmt = sql.SQL("copy (select {}::cidr) to stdout (format {})").format( + val, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["cidr"]) + (got,) = copy.read_row() + + assert got == pyval def binary_check(fmt): diff --git a/tests/types/test_text.py b/tests/types/test_text.py index f0fd97ec0..c18dcd358 100644 --- a/tests/types/test_text.py +++ b/tests/types/test_text.py @@ -114,16 +114,34 @@ def test_load_enc(conn, typename, encoding, fmt_out): ).fetchone() assert res == eur + stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format( + ord(eur), sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types([typename]) + (res,) = copy.read_row() + + assert res == eur + @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) @pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"]) def test_load_badenc(conn, typename, fmt_out): + conn.autocommit = True cur = conn.cursor(binary=fmt_out) conn.client_encoding = "latin1" - with pytest.raises(psycopg3.DatabaseError): + with pytest.raises(psycopg3.DataError): cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),)) + stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format( + ord(eur), sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types([typename]) + with pytest.raises(psycopg3.DataError): + copy.read_row() + @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) @pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"]) @@ -131,9 +149,16 @@ def test_load_ascii(conn, typename, fmt_out): cur = conn.cursor(binary=fmt_out) conn.client_encoding = "ascii" - (res,) = cur.execute( - f"select chr(%s::int)::{typename}", (ord(eur),) - ).fetchone() + cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),)) + assert cur.fetchone()[0] == eur.encode("utf8") + + stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format( + ord(eur), sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types([typename]) + (res,) = copy.read_row() + assert res == eur.encode("utf8") diff --git a/tests/types/test_uuid.py b/tests/types/test_uuid.py index 548664668..6fbea2697 100644 --- a/tests/types/test_uuid.py +++ b/tests/types/test_uuid.py @@ -5,6 +5,7 @@ import subprocess as sp import pytest from psycopg3 import pq +from psycopg3 import sql from psycopg3.adapt import Format @@ -23,6 +24,15 @@ def test_uuid_load(conn, fmt_out): cur.execute("select %s::uuid", (val,)) assert cur.fetchone()[0] == UUID(val) + stmt = sql.SQL("copy (select {}::uuid) to stdout (format {})").format( + val, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["uuid"]) + (res,) = copy.read_row() + + assert res == UUID(val) + @pytest.mark.subprocess def test_lazy_load(dsn):