]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added adaptation of binary arrays
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 5 Apr 2020 01:10:44 +0000 (13:10 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 5 Apr 2020 01:10:44 +0000 (13:10 +1200)
Added binary array adaptation for test types and some numeric types.

psycopg3/types/array.py
psycopg3/types/numeric.py
psycopg3/types/text.py
tests/types/test_array.py
tests/types/test_numeric.py
tests/types/test_text.py

index e4a51bd1d090ce0d88cc04fac1d76375b9495580..43a5a4408b5ee5a0116008c4439aa54a67499a99 100644 (file)
@@ -6,12 +6,16 @@ Adapters for arrays
 
 import re
 import struct
-from typing import Any, Generator, List, Optional
+from typing import Any, Generator, List, Optional, Tuple
 
 from .. import errors as e
 from ..pq import Format
 from ..adapt import Adapter, TypeCaster, Transformer, UnknownCaster
 from ..adapt import AdaptContext, TypeCasterType, TypeCasterFunc
+from .oids import builtins
+
+TEXT_OID = builtins["text"].oid
+TEXT_ARRAY_OID = builtins["text"].array_oid
 
 
 # from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
@@ -55,37 +59,122 @@ def escape_item(item: Optional[bytes]) -> bytes:
         return b'"' + _re_escape.sub(br"\\\1", item) + b'"'
 
 
-@Adapter.text(list)
-class ListAdapter(Adapter):
+class BaseListAdapter(Adapter):
     def __init__(self, src: type, context: AdaptContext = None):
         super().__init__(src, context)
         self.tx = Transformer(context)
 
-    def adapt(self, obj: List[Any]) -> bytes:
+    def _array_oid(self, base_oid: int) -> int:
+        """
+        Return the oid of the array from the oid of the base item.
+
+        Fall back on text[].
+        TODO: we shouldn't consider builtins only, but other adaptation
+        contexts too
+        """
+        oid = 0
+        if base_oid:
+            info = builtins.get(base_oid)
+            if info is not None:
+                oid = info.array_oid
+
+        return oid or TEXT_ARRAY_OID
+
+
+@Adapter.text(list)
+class ListAdapter(BaseListAdapter):
+    def adapt(self, obj: List[Any]) -> Tuple[bytes, int]:
         tokens: List[bytes] = []
-        self.adapt_list(obj, tokens)
-        return b"".join(tokens)
 
-    def adapt_list(self, obj: List[Any], tokens: List[bytes]) -> None:
+        oid = 0
+
+        def adapt_list(obj: List[Any]) -> None:
+            nonlocal oid
+
+            if not obj:
+                tokens.append(b"{}")
+                return
+
+            tokens.append(b"{")
+            for item in obj:
+                if isinstance(item, list):
+                    adapt_list(item)
+                elif item is None:
+                    tokens.append(b"NULL")
+                else:
+                    ad = self.tx.adapt(item)
+                    if isinstance(ad, tuple):
+                        if oid == 0:
+                            oid = ad[1]
+                        ad = ad[0]
+                    tokens.append(escape_item(ad))
+
+                tokens.append(b",")
+
+            tokens[-1] = b"}"
+
+        adapt_list(obj)
+
+        return b"".join(tokens), self._array_oid(oid)
+
+
+@Adapter.binary(list)
+class BinaryListAdapter(BaseListAdapter):
+    def adapt(self, obj: List[Any]) -> Tuple[bytes, int]:
         if not obj:
-            tokens.append(b"{}")
-            return
-
-        tokens.append(b"{")
-        for item in obj:
-            if isinstance(item, list):
-                self.adapt_list(item, tokens)
-            elif item is None:
-                tokens.append(b"NULL")
+            return _struct_head.pack(0, 0, TEXT_OID), TEXT_ARRAY_OID
+
+        data: List[bytes] = []
+        head = [0, 0, 0]  # to fill: ndims, hasnull, base_oid
+        dims = []
+        data = []
+
+        def calc_dims(L: List[Any]) -> None:
+            if isinstance(L, self.src):
+                if not L:
+                    raise e.DataError("lists cannot contain empty lists")
+                dims.append(len(L))
+                calc_dims(L[0])
+
+        calc_dims(obj)
+
+        def adapt_list(L: List[Any], dim: int) -> None:
+            if len(L) != dims[dim]:
+                raise e.DataError("nested lists have inconsistent lengths")
+
+            if dim == len(dims) - 1:
+                for item in L:
+                    ad = self.tx.adapt(item, Format.BINARY)
+                    if isinstance(ad, tuple):
+                        if head[2] == 0:
+                            head[2] = ad[1]
+                        ad = ad[0]
+                    if ad is None:
+                        head[1] = 1
+                        data.append(b"\xff\xff\xff\xff")
+                    else:
+                        data.append(_struct_len.pack(len(ad)))
+                        data.append(ad)
             else:
-                ad = self.tx.adapt(item)
-                if isinstance(ad, tuple):
-                    ad = ad[0]
-                tokens.append(escape_item(ad))
+                for item in L:
+                    if not isinstance(item, self.src):
+                        raise e.DataError(
+                            "nested lists have inconsistent depths"
+                        )
+                    adapt_list(item, dim + 1)  # type: ignore
+
+        adapt_list(obj, 0)
+
+        head[0] = len(dims)
+        if head[2] == 0:
+            head[2] = TEXT_OID
 
-            tokens.append(b",")
+        oid = self._array_oid(head[2])
 
-        tokens[-1] = b"}"
+        bhead = _struct_head.pack(*head) + b"".join(
+            _struct_dim.pack(dim, 1) for dim in dims
+        )
+        return bhead + b"".join(data), oid
 
 
 class ArrayCasterBase(TypeCaster):
@@ -141,9 +230,9 @@ class ArrayCasterText(ArrayCasterBase):
         return rv
 
 
-_unpack_head = struct.Struct("!III").unpack_from
-_unpack_dim = struct.Struct("!II").unpack_from
-_unpack_len = struct.Struct("!i").unpack_from
+_struct_head = struct.Struct("!III")
+_struct_dim = struct.Struct("!II")
+_struct_len = struct.Struct("!i")
 
 
 class ArrayCasterBinary(ArrayCasterBase):
@@ -152,18 +241,20 @@ class ArrayCasterBinary(ArrayCasterBase):
         self.tx = Transformer(context)
 
     def cast(self, data: bytes) -> List[Any]:
-        ndims, hasnull, oid = _unpack_head(data[:12])
+        ndims, hasnull, oid = _struct_head.unpack_from(data[:12])
         if not ndims:
             return []
 
         fcast = self.tx.get_cast_function(oid, Format.BINARY)
 
         p = 12 + 8 * ndims
-        dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))]
+        dims = [
+            _struct_dim.unpack_from(data, i)[0] for i in list(range(12, p, 8))
+        ]
 
         def consume(p: int) -> Generator[Any, None, None]:
             while 1:
-                size = _unpack_len(data, p)[0]
+                size = _struct_len.unpack_from(data, p)[0]
                 p += 4
                 if size != -1:
                     yield fcast(data[p : p + size])
index b06a99cf27b1a36845d3c9d1708f8590d96c2e29..d119087929b9f1e992e76d834054a749b3a77319 100644 (file)
@@ -49,6 +49,10 @@ _bool_adapt = {
     True: (b"t", builtins["bool"].oid),
     False: (b"f", builtins["bool"].oid),
 }
+_bool_binary_adapt = {
+    True: (b"\x01", builtins["bool"].oid),
+    False: (b"\x00", builtins["bool"].oid),
+}
 
 
 @Adapter.text(bool)
@@ -56,6 +60,11 @@ def adapt_bool(obj: bool) -> Tuple[bytes, int]:
     return _bool_adapt[obj]
 
 
+@Adapter.binary(bool)
+def adapt_binary_bool(obj: bool) -> Tuple[bytes, int]:
+    return _bool_binary_adapt[obj]
+
+
 @TypeCaster.text(builtins["int2"].oid)
 @TypeCaster.text(builtins["int4"].oid)
 @TypeCaster.text(builtins["int8"].oid)
@@ -69,24 +78,28 @@ def cast_int(data: bytes) -> int:
 
 
 @TypeCaster.binary(builtins["int2"].oid)
+@ArrayCaster.binary(builtins["int2"].array_oid)
 def cast_binary_int2(data: bytes) -> int:
     rv: int = _int2_struct.unpack(data)[0]
     return rv
 
 
 @TypeCaster.binary(builtins["int4"].oid)
+@ArrayCaster.binary(builtins["int4"].array_oid)
 def cast_binary_int4(data: bytes) -> int:
     rv: int = _int4_struct.unpack(data)[0]
     return rv
 
 
 @TypeCaster.binary(builtins["int8"].oid)
+@ArrayCaster.binary(builtins["int8"].array_oid)
 def cast_binary_int8(data: bytes) -> int:
     rv: int = _int8_struct.unpack(data)[0]
     return rv
 
 
 @TypeCaster.binary(builtins["oid"].oid)
+@ArrayCaster.binary(builtins["oid"].array_oid)
 def cast_binary_oid(data: bytes) -> int:
     rv: int = _oid_struct.unpack(data)[0]
     return rv
index f9820b4f3fdf532f4f28458c90a38898f31266c6..adfbf686b5e152e6bcf60b268efc5e2351529eeb 100644 (file)
@@ -87,5 +87,6 @@ def cast_bytea(data: bytes) -> bytes:
 
 
 @TypeCaster.binary(builtins["bytea"].oid)
+@ArrayCaster.binary(builtins["bytea"].array_oid)
 def cast_bytea_binary(data: bytes) -> bytes:
     return data
index a2c2acdc28f341333dc808b3a73331ce5067b56c..b6e79a8d4a63acb725740f8d73546b5c3c8c2acd 100644 (file)
@@ -24,7 +24,8 @@ tests_str = [
 @pytest.mark.parametrize("obj, want", tests_str)
 def test_adapt_list_str(conn, obj, want, fmt_in):
     cur = conn.cursor()
-    cur.execute("select %s::text[] = %s::text[]", (obj, want))
+    ph = "%s" if fmt_in == Format.TEXT else "%b"
+    cur.execute(f"select {ph}::text[] = %s::text[]", (obj, want))
     assert cur.fetchone()[0]
 
 
@@ -32,25 +33,27 @@ def test_adapt_list_str(conn, obj, want, fmt_in):
 @pytest.mark.parametrize("want, obj", tests_str)
 def test_cast_list_str(conn, obj, want, fmt_out):
     cur = conn.cursor(binary=fmt_out == Format.BINARY)
-    ph = "%s" if format == Format.TEXT else "%b"
-    cur.execute("select %s::text[]" % ph, (obj,))
+    cur.execute("select %s::text[]", (obj,))
     assert cur.fetchone()[0] == want
 
 
-def test_all_chars(conn):
-    cur = conn.cursor()
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_all_chars(conn, fmt_in, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
+    ph = "%s" if fmt_in == Format.TEXT else "%b"
     for i in range(1, 256):
         c = chr(i)
-        cur.execute("select %s::text[]", ([c],))
+        cur.execute(f"select {ph}::text[]", ([c],))
         assert cur.fetchone()[0] == [c]
 
     a = list(map(chr, range(1, 256)))
     a.append("\u20ac")
-    cur.execute("select %s::text[]", (a,))
+    cur.execute(f"select {ph}::text[]", (a,))
     assert cur.fetchone()[0] == a
 
     a = "".join(a)
-    cur.execute("select %s::text[]", ([a],))
+    cur.execute(f"select {ph}::text[]", ([a],))
     assert cur.fetchone()[0] == [a]
 
 
@@ -69,9 +72,10 @@ def test_adapt_list_int(conn, obj, want):
     assert cur.fetchone()[0]
 
 
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("want, obj", tests_int)
-def test_cast_list_int(conn, obj, want):
-    cur = conn.cursor()
+def test_cast_list_int(conn, obj, want, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
     cur.execute("select %s::int[]", (obj,))
     assert cur.fetchone()[0] == want
 
index e639e26370385f66ff5e185d53323925a48d4fbd..1de3196f1dfa1562e2fd76250561ddc995fa0709 100644 (file)
@@ -53,11 +53,11 @@ def test_adapt_int(conn, val, expr):
         ("4294967295", "oid", 4294967295),
     ],
 )
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_cast_int(conn, val, pgtype, want, format):
-    cur = conn.cursor(binary=format == Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_int(conn, val, pgtype, want, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
     cur.execute("select %%s::%s" % pgtype, (val,))
-    assert cur.pgresult.fformat(0) == format
+    assert cur.pgresult.fformat(0) == fmt_out
     assert cur.pgresult.ftype(0) == builtins[pgtype].oid
     result = cur.fetchone()[0]
     assert result == want
@@ -132,11 +132,11 @@ def test_adapt_float_approx(conn, val, expr):
         ("-inf", "float8", -float("inf")),
     ],
 )
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_cast_float(conn, val, pgtype, want, format):
-    cur = conn.cursor(binary=format == Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_float(conn, val, pgtype, want, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
     cur.execute("select %%s::%s" % pgtype, (val,))
-    assert cur.pgresult.fformat(0) == format
+    assert cur.pgresult.fformat(0) == fmt_out
     result = cur.fetchone()[0]
     assert type(result) is type(want)
     if isnan(want):
@@ -161,11 +161,11 @@ def test_cast_float(conn, val, pgtype, want, format):
         ("-1.42e40", "float8", -1.42e40),
     ],
 )
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_cast_float_approx(conn, expr, pgtype, want, format):
-    cur = conn.cursor(binary=format == Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_float_approx(conn, expr, pgtype, want, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
     cur.execute("select %s::%s" % (expr, pgtype))
-    assert cur.pgresult.fformat(0) == format
+    assert cur.pgresult.fformat(0) == fmt_out
     result = cur.fetchone()[0]
     assert result == pytest.approx(want)
 
@@ -226,12 +226,14 @@ def test_numeric_as_float(conn, val):
 #
 
 
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("b", [True, False, None])
-def test_roundtrip_bool(conn, b, format):
-    cur = conn.cursor(binary=format == Format.BINARY)
-    result = cur.execute("select %s", (b,)).fetchone()[0]
-    assert cur.pgresult.fformat(0) == format
+def test_roundtrip_bool(conn, b, fmt_in, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
+    ph = "%s" if fmt_in == Format.TEXT else "%b"
+    result = cur.execute(f"select {ph}", (b,)).fetchone()[0]
+    assert cur.pgresult.fformat(0) == fmt_out
     assert result is b
 
 
index a31015e4a28d923e50a9413ce43a8649e9267fde..37ee843ee0c059a870901db7f9a9384831c97570 100644 (file)
@@ -11,97 +11,104 @@ eur = "\u20ac"
 #
 
 
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_adapt_1char(conn, format):
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+def test_adapt_1char(conn, fmt_in):
     cur = conn.cursor()
-    ph = "%s" if format == Format.TEXT else "%b"
+    ph = "%s" if fmt_in == Format.TEXT else "%b"
     for i in range(1, 256):
-        cur.execute("select %s = chr(%%s::int)" % ph, (chr(i), i))
+        cur.execute(f"select {ph} = chr(%s::int)", (chr(i), i))
         assert cur.fetchone()[0], chr(i)
 
 
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_cast_1char(conn, format):
-    cur = conn.cursor(binary=format == Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_1char(conn, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
     for i in range(1, 256):
         cur.execute("select chr(%s::int)", (i,))
         assert cur.fetchone()[0] == chr(i)
 
-    assert cur.pgresult.fformat(0) == format
+    assert cur.pgresult.fformat(0) == fmt_out
 
 
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("encoding", ["utf8", "latin9"])
-def test_adapt_enc(conn, format, encoding):
+def test_adapt_enc(conn, fmt_in, encoding):
     cur = conn.cursor()
-    ph = "%s" if format == Format.TEXT else "%b"
+    ph = "%s" if fmt_in == Format.TEXT else "%b"
 
     conn.encoding = encoding
-    (res,) = cur.execute("select %s::bytea" % ph, (eur,)).fetchone()
+    (res,) = cur.execute(f"select {ph}::bytea", (eur,)).fetchone()
     assert res == eur.encode("utf8")
 
 
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_adapt_ascii(conn, format):
-    cur = conn.cursor(binary=format == Format.BINARY)
-    ph = "%s" if format == Format.TEXT else "%b"
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+def test_adapt_ascii(conn, fmt_in):
+    cur = conn.cursor()
+    ph = "%s" if fmt_in == Format.TEXT else "%b"
 
     conn.encoding = "sql_ascii"
-    (res,) = cur.execute("select ascii(%s)" % ph, (eur,)).fetchone()
+    (res,) = cur.execute(f"select ascii({ph})", (eur,)).fetchone()
     assert res == ord(eur)
 
 
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_adapt_badenc(conn, format):
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+def test_adapt_badenc(conn, fmt_in):
     cur = conn.cursor()
-    ph = "%s" if format == Format.TEXT else "%b"
+    ph = "%s" if fmt_in == Format.TEXT else "%b"
 
     conn.encoding = "latin1"
     with pytest.raises(UnicodeEncodeError):
-        cur.execute("select %s::bytea" % ph, (eur,))
+        cur.execute(f"select {ph}::bytea", (eur,))
 
 
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("encoding", ["utf8", "latin9"])
-def test_cast_enc(conn, format, encoding):
-    cur = conn.cursor(binary=format == Format.BINARY)
+def test_cast_enc(conn, fmt_out, encoding):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
 
     conn.encoding = encoding
     (res,) = cur.execute("select chr(%s::int)", (ord(eur),)).fetchone()
     assert res == eur
 
 
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_cast_badenc(conn, format):
-    cur = conn.cursor(binary=format == Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_badenc(conn, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
 
     conn.encoding = "latin1"
     with pytest.raises(psycopg3.DatabaseError):
         cur.execute("select chr(%s::int)", (ord(eur),))
 
 
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_cast_ascii(conn, format):
-    cur = conn.cursor(binary=format == Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_ascii(conn, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
 
     conn.encoding = "sql_ascii"
     (res,) = cur.execute("select chr(%s::int)", (ord(eur),)).fetchone()
     assert res == eur.encode("utf8")
 
 
-def test_text_array(conn):
-    cur = conn.cursor()
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_text_array(conn, fmt_in, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
+    ph = "%s" if fmt_in == Format.TEXT else "%b"
     a = list(map(chr, range(1, 256))) + [eur]
-    (res,) = cur.execute("select %s::text[]", (a,)).fetchone()
+
+    (res,) = cur.execute(f"select {ph}::text[]", (a,)).fetchone()
     assert res == a
 
 
-def test_text_array_ascii(conn):
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_text_array_ascii(conn, fmt_in, fmt_out):
     conn.encoding = "sql_ascii"
-    cur = conn.cursor()
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
     a = list(map(chr, range(1, 256))) + [eur]
     exp = [s.encode("utf8") for s in a]
-    (res,) = cur.execute("select %s::text[]", (a,)).fetchone()
+    ph = "%s" if fmt_in == Format.TEXT else "%b"
+    (res,) = cur.execute(f"select {ph}::text[]", (a,)).fetchone()
     assert res == exp
 
 
@@ -110,28 +117,30 @@ def test_text_array_ascii(conn):
 #
 
 
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_adapt_1byte(conn, format):
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+def test_adapt_1byte(conn, fmt_in):
     cur = conn.cursor()
-    ph = "%s" if format == Format.TEXT else "%b"
-    "select %s = %%s::bytea" % ph
+    ph = "%s" if fmt_in == Format.TEXT else "%b"
     for i in range(0, 256):
-        cur.execute("select %s = %%s::bytea" % ph, (bytes([i]), fr"\x{i:02x}"))
+        cur.execute(f"select {ph} = %s::bytea", (bytes([i]), fr"\x{i:02x}"))
         assert cur.fetchone()[0], i
 
 
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_cast_1byte(conn, format):
-    cur = conn.cursor(binary=format == Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_1byte(conn, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
     for i in range(0, 256):
         cur.execute("select %s::bytea", (fr"\x{i:02x}",))
         assert cur.fetchone()[0] == bytes([i])
 
-    assert cur.pgresult.fformat(0) == format
+    assert cur.pgresult.fformat(0) == fmt_out
 
 
-def test_bytea_array(conn):
-    cur = conn.cursor()
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_bytea_array(conn, fmt_in, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
     a = [bytes(range(0, 256))]
-    (res,) = cur.execute("select %s::bytea[]", (a,)).fetchone()
+    ph = "%s" if fmt_in == Format.TEXT else "%b"
+    (res,) = cur.execute(f"select {ph}::bytea[]", (a,)).fetchone()
     assert res == a