From: Daniele Varrazzo Date: Sat, 4 Apr 2020 14:12:10 +0000 (+1200) Subject: Added cast of binary arrays X-Git-Tag: 3.0.dev0~611 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=320e7fc4bca303b783e4bda8bab308834191edb0;p=thirdparty%2Fpsycopg.git Added cast of binary arrays --- diff --git a/psycopg3/types/array.py b/psycopg3/types/array.py index 39f33a127..8b85c68ad 100644 --- a/psycopg3/types/array.py +++ b/psycopg3/types/array.py @@ -5,7 +5,8 @@ Adapters for arrays # Copyright (C) 2020 The Psycopg Team import re -from typing import Any, List, Optional +import struct +from typing import Any, Generator, List, Optional from .. import errors as e from ..pq import Format @@ -99,6 +100,8 @@ class ArrayCasterBase(TypeCaster): else: self.caster_func = type(self).base_caster + +class ArrayCasterText(ArrayCasterBase): def cast(self, data: bytes) -> List[Any]: rv = None stack: List[Any] = [] @@ -119,10 +122,12 @@ class ArrayCasterBase(TypeCaster): else: if not stack: - raise e.DataError( - f"malformed array, unexpected" - f" '{t.decode('utf8', 'replace')}'" + wat = ( + t[:10].decode("utf8", "replace") + "..." + if len(t) > 10 + else "" ) + raise e.DataError(f"malformed array, unexpected '{wat}'") if t == b"NULL": v = None else: @@ -136,6 +141,48 @@ class ArrayCasterBase(TypeCaster): return rv +_unpack_head = struct.Struct("!III").unpack_from +_unpack_dim = struct.Struct("!II").unpack_from +_unpack_len = struct.Struct("!i").unpack_from + + +class ArrayCasterBinary(ArrayCasterBase): + def __init__(self, oid: int, context: AdaptContext = None): + super().__init__(oid, context) + self.tx = Transformer(context) + + def cast(self, data: bytes) -> List[Any]: + ndims, hasnull, oid = _unpack_head(data[:12]) + if not ndims: + return [] + + fcast = self.tx.get_cast_function(oid, Format.BINARY) + + p = 12 + 8 * ndims + dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))] + + def consume(p: int) -> Generator[Any, None, None]: + while 1: + size = _unpack_len(data, p)[0] + p += 4 + if size != -1: + yield fcast(data[p : p + size]) + p += size + else: + yield None + + items = consume(p) + + def agg(dims: List[int]) -> List[Any]: + if not dims: + return next(items) + else: + dim, dims = dims[0], dims[1:] + return [agg(dims) for _ in range(dim)] + + return agg(dims) + + class ArrayCaster(TypeCaster): @staticmethod def register( @@ -144,13 +191,11 @@ class ArrayCaster(TypeCaster): context: AdaptContext = None, format: Format = Format.TEXT, ) -> TypeCasterType: - t = type( - caster.__name__ + "_array", - (ArrayCasterBase,), - {"base_caster": caster}, - ) + base = ArrayCasterText if format == Format.TEXT else ArrayCasterBinary + name = f"{caster.__name__}_{format.name.lower()}_array" + t = type(name, (base,), {"base_caster": caster}) return TypeCaster.register(oid, t, context=context, format=format) -class UnknownArrayCaster(ArrayCasterBase): +class UnknownArrayCaster(ArrayCasterText): base_caster = UnknownCaster diff --git a/tests/types/test_array.py b/tests/types/test_array.py index c69733719..a2c2acdc2 100644 --- a/tests/types/test_array.py +++ b/tests/types/test_array.py @@ -1,6 +1,6 @@ import pytest from psycopg3.types import builtins -from psycopg3.adapt import TypeCaster, UnknownCaster +from psycopg3.adapt import TypeCaster, UnknownCaster, Format from psycopg3.types.array import UnknownArrayCaster, ArrayCaster @@ -9,7 +9,10 @@ tests_str = [ (["foo", "bar", "baz"], "{foo,bar,baz}"), (["foo", None, "baz"], "{foo,null,baz}"), (["foo", "null", "", "baz"], '{foo,"null","",baz}'), - ([["foo", "bar"], ["baz", "qux"]], "{{foo,bar},{baz,qux}}"), + ( + [["foo", "bar"], ["baz", "qux"], ["quux", "quuux"]], + "{{foo,bar},{baz,qux},{quux,quuux}}", + ), ( [[["fo{o", "ba}r"], ['ba"z', "qu'x"], ["qu ux", " "]]], r'{{{"fo{o","ba}r"},{"ba\"z",qu\'x},{"qu ux"," "}}}', @@ -17,17 +20,20 @@ tests_str = [ ] +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) @pytest.mark.parametrize("obj, want", tests_str) -def test_adapt_list_str(conn, obj, want): +def test_adapt_list_str(conn, obj, want, fmt_in): cur = conn.cursor() cur.execute("select %s::text[] = %s::text[]", (obj, want)) assert cur.fetchone()[0] +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) @pytest.mark.parametrize("want, obj", tests_str) -def test_cast_list_str(conn, obj, want): - cur = conn.cursor() - cur.execute("select %s::text[]", (obj,)) +def test_cast_list_str(conn, obj, want, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) + ph = "%s" if format == Format.TEXT else "%b" + cur.execute("select %s::text[]" % ph, (obj,)) assert cur.fetchone()[0] == want