From: Daniele Varrazzo Date: Sun, 5 Apr 2020 01:10:44 +0000 (+1200) Subject: Added adaptation of binary arrays X-Git-Tag: 3.0.dev0~608 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=33eaeff2016b7a515f43a8a5ca9c0ab787f1b1fe;p=thirdparty%2Fpsycopg.git Added adaptation of binary arrays Added binary array adaptation for test types and some numeric types. --- diff --git a/psycopg3/types/array.py b/psycopg3/types/array.py index e4a51bd1d..43a5a4408 100644 --- a/psycopg3/types/array.py +++ b/psycopg3/types/array.py @@ -6,12 +6,16 @@ Adapters for arrays import re import struct -from typing import Any, Generator, List, Optional +from typing import Any, Generator, List, Optional, Tuple from .. import errors as e from ..pq import Format from ..adapt import Adapter, TypeCaster, Transformer, UnknownCaster from ..adapt import AdaptContext, TypeCasterType, TypeCasterFunc +from .oids import builtins + +TEXT_OID = builtins["text"].oid +TEXT_ARRAY_OID = builtins["text"].array_oid # from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO @@ -55,37 +59,122 @@ def escape_item(item: Optional[bytes]) -> bytes: return b'"' + _re_escape.sub(br"\\\1", item) + b'"' -@Adapter.text(list) -class ListAdapter(Adapter): +class BaseListAdapter(Adapter): def __init__(self, src: type, context: AdaptContext = None): super().__init__(src, context) self.tx = Transformer(context) - def adapt(self, obj: List[Any]) -> bytes: + def _array_oid(self, base_oid: int) -> int: + """ + Return the oid of the array from the oid of the base item. + + Fall back on text[]. + TODO: we shouldn't consider builtins only, but other adaptation + contexts too + """ + oid = 0 + if base_oid: + info = builtins.get(base_oid) + if info is not None: + oid = info.array_oid + + return oid or TEXT_ARRAY_OID + + +@Adapter.text(list) +class ListAdapter(BaseListAdapter): + def adapt(self, obj: List[Any]) -> Tuple[bytes, int]: tokens: List[bytes] = [] - self.adapt_list(obj, tokens) - return b"".join(tokens) - def adapt_list(self, obj: List[Any], tokens: List[bytes]) -> None: + oid = 0 + + def adapt_list(obj: List[Any]) -> None: + nonlocal oid + + if not obj: + tokens.append(b"{}") + return + + tokens.append(b"{") + for item in obj: + if isinstance(item, list): + adapt_list(item) + elif item is None: + tokens.append(b"NULL") + else: + ad = self.tx.adapt(item) + if isinstance(ad, tuple): + if oid == 0: + oid = ad[1] + ad = ad[0] + tokens.append(escape_item(ad)) + + tokens.append(b",") + + tokens[-1] = b"}" + + adapt_list(obj) + + return b"".join(tokens), self._array_oid(oid) + + +@Adapter.binary(list) +class BinaryListAdapter(BaseListAdapter): + def adapt(self, obj: List[Any]) -> Tuple[bytes, int]: if not obj: - tokens.append(b"{}") - return - - tokens.append(b"{") - for item in obj: - if isinstance(item, list): - self.adapt_list(item, tokens) - elif item is None: - tokens.append(b"NULL") + return _struct_head.pack(0, 0, TEXT_OID), TEXT_ARRAY_OID + + data: List[bytes] = [] + head = [0, 0, 0] # to fill: ndims, hasnull, base_oid + dims = [] + data = [] + + def calc_dims(L: List[Any]) -> None: + if isinstance(L, self.src): + if not L: + raise e.DataError("lists cannot contain empty lists") + dims.append(len(L)) + calc_dims(L[0]) + + calc_dims(obj) + + def adapt_list(L: List[Any], dim: int) -> None: + if len(L) != dims[dim]: + raise e.DataError("nested lists have inconsistent lengths") + + if dim == len(dims) - 1: + for item in L: + ad = self.tx.adapt(item, Format.BINARY) + if isinstance(ad, tuple): + if head[2] == 0: + head[2] = ad[1] + ad = ad[0] + if ad is None: + head[1] = 1 + data.append(b"\xff\xff\xff\xff") + else: + data.append(_struct_len.pack(len(ad))) + data.append(ad) else: - ad = self.tx.adapt(item) - if isinstance(ad, tuple): - ad = ad[0] - tokens.append(escape_item(ad)) + for item in L: + if not isinstance(item, self.src): + raise e.DataError( + "nested lists have inconsistent depths" + ) + adapt_list(item, dim + 1) # type: ignore + + adapt_list(obj, 0) + + head[0] = len(dims) + if head[2] == 0: + head[2] = TEXT_OID - tokens.append(b",") + oid = self._array_oid(head[2]) - tokens[-1] = b"}" + bhead = _struct_head.pack(*head) + b"".join( + _struct_dim.pack(dim, 1) for dim in dims + ) + return bhead + b"".join(data), oid class ArrayCasterBase(TypeCaster): @@ -141,9 +230,9 @@ class ArrayCasterText(ArrayCasterBase): return rv -_unpack_head = struct.Struct("!III").unpack_from -_unpack_dim = struct.Struct("!II").unpack_from -_unpack_len = struct.Struct("!i").unpack_from +_struct_head = struct.Struct("!III") +_struct_dim = struct.Struct("!II") +_struct_len = struct.Struct("!i") class ArrayCasterBinary(ArrayCasterBase): @@ -152,18 +241,20 @@ class ArrayCasterBinary(ArrayCasterBase): self.tx = Transformer(context) def cast(self, data: bytes) -> List[Any]: - ndims, hasnull, oid = _unpack_head(data[:12]) + ndims, hasnull, oid = _struct_head.unpack_from(data[:12]) if not ndims: return [] fcast = self.tx.get_cast_function(oid, Format.BINARY) p = 12 + 8 * ndims - dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))] + dims = [ + _struct_dim.unpack_from(data, i)[0] for i in list(range(12, p, 8)) + ] def consume(p: int) -> Generator[Any, None, None]: while 1: - size = _unpack_len(data, p)[0] + size = _struct_len.unpack_from(data, p)[0] p += 4 if size != -1: yield fcast(data[p : p + size]) diff --git a/psycopg3/types/numeric.py b/psycopg3/types/numeric.py index b06a99cf2..d11908792 100644 --- a/psycopg3/types/numeric.py +++ b/psycopg3/types/numeric.py @@ -49,6 +49,10 @@ _bool_adapt = { True: (b"t", builtins["bool"].oid), False: (b"f", builtins["bool"].oid), } +_bool_binary_adapt = { + True: (b"\x01", builtins["bool"].oid), + False: (b"\x00", builtins["bool"].oid), +} @Adapter.text(bool) @@ -56,6 +60,11 @@ def adapt_bool(obj: bool) -> Tuple[bytes, int]: return _bool_adapt[obj] +@Adapter.binary(bool) +def adapt_binary_bool(obj: bool) -> Tuple[bytes, int]: + return _bool_binary_adapt[obj] + + @TypeCaster.text(builtins["int2"].oid) @TypeCaster.text(builtins["int4"].oid) @TypeCaster.text(builtins["int8"].oid) @@ -69,24 +78,28 @@ def cast_int(data: bytes) -> int: @TypeCaster.binary(builtins["int2"].oid) +@ArrayCaster.binary(builtins["int2"].array_oid) def cast_binary_int2(data: bytes) -> int: rv: int = _int2_struct.unpack(data)[0] return rv @TypeCaster.binary(builtins["int4"].oid) +@ArrayCaster.binary(builtins["int4"].array_oid) def cast_binary_int4(data: bytes) -> int: rv: int = _int4_struct.unpack(data)[0] return rv @TypeCaster.binary(builtins["int8"].oid) +@ArrayCaster.binary(builtins["int8"].array_oid) def cast_binary_int8(data: bytes) -> int: rv: int = _int8_struct.unpack(data)[0] return rv @TypeCaster.binary(builtins["oid"].oid) +@ArrayCaster.binary(builtins["oid"].array_oid) def cast_binary_oid(data: bytes) -> int: rv: int = _oid_struct.unpack(data)[0] return rv diff --git a/psycopg3/types/text.py b/psycopg3/types/text.py index f9820b4f3..adfbf686b 100644 --- a/psycopg3/types/text.py +++ b/psycopg3/types/text.py @@ -87,5 +87,6 @@ def cast_bytea(data: bytes) -> bytes: @TypeCaster.binary(builtins["bytea"].oid) +@ArrayCaster.binary(builtins["bytea"].array_oid) def cast_bytea_binary(data: bytes) -> bytes: return data diff --git a/tests/types/test_array.py b/tests/types/test_array.py index a2c2acdc2..b6e79a8d4 100644 --- a/tests/types/test_array.py +++ b/tests/types/test_array.py @@ -24,7 +24,8 @@ tests_str = [ @pytest.mark.parametrize("obj, want", tests_str) def test_adapt_list_str(conn, obj, want, fmt_in): cur = conn.cursor() - cur.execute("select %s::text[] = %s::text[]", (obj, want)) + ph = "%s" if fmt_in == Format.TEXT else "%b" + cur.execute(f"select {ph}::text[] = %s::text[]", (obj, want)) assert cur.fetchone()[0] @@ -32,25 +33,27 @@ def test_adapt_list_str(conn, obj, want, fmt_in): @pytest.mark.parametrize("want, obj", tests_str) def test_cast_list_str(conn, obj, want, fmt_out): cur = conn.cursor(binary=fmt_out == Format.BINARY) - ph = "%s" if format == Format.TEXT else "%b" - cur.execute("select %s::text[]" % ph, (obj,)) + cur.execute("select %s::text[]", (obj,)) assert cur.fetchone()[0] == want -def test_all_chars(conn): - cur = conn.cursor() +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +def test_all_chars(conn, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) + ph = "%s" if fmt_in == Format.TEXT else "%b" for i in range(1, 256): c = chr(i) - cur.execute("select %s::text[]", ([c],)) + cur.execute(f"select {ph}::text[]", ([c],)) assert cur.fetchone()[0] == [c] a = list(map(chr, range(1, 256))) a.append("\u20ac") - cur.execute("select %s::text[]", (a,)) + cur.execute(f"select {ph}::text[]", (a,)) assert cur.fetchone()[0] == a a = "".join(a) - cur.execute("select %s::text[]", ([a],)) + cur.execute(f"select {ph}::text[]", ([a],)) assert cur.fetchone()[0] == [a] @@ -69,9 +72,10 @@ def test_adapt_list_int(conn, obj, want): assert cur.fetchone()[0] +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) @pytest.mark.parametrize("want, obj", tests_int) -def test_cast_list_int(conn, obj, want): - cur = conn.cursor() +def test_cast_list_int(conn, obj, want, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) cur.execute("select %s::int[]", (obj,)) assert cur.fetchone()[0] == want diff --git a/tests/types/test_numeric.py b/tests/types/test_numeric.py index e639e2637..1de3196f1 100644 --- a/tests/types/test_numeric.py +++ b/tests/types/test_numeric.py @@ -53,11 +53,11 @@ def test_adapt_int(conn, val, expr): ("4294967295", "oid", 4294967295), ], ) -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_cast_int(conn, val, pgtype, want, format): - cur = conn.cursor(binary=format == Format.BINARY) +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +def test_cast_int(conn, val, pgtype, want, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) cur.execute("select %%s::%s" % pgtype, (val,)) - assert cur.pgresult.fformat(0) == format + assert cur.pgresult.fformat(0) == fmt_out assert cur.pgresult.ftype(0) == builtins[pgtype].oid result = cur.fetchone()[0] assert result == want @@ -132,11 +132,11 @@ def test_adapt_float_approx(conn, val, expr): ("-inf", "float8", -float("inf")), ], ) -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_cast_float(conn, val, pgtype, want, format): - cur = conn.cursor(binary=format == Format.BINARY) +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +def test_cast_float(conn, val, pgtype, want, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) cur.execute("select %%s::%s" % pgtype, (val,)) - assert cur.pgresult.fformat(0) == format + assert cur.pgresult.fformat(0) == fmt_out result = cur.fetchone()[0] assert type(result) is type(want) if isnan(want): @@ -161,11 +161,11 @@ def test_cast_float(conn, val, pgtype, want, format): ("-1.42e40", "float8", -1.42e40), ], ) -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_cast_float_approx(conn, expr, pgtype, want, format): - cur = conn.cursor(binary=format == Format.BINARY) +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +def test_cast_float_approx(conn, expr, pgtype, want, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) cur.execute("select %s::%s" % (expr, pgtype)) - assert cur.pgresult.fformat(0) == format + assert cur.pgresult.fformat(0) == fmt_out result = cur.fetchone()[0] assert result == pytest.approx(want) @@ -226,12 +226,14 @@ def test_numeric_as_float(conn, val): # -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) @pytest.mark.parametrize("b", [True, False, None]) -def test_roundtrip_bool(conn, b, format): - cur = conn.cursor(binary=format == Format.BINARY) - result = cur.execute("select %s", (b,)).fetchone()[0] - assert cur.pgresult.fformat(0) == format +def test_roundtrip_bool(conn, b, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) + ph = "%s" if fmt_in == Format.TEXT else "%b" + result = cur.execute(f"select {ph}", (b,)).fetchone()[0] + assert cur.pgresult.fformat(0) == fmt_out assert result is b diff --git a/tests/types/test_text.py b/tests/types/test_text.py index a31015e4a..37ee843ee 100644 --- a/tests/types/test_text.py +++ b/tests/types/test_text.py @@ -11,97 +11,104 @@ eur = "\u20ac" # -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_adapt_1char(conn, format): +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) +def test_adapt_1char(conn, fmt_in): cur = conn.cursor() - ph = "%s" if format == Format.TEXT else "%b" + ph = "%s" if fmt_in == Format.TEXT else "%b" for i in range(1, 256): - cur.execute("select %s = chr(%%s::int)" % ph, (chr(i), i)) + cur.execute(f"select {ph} = chr(%s::int)", (chr(i), i)) assert cur.fetchone()[0], chr(i) -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_cast_1char(conn, format): - cur = conn.cursor(binary=format == Format.BINARY) +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +def test_cast_1char(conn, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) for i in range(1, 256): cur.execute("select chr(%s::int)", (i,)) assert cur.fetchone()[0] == chr(i) - assert cur.pgresult.fformat(0) == format + assert cur.pgresult.fformat(0) == fmt_out -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) @pytest.mark.parametrize("encoding", ["utf8", "latin9"]) -def test_adapt_enc(conn, format, encoding): +def test_adapt_enc(conn, fmt_in, encoding): cur = conn.cursor() - ph = "%s" if format == Format.TEXT else "%b" + ph = "%s" if fmt_in == Format.TEXT else "%b" conn.encoding = encoding - (res,) = cur.execute("select %s::bytea" % ph, (eur,)).fetchone() + (res,) = cur.execute(f"select {ph}::bytea", (eur,)).fetchone() assert res == eur.encode("utf8") -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_adapt_ascii(conn, format): - cur = conn.cursor(binary=format == Format.BINARY) - ph = "%s" if format == Format.TEXT else "%b" +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) +def test_adapt_ascii(conn, fmt_in): + cur = conn.cursor() + ph = "%s" if fmt_in == Format.TEXT else "%b" conn.encoding = "sql_ascii" - (res,) = cur.execute("select ascii(%s)" % ph, (eur,)).fetchone() + (res,) = cur.execute(f"select ascii({ph})", (eur,)).fetchone() assert res == ord(eur) -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_adapt_badenc(conn, format): +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) +def test_adapt_badenc(conn, fmt_in): cur = conn.cursor() - ph = "%s" if format == Format.TEXT else "%b" + ph = "%s" if fmt_in == Format.TEXT else "%b" conn.encoding = "latin1" with pytest.raises(UnicodeEncodeError): - cur.execute("select %s::bytea" % ph, (eur,)) + cur.execute(f"select {ph}::bytea", (eur,)) -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) @pytest.mark.parametrize("encoding", ["utf8", "latin9"]) -def test_cast_enc(conn, format, encoding): - cur = conn.cursor(binary=format == Format.BINARY) +def test_cast_enc(conn, fmt_out, encoding): + cur = conn.cursor(binary=fmt_out == Format.BINARY) conn.encoding = encoding (res,) = cur.execute("select chr(%s::int)", (ord(eur),)).fetchone() assert res == eur -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_cast_badenc(conn, format): - cur = conn.cursor(binary=format == Format.BINARY) +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +def test_cast_badenc(conn, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) conn.encoding = "latin1" with pytest.raises(psycopg3.DatabaseError): cur.execute("select chr(%s::int)", (ord(eur),)) -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_cast_ascii(conn, format): - cur = conn.cursor(binary=format == Format.BINARY) +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +def test_cast_ascii(conn, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) conn.encoding = "sql_ascii" (res,) = cur.execute("select chr(%s::int)", (ord(eur),)).fetchone() assert res == eur.encode("utf8") -def test_text_array(conn): - cur = conn.cursor() +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +def test_text_array(conn, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) + ph = "%s" if fmt_in == Format.TEXT else "%b" a = list(map(chr, range(1, 256))) + [eur] - (res,) = cur.execute("select %s::text[]", (a,)).fetchone() + + (res,) = cur.execute(f"select {ph}::text[]", (a,)).fetchone() assert res == a -def test_text_array_ascii(conn): +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +def test_text_array_ascii(conn, fmt_in, fmt_out): conn.encoding = "sql_ascii" - cur = conn.cursor() + cur = conn.cursor(binary=fmt_out == Format.BINARY) a = list(map(chr, range(1, 256))) + [eur] exp = [s.encode("utf8") for s in a] - (res,) = cur.execute("select %s::text[]", (a,)).fetchone() + ph = "%s" if fmt_in == Format.TEXT else "%b" + (res,) = cur.execute(f"select {ph}::text[]", (a,)).fetchone() assert res == exp @@ -110,28 +117,30 @@ def test_text_array_ascii(conn): # -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_adapt_1byte(conn, format): +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) +def test_adapt_1byte(conn, fmt_in): cur = conn.cursor() - ph = "%s" if format == Format.TEXT else "%b" - "select %s = %%s::bytea" % ph + ph = "%s" if fmt_in == Format.TEXT else "%b" for i in range(0, 256): - cur.execute("select %s = %%s::bytea" % ph, (bytes([i]), fr"\x{i:02x}")) + cur.execute(f"select {ph} = %s::bytea", (bytes([i]), fr"\x{i:02x}")) assert cur.fetchone()[0], i -@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_cast_1byte(conn, format): - cur = conn.cursor(binary=format == Format.BINARY) +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +def test_cast_1byte(conn, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) for i in range(0, 256): cur.execute("select %s::bytea", (fr"\x{i:02x}",)) assert cur.fetchone()[0] == bytes([i]) - assert cur.pgresult.fformat(0) == format + assert cur.pgresult.fformat(0) == fmt_out -def test_bytea_array(conn): - cur = conn.cursor() +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +def test_bytea_array(conn, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) a = [bytes(range(0, 256))] - (res,) = cur.execute("select %s::bytea[]", (a,)).fetchone() + ph = "%s" if fmt_in == Format.TEXT else "%b" + (res,) = cur.execute(f"select {ph}::bytea[]", (a,)).fetchone() assert res == a