From: Daniele Varrazzo Date: Tue, 30 Aug 2022 11:34:27 +0000 (+0100) Subject: perf(array) faster algorithm to load nested binary arrays X-Git-Tag: 3.1.5~12^2~15 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d870170fd463f2a10795c18482603e5315560485;p=thirdparty%2Fpsycopg.git perf(array) faster algorithm to load nested binary arrays Avoid using generators in the algorithm. This also makes the function easier to port to C. --- diff --git a/psycopg/psycopg/types/array.py b/psycopg/psycopg/types/array.py index addcb4339..e128ea60b 100644 --- a/psycopg/psycopg/types/array.py +++ b/psycopg/psycopg/types/array.py @@ -6,8 +6,8 @@ Adapters for arrays import re import struct -from typing import Any, cast, Callable, Iterator, List -from typing import Optional, Pattern, Set, Tuple, Type +from math import prod +from typing import Any, cast, Callable, List, Optional, Pattern, Set, Tuple, Type from .. import pq from .. import errors as e @@ -441,29 +441,25 @@ def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]: def load_binary(data: Buffer, load: LoadFunc) -> List[Any]: ndims, hasnull, oid = _unpack_head(data) + if not ndims: return [] p = 12 + 8 * ndims - dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))] - - def consume(p: int) -> Iterator[Any]: - while True: - size = unpack_len(data, p)[0] - p += 4 - if size != -1: - yield load(data[p : p + size]) - p += size - else: - yield None - - items = consume(p) - - def agg(dims: List[int]) -> List[Any]: - if not dims: - return next(items) - else: - dim, dims = dims[0], dims[1:] - return [agg(dims) for _ in range(dim)] - - return agg(dims) + dims = [_unpack_dim(data, i)[0] for i in range(12, p, 8)] + nelems = prod(dims) + + out: List[Any] = [None] * nelems + for i in range(nelems): + size = unpack_len(data, p)[0] + p += 4 + if size == -1: + continue + out[i] = load(data[p : p + size]) + p += size + + # fon ndims > 1 we have to aggregate the array into sub-arrays + for dim in dims[-1:0:-1]: + out = [out[i : i + dim] for i in range(0, len(out), dim)] + + return out diff --git a/tests/types/test_array.py b/tests/types/test_array.py index 41756c879..0d80152b8 100644 --- a/tests/types/test_array.py +++ b/tests/types/test_array.py @@ -1,3 +1,5 @@ +from math import prod +from typing import List, Any from decimal import Decimal import pytest @@ -287,6 +289,21 @@ def test_load_array_no_comma_separator(conn): assert cur.fetchone()[0] == ["(2,2),(1,1)", "(5,6),(3,4)"] +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_nested_array(conn, fmt_out): + dims = [3, 4, 5, 6] + a: List[Any] = list(range(prod(dims))) + for dim in dims[-1:0:-1]: + a = [a[i : i + dim] for i in range(0, len(a), dim)] + + assert a[2][3][4][5] == prod(dims) - 1 + + sa = str(a).replace("[", "{").replace("]", "}") + got = conn.execute("select %s::int[][][][]", [sa], binary=fmt_out).fetchone()[0] + assert got == a + + @pytest.mark.crdb_skip("nested array") @pytest.mark.parametrize("fmt_out", pq.Format) @pytest.mark.parametrize(