]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added Loader.format and Dumper.format
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 30 Dec 2020 19:27:56 +0000 (20:27 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 8 Jan 2021 01:26:53 +0000 (02:26 +0100)
12 files changed:
psycopg3/psycopg3/adapt.py
psycopg3/psycopg3/dbapi20.py
psycopg3/psycopg3/types/array.py
psycopg3/psycopg3/types/composite.py
psycopg3/psycopg3/types/date.py
psycopg3/psycopg3/types/json.py
psycopg3/psycopg3/types/network.py
psycopg3/psycopg3/types/numeric.py
psycopg3/psycopg3/types/singletons.py
psycopg3/psycopg3/types/text.py
psycopg3/psycopg3/types/uuid.py
tests/test_adapt.py

index 9b5e89b03d3a8de0f940fca82cfa9fede3abbf0f..b32442b0669792ae29a0ae5f78188b889fadb651 100644 (file)
@@ -22,6 +22,7 @@ class Dumper(ABC):
     Convert Python object of the type *src* to PostgreSQL representation.
     """
 
+    format: Format
     connection: Optional["BaseConnection"] = None
 
     # A class-wide oid, which will be used by default by instances unless
@@ -98,6 +99,7 @@ class Loader(ABC):
     Convert PostgreSQL objects with OID *oid* to Python objects.
     """
 
+    format: Format
     connection: Optional["BaseConnection"]
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
index 7e883ba419c2b4e3dae8d1afe69bf4e3abd4ea91..24c5ad9c69655f904af59349cfa0c0846590dbcb 100644 (file)
@@ -10,7 +10,7 @@ from math import floor
 from typing import Any, Sequence
 
 from .oids import builtins
-from .adapt import Dumper
+from .adapt import Dumper, Format
 
 
 class DBAPITypeObject:
@@ -53,6 +53,7 @@ class Binary:
 @Dumper.text(Binary)
 class BinaryDumper(Dumper):
 
+    format = Format.TEXT
     oid = builtins["bytea"].oid
 
     def dump(self, obj: Binary) -> bytes:
index 129c62d9832f9a792a474188c16da32c030d9f9d..dcbec8efa143ec67b2e162b72ed10700bce5d1b1 100644 (file)
@@ -47,6 +47,9 @@ class ListDumper(BaseListDumper):
     # they are empty strings, contain curly braces, delimiter characters,
     # double quotes, backslashes, or white space, or match the word NULL.
     # TODO: recognise only , as delimiter. Should be configured
+
+    format = Format.TEXT
+
     _re_needs_quotes = re.compile(
         br"""(?xi)
           ^$              # the empty string
@@ -101,6 +104,9 @@ class ListDumper(BaseListDumper):
 
 @Dumper.binary(list)
 class ListBinaryDumper(BaseListDumper):
+
+    format = Format.BINARY
+
     def dump(self, obj: List[Any]) -> bytes:
         if not obj:
             return _struct_head.pack(0, 0, TEXT_OID)
@@ -166,6 +172,8 @@ class BaseArrayLoader(Loader):
 
 class ArrayLoader(BaseArrayLoader):
 
+    format = Format.TEXT
+
     # Tokenize an array representation into item and brackets
     # TODO: currently recognise only , as delimiter. Should be configured
     _re_parse = re.compile(
@@ -226,6 +234,9 @@ _struct_len = struct.Struct("!i")
 
 
 class ArrayBinaryLoader(BaseArrayLoader):
+
+    format = Format.BINARY
+
     def load(self, data: bytes) -> List[Any]:
         ndims, hasnull, oid = _struct_head.unpack_from(data[:12])
         if not ndims:
index 53abd2aa46696f26315755195e88a8b9777f15e8..e5065c5edbf2ca13b453ac96f67fae9e9ea88b8a 100644 (file)
@@ -144,6 +144,9 @@ where t.oid = %(name)s::regtype
 
 
 class SequenceDumper(Dumper):
+
+    format = Format.TEXT
+
     def __init__(self, src: type, context: Optional[AdaptContext] = None):
         super().__init__(src, context)
         self._tx = Transformer(context)
@@ -190,6 +193,9 @@ class TupleDumper(SequenceDumper):
 
 
 class BaseCompositeLoader(Loader):
+
+    format = Format.TEXT
+
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         self._tx = Transformer(context)
@@ -243,9 +249,15 @@ _struct_oidlen = struct.Struct("!Ii")
 
 
 @Loader.binary(builtins["record"].oid)
-class RecordBinaryLoader(BaseCompositeLoader):
+class RecordBinaryLoader(Loader):
+
+    format = Format.BINARY
     _types_set = False
 
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+        super().__init__(oid, context)
+        self._tx = Transformer(context)
+
     def load(self, data: bytes) -> Tuple[Any, ...]:
         if not self._types_set:
             self._config_types(data)
@@ -275,6 +287,8 @@ class RecordBinaryLoader(BaseCompositeLoader):
 
 
 class CompositeLoader(RecordLoader):
+
+    format = Format.TEXT
     factory: Callable[..., Any]
     fields_types: List[int]
     _types_set = False
@@ -298,6 +312,8 @@ class CompositeLoader(RecordLoader):
 
 
 class CompositeBinaryLoader(RecordBinaryLoader):
+
+    format = Format.BINARY
     factory: Callable[..., Any]
 
     def load(self, data: bytes) -> Any:
index d41855c802bd03cb1efa338b83482ac6d1fdf678..fbf74834d0817a60328ca7dfd80cd3b5169070eb 100644 (file)
@@ -10,7 +10,7 @@ from datetime import date, datetime, time, timedelta
 from typing import cast, Optional
 
 from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Dumper, Loader, Format
 from ..proto import AdaptContext
 from ..errors import InterfaceError, DataError
 
@@ -18,6 +18,7 @@ from ..errors import InterfaceError, DataError
 @Dumper.text(date)
 class DateDumper(Dumper):
 
+    format = Format.TEXT
     _oid = builtins["date"].oid
 
     def dump(self, obj: date) -> bytes:
@@ -29,6 +30,7 @@ class DateDumper(Dumper):
 @Dumper.text(time)
 class TimeDumper(Dumper):
 
+    format = Format.TEXT
     _oid = builtins["timetz"].oid
 
     def dump(self, obj: time) -> bytes:
@@ -38,6 +40,7 @@ class TimeDumper(Dumper):
 @Dumper.text(datetime)
 class DateTimeDumper(Dumper):
 
+    format = Format.TEXT
     _oid = builtins["timestamptz"].oid
 
     def dump(self, obj: date) -> bytes:
@@ -49,6 +52,7 @@ class DateTimeDumper(Dumper):
 @Dumper.text(timedelta)
 class TimeDeltaDumper(Dumper):
 
+    format = Format.TEXT
     _oid = builtins["interval"].oid
 
     def __init__(self, src: type, context: Optional[AdaptContext] = None):
@@ -75,6 +79,9 @@ class TimeDeltaDumper(Dumper):
 
 @Loader.text(builtins["date"].oid)
 class DateLoader(Loader):
+
+    format = Format.TEXT
+
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         self._format = self._format_from_context()
@@ -134,6 +141,7 @@ class DateLoader(Loader):
 @Loader.text(builtins["time"].oid)
 class TimeLoader(Loader):
 
+    format = Format.TEXT
     _format = "%H:%M:%S.%f"
     _format_no_micro = _format.replace(".%f", "")
 
@@ -158,6 +166,8 @@ class TimeLoader(Loader):
 
 @Loader.text(builtins["timetz"].oid)
 class TimeTzLoader(TimeLoader):
+
+    format = Format.TEXT
     _format = "%H:%M:%S.%f%z"
     _format_no_micro = _format.replace(".%f", "")
 
@@ -193,6 +203,9 @@ class TimeTzLoader(TimeLoader):
 
 @Loader.text(builtins["timestamp"].oid)
 class TimestampLoader(DateLoader):
+
+    format = Format.TEXT
+
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         self._format_no_micro = self._format.replace(".%f", "")
@@ -245,6 +258,9 @@ class TimestampLoader(DateLoader):
 
 @Loader.text(builtins["timestamptz"].oid)
 class TimestamptzLoader(TimestampLoader):
+
+    format = Format.TEXT
+
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         if sys.version_info < (3, 7):
             setattr(self, "load", self._load_py36)
@@ -307,6 +323,8 @@ class TimestamptzLoader(TimestampLoader):
 @Loader.text(builtins["interval"].oid)
 class IntervalLoader(Loader):
 
+    format = Format.TEXT
+
     _re_interval = re.compile(
         br"""
         (?: (?P<years> [-+]?\d+) \s+ years? \s* )?
index 3fe79d0c53218116d5d1d0beaffd916ebf179d52..272e2b650a0050f4e1ea55fea2828fae2d2e4cbc 100644 (file)
@@ -8,7 +8,7 @@ import json
 from typing import Any, Callable, Optional
 
 from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Dumper, Loader, Format
 from ..errors import DataError
 
 JsonDumpsFunction = Callable[[Any], str]
@@ -32,37 +32,63 @@ class Jsonb(_JsonWrapper):
 
 
 class _JsonDumper(Dumper):
+
+    format = Format.TEXT
+
     def dump(self, obj: _JsonWrapper) -> bytes:
         return obj.dumps().encode("utf-8")
 
 
 @Dumper.text(Json)
-@Dumper.binary(Json)
 class JsonDumper(_JsonDumper):
+
+    format = Format.TEXT
     _oid = builtins["json"].oid
 
 
+@Dumper.binary(Json)
+class JsonBinaryDumper(JsonDumper):
+
+    format = Format.BINARY
+
+
 @Dumper.text(Jsonb)
 class JsonbDumper(_JsonDumper):
+
+    format = Format.TEXT
     _oid = builtins["jsonb"].oid
 
 
 @Dumper.binary(Jsonb)
 class JsonbBinaryDumper(JsonbDumper):
+
+    format = Format.BINARY
+
     def dump(self, obj: _JsonWrapper) -> bytes:
         return b"\x01" + obj.dumps().encode("utf-8")
 
 
 @Loader.text(builtins["json"].oid)
 @Loader.text(builtins["jsonb"].oid)
-@Loader.binary(builtins["json"].oid)
 class JsonLoader(Loader):
+
+    format = Format.TEXT
+
     def load(self, data: bytes) -> Any:
         return json.loads(data)
 
 
+@Loader.binary(builtins["json"].oid)
+class JsonBinaryLoader(JsonLoader):
+
+    format = Format.BINARY
+
+
 @Loader.binary(builtins["jsonb"].oid)
 class JsonbBinaryLoader(Loader):
+
+    format = Format.BINARY
+
     def load(self, data: bytes) -> Any:
         if data and data[0] != 1:
             raise DataError("unknown jsonb binary format: {data[0]}")
index ff5df320b97753ea723e36517d12c574e926fc40..85c34bc299e422bb60cc318dc7bb22ff8dbcf3e8 100644 (file)
@@ -7,7 +7,7 @@ Adapters for network types.
 from typing import Callable, Optional, Union, TYPE_CHECKING
 
 from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Dumper, Loader, Format
 from ..proto import AdaptContext
 
 if TYPE_CHECKING:
@@ -29,6 +29,7 @@ ip_network: Callable[[str], Network]
 @Dumper.text("ipaddress.IPv6Interface")
 class InterfaceDumper(Dumper):
 
+    format = Format.TEXT
     _oid = builtins["inet"].oid
 
     def dump(self, obj: Interface) -> bytes:
@@ -39,6 +40,7 @@ class InterfaceDumper(Dumper):
 @Dumper.text("ipaddress.IPv6Network")
 class NetworkDumper(Dumper):
 
+    format = Format.TEXT
     _oid = builtins["cidr"].oid
 
     def dump(self, obj: Network) -> bytes:
@@ -54,6 +56,9 @@ class _LazyIpaddress(Loader):
 
 @Loader.text(builtins["inet"].oid)
 class InetLoader(_LazyIpaddress):
+
+    format = Format.TEXT
+
     def load(self, data: bytes) -> Union[Address, Interface]:
         if b"/" in data:
             return ip_interface(data.decode("utf8"))
@@ -63,5 +68,8 @@ class InetLoader(_LazyIpaddress):
 
 @Loader.text(builtins["cidr"].oid)
 class CidrLoader(_LazyIpaddress):
+
+    format = Format.TEXT
+
     def load(self, data: bytes) -> Network:
         return ip_network(data.decode("utf8"))
index 573df51975519a80777bad840111ea4c64239937..476b16cd000af444df6eca1e9e995a8fb0fdc3fb 100644 (file)
@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, Tuple, cast
 from decimal import Decimal
 
 from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Dumper, Loader, Format
 
 _PackInt = Callable[[int], bytes]
 _PackFloat = Callable[[float], bytes]
@@ -57,6 +57,9 @@ class Oid(int):
 
 
 class NumberDumper(Dumper):
+
+    format = Format.TEXT
+
     def dump(self, obj: Any) -> bytes:
         return str(obj).encode("utf8")
 
@@ -66,6 +69,7 @@ class NumberDumper(Dumper):
 
 
 class SpecialValuesDumper(NumberDumper):
+
     _special: Dict[bytes, bytes] = {}
 
     def quote(self, obj: Any) -> bytes:
@@ -84,12 +88,17 @@ class IntDumper(NumberDumper):
 
 @Dumper.binary(int)
 class IntBinaryDumper(IntDumper):
+
+    format = Format.BINARY
+
     def dump(self, obj: int) -> bytes:
         return _pack_int8(obj)
 
 
 @Dumper.text(float)
 class FloatDumper(SpecialValuesDumper):
+
+    format = Format.TEXT
     _oid = builtins["float8"].oid
 
     _special = {
@@ -100,7 +109,9 @@ class FloatDumper(SpecialValuesDumper):
 
 
 @Dumper.binary(float)
-class FloatBinaryDumper(NumberDumper):
+class FloatBinaryDumper(Dumper):
+
+    format = Format.BINARY
     _oid = builtins["float8"].oid
 
     def dump(self, obj: float) -> bytes:
@@ -140,24 +151,36 @@ class OidDumper(NumberDumper):
 
 @Dumper.binary(Int2)
 class Int2BinaryDumper(Int2Dumper):
+
+    format = Format.BINARY
+
     def dump(self, obj: int) -> bytes:
         return _pack_int2(obj)
 
 
 @Dumper.binary(Int4)
 class Int4BinaryDumper(Int4Dumper):
+
+    format = Format.BINARY
+
     def dump(self, obj: int) -> bytes:
         return _pack_int4(obj)
 
 
 @Dumper.binary(Int8)
 class Int8BinaryDumper(Int8Dumper):
+
+    format = Format.BINARY
+
     def dump(self, obj: int) -> bytes:
         return _pack_int8(obj)
 
 
 @Dumper.binary(Oid)
 class OidBinaryDumper(OidDumper):
+
+    format = Format.BINARY
+
     def dump(self, obj: int) -> bytes:
         return _pack_uint4(obj)
 
@@ -167,6 +190,9 @@ class OidBinaryDumper(OidDumper):
 @Loader.text(builtins["int8"].oid)
 @Loader.text(builtins["oid"].oid)
 class IntLoader(Loader):
+
+    format = Format.TEXT
+
     def load(self, data: bytes) -> int:
         # it supports bytes directly
         return int(data)
@@ -174,24 +200,36 @@ class IntLoader(Loader):
 
 @Loader.binary(builtins["int2"].oid)
 class Int2BinaryLoader(Loader):
+
+    format = Format.BINARY
+
     def load(self, data: bytes) -> int:
         return _unpack_int2(data)[0]
 
 
 @Loader.binary(builtins["int4"].oid)
 class Int4BinaryLoader(Loader):
+
+    format = Format.BINARY
+
     def load(self, data: bytes) -> int:
         return _unpack_int4(data)[0]
 
 
 @Loader.binary(builtins["int8"].oid)
 class Int8BinaryLoader(Loader):
+
+    format = Format.BINARY
+
     def load(self, data: bytes) -> int:
         return _unpack_int8(data)[0]
 
 
 @Loader.binary(builtins["oid"].oid)
 class OidBinaryLoader(Loader):
+
+    format = Format.BINARY
+
     def load(self, data: bytes) -> int:
         return _unpack_uint4(data)[0]
 
@@ -199,6 +237,9 @@ class OidBinaryLoader(Loader):
 @Loader.text(builtins["float4"].oid)
 @Loader.text(builtins["float8"].oid)
 class FloatLoader(Loader):
+
+    format = Format.TEXT
+
     def load(self, data: bytes) -> float:
         # it supports bytes directly
         return float(data)
@@ -206,17 +247,26 @@ class FloatLoader(Loader):
 
 @Loader.binary(builtins["float4"].oid)
 class Float4BinaryLoader(Loader):
+
+    format = Format.BINARY
+
     def load(self, data: bytes) -> float:
         return _unpack_float4(data)[0]
 
 
 @Loader.binary(builtins["float8"].oid)
 class Float8BinaryLoader(Loader):
+
+    format = Format.BINARY
+
     def load(self, data: bytes) -> float:
         return _unpack_float8(data)[0]
 
 
 @Loader.text(builtins["numeric"].oid)
 class NumericLoader(Loader):
+
+    format = Format.TEXT
+
     def load(self, data: bytes) -> Decimal:
         return Decimal(data.decode("utf8"))
index 4d695a92326b36832ee23f8f65b7b50be336e3ae..79176138621acb7516f6487a9f19a09760956f3c 100644 (file)
@@ -5,12 +5,13 @@ Adapters for None and boolean.
 # Copyright (C) 2020 The Psycopg Team
 
 from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Dumper, Loader, Format
 
 
 @Dumper.text(bool)
 class BoolDumper(Dumper):
 
+    format = Format.TEXT
     _oid = builtins["bool"].oid
 
     def dump(self, obj: bool) -> bytes:
@@ -23,6 +24,7 @@ class BoolDumper(Dumper):
 @Dumper.binary(bool)
 class BoolBinaryDumper(Dumper):
 
+    format = Format.BINARY
     _oid = builtins["bool"].oid
 
     def dump(self, obj: bool) -> bytes:
@@ -36,6 +38,8 @@ class NoneDumper(Dumper):
     quote(), so it can be used in sql composition.
     """
 
+    format = Format.TEXT
+
     def dump(self, obj: None) -> bytes:
         raise NotImplementedError("NULL is passed to Postgres in other ways")
 
@@ -45,11 +49,17 @@ class NoneDumper(Dumper):
 
 @Loader.text(builtins["bool"].oid)
 class BoolLoader(Loader):
+
+    format = Format.TEXT
+
     def load(self, data: bytes) -> bool:
         return data == b"t"
 
 
 @Loader.binary(builtins["bool"].oid)
 class BoolBinaryLoader(Loader):
+
+    format = Format.BINARY
+
     def load(self, data: bytes) -> bool:
         return data != b"\x00"
index b215989209559d6b1fbf4a6dd18aeaacfedd8d0d..ed927e476a0b479dc3a69d985eb1868395d44118 100644 (file)
@@ -8,7 +8,7 @@ from typing import Optional, Union, TYPE_CHECKING
 
 from ..pq import Escaping
 from ..oids import builtins, INVALID_OID
-from ..adapt import Dumper, Loader
+from ..adapt import Dumper, Loader, Format
 from ..proto import AdaptContext
 from ..errors import DataError
 
@@ -32,6 +32,9 @@ class _StringDumper(Dumper):
 
 @Dumper.binary(str)
 class StringBinaryDumper(_StringDumper):
+
+    format = Format.BINARY
+
     def dump(self, obj: str) -> bytes:
         # the server will raise DataError subclass if the string contains 0x00
         return obj.encode(self._encoding)
@@ -39,6 +42,9 @@ class StringBinaryDumper(_StringDumper):
 
 @Dumper.text(str)
 class StringDumper(_StringDumper):
+
+    format = Format.TEXT
+
     def dump(self, obj: str) -> bytes:
         if "\x00" in obj:
             raise DataError(
@@ -53,12 +59,9 @@ class StringDumper(_StringDumper):
 @Loader.text(builtins["name"].oid)
 @Loader.text(builtins["text"].oid)
 @Loader.text(builtins["varchar"].oid)
-@Loader.binary(builtins["bpchar"].oid)
-@Loader.binary(builtins["name"].oid)
-@Loader.binary(builtins["text"].oid)
-@Loader.binary(builtins["varchar"].oid)
 class TextLoader(Loader):
 
+    format = Format.TEXT
     _encoding = "utf-8"
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
@@ -76,11 +79,21 @@ class TextLoader(Loader):
             return data
 
 
+@Loader.binary(builtins["bpchar"].oid)
+@Loader.binary(builtins["name"].oid)
+@Loader.binary(builtins["text"].oid)
+@Loader.binary(builtins["varchar"].oid)
+class TextBinaryLoader(TextLoader):
+
+    format = Format.BINARY
+
+
 @Dumper.text(bytes)
 @Dumper.text(bytearray)
 @Dumper.text(memoryview)
 class BytesDumper(Dumper):
 
+    format = Format.TEXT
     _oid = builtins["bytea"].oid
 
     def __init__(self, src: type, context: Optional[AdaptContext] = None):
@@ -100,6 +113,7 @@ class BytesDumper(Dumper):
 @Dumper.binary(memoryview)
 class BytesBinaryDumper(Dumper):
 
+    format = Format.BINARY
     _oid = builtins["bytea"].oid
 
     def dump(
@@ -111,6 +125,8 @@ class BytesBinaryDumper(Dumper):
 
 @Loader.text(builtins["bytea"].oid)
 class ByteaLoader(Loader):
+
+    format = Format.TEXT
     _escaping: "EscapingProto"
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
@@ -125,5 +141,8 @@ class ByteaLoader(Loader):
 @Loader.binary(builtins["bytea"].oid)
 @Loader.binary(INVALID_OID)
 class ByteaBinaryLoader(Loader):
+
+    format = Format.BINARY
+
     def load(self, data: bytes) -> bytes:
         return data
index 819c4b3fe163b797473a818d1af741c3340cb09d..4dc284d838949ea0ee191eff6985fd64509e870a 100644 (file)
@@ -7,7 +7,7 @@ Adapters for the UUID type.
 from typing import Callable, Optional, TYPE_CHECKING
 
 from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Dumper, Loader, Format
 from ..proto import AdaptContext
 
 if TYPE_CHECKING:
@@ -20,6 +20,7 @@ UUID: Callable[..., "uuid.UUID"]
 @Dumper.text("uuid.UUID")
 class UUIDDumper(Dumper):
 
+    format = Format.TEXT
     _oid = builtins["uuid"].oid
 
     def dump(self, obj: "uuid.UUID") -> bytes:
@@ -28,12 +29,18 @@ class UUIDDumper(Dumper):
 
 @Dumper.binary("uuid.UUID")
 class UUIDBinaryDumper(UUIDDumper):
+
+    format = Format.BINARY
+
     def dump(self, obj: "uuid.UUID") -> bytes:
         return obj.bytes
 
 
 @Loader.text(builtins["uuid"].oid)
 class UUIDLoader(Loader):
+
+    format = Format.TEXT
+
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         global UUID
@@ -45,5 +52,8 @@ class UUIDLoader(Loader):
 
 @Loader.binary(builtins["uuid"].oid)
 class UUIDBinaryLoader(UUIDLoader):
+
+    format = Format.BINARY
+
     def load(self, data: bytes) -> "uuid.UUID":
         return UUID(bytes=data)
index 33dc4bdc61f3e3341c74b17fe9032a7729fddb37..da0bea7eb45a402d825752bd62d0bcdac000d089 100644 (file)
@@ -37,7 +37,7 @@ def test_quote(data, result):
 
 def test_dump_connection_ctx(conn):
     make_dumper("t").register(str, conn)
-    make_dumper("b").register(str, conn, format=Format.BINARY)
+    make_bin_dumper("b").register(str, conn, format=Format.BINARY)
 
     cur = conn.cursor()
     cur.execute("select %s, %b", ["hello", "world"])
@@ -46,11 +46,11 @@ def test_dump_connection_ctx(conn):
 
 def test_dump_cursor_ctx(conn):
     make_dumper("t").register(str, conn)
-    make_dumper("b").register(str, conn, format=Format.BINARY)
+    make_bin_dumper("b").register(str, conn, format=Format.BINARY)
 
     cur = conn.cursor()
     make_dumper("tc").register(str, cur)
-    make_dumper("bc").register(str, cur, format=Format.BINARY)
+    make_bin_dumper("bc").register(str, cur, format=Format.BINARY)
 
     cur.execute("select %s, %b", ["hello", "world"])
     assert cur.fetchone() == ("hellotc", "worldbc")
@@ -88,7 +88,7 @@ def test_cast(data, format, type, result):
 
 def test_load_connection_ctx(conn):
     make_loader("t").register(TEXT_OID, conn)
-    make_loader("b").register(TEXT_OID, conn, format=Format.BINARY)
+    make_bin_loader("b").register(TEXT_OID, conn, format=Format.BINARY)
 
     r = conn.cursor().execute("select 'hello'::text").fetchone()
     assert r == ("hellot",)
@@ -98,11 +98,11 @@ def test_load_connection_ctx(conn):
 
 def test_load_cursor_ctx(conn):
     make_loader("t").register(TEXT_OID, conn)
-    make_loader("b").register(TEXT_OID, conn, format=Format.BINARY)
+    make_bin_loader("b").register(TEXT_OID, conn, format=Format.BINARY)
 
     cur = conn.cursor()
     make_loader("tc").register(TEXT_OID, cur)
-    make_loader("bc").register(TEXT_OID, cur, format=Format.BINARY)
+    make_bin_loader("bc").register(TEXT_OID, cur, format=Format.BINARY)
 
     r = cur.execute("select 'hello'::text").fetchone()
     assert r == ("hellotc",)
@@ -125,7 +125,11 @@ def test_load_cursor_ctx(conn):
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 def test_load_cursor_ctx_nested(conn, sql, obj, fmt_out):
     cur = conn.cursor(format=fmt_out)
-    make_loader("c").register(TEXT_OID, cur, format=fmt_out)
+    if format == Format.TEXT:
+        make_loader("c").register(TEXT_OID, cur, format=fmt_out)
+    else:
+        make_bin_loader("c").register(TEXT_OID, cur, format=fmt_out)
+
     cur.execute(f"select {sql}")
     res = cur.fetchone()[0]
     assert res == obj
@@ -178,6 +182,7 @@ def make_dumper(suffix):
 
     class TestDumper(Dumper):
         oid = TEXT_OID
+        format = Format.TEXT
 
         def dump(self, s):
             return (s + suffix).encode("ascii")
@@ -185,11 +190,25 @@ def make_dumper(suffix):
     return TestDumper
 
 
+def make_bin_dumper(suffix):
+    cls = make_dumper(suffix)
+    cls.format = Format.BINARY
+    return cls
+
+
 def make_loader(suffix):
     """Create a test loader appending a suffix to the data returned."""
 
     class TestLoader(Loader):
+        format = Format.TEXT
+
         def load(self, b):
             return b.decode("ascii") + suffix
 
     return TestLoader
+
+
+def make_bin_loader(suffix):
+    cls = make_loader(suffix)
+    cls.format = Format.BINARY
+    return cls