import re
import struct
-from typing import Any, cast, Callable, Iterator, List
-from typing import Optional, Pattern, Set, Tuple, Type
+from math import prod
+from typing import Any, cast, Callable, List, Optional, Pattern, Set, Tuple, Type
from .. import pq
from .. import errors as e
def load_binary(data: Buffer, load: LoadFunc) -> List[Any]:
ndims, hasnull, oid = _unpack_head(data)
+
if not ndims:
return []
p = 12 + 8 * ndims
- dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))]
-
- def consume(p: int) -> Iterator[Any]:
- while True:
- size = unpack_len(data, p)[0]
- p += 4
- if size != -1:
- yield load(data[p : p + size])
- p += size
- else:
- yield None
-
- items = consume(p)
-
- def agg(dims: List[int]) -> List[Any]:
- if not dims:
- return next(items)
- else:
- dim, dims = dims[0], dims[1:]
- return [agg(dims) for _ in range(dim)]
-
- return agg(dims)
+ dims = [_unpack_dim(data, i)[0] for i in range(12, p, 8)]
+ nelems = prod(dims)
+
+ out: List[Any] = [None] * nelems
+ for i in range(nelems):
+ size = unpack_len(data, p)[0]
+ p += 4
+ if size == -1:
+ continue
+ out[i] = load(data[p : p + size])
+ p += size
+
+ # fon ndims > 1 we have to aggregate the array into sub-arrays
+ for dim in dims[-1:0:-1]:
+ out = [out[i : i + dim] for i in range(0, len(out), dim)]
+
+ return out
+from math import prod
+from typing import List, Any
from decimal import Decimal
import pytest
assert cur.fetchone()[0] == ["(2,2),(1,1)", "(5,6),(3,4)"]
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_nested_array(conn, fmt_out):
+ dims = [3, 4, 5, 6]
+ a: List[Any] = list(range(prod(dims)))
+ for dim in dims[-1:0:-1]:
+ a = [a[i : i + dim] for i in range(0, len(a), dim)]
+
+ assert a[2][3][4][5] == prod(dims) - 1
+
+ sa = str(a).replace("[", "{").replace("]", "}")
+ got = conn.execute("select %s::int[][][][]", [sa], binary=fmt_out).fetchone()[0]
+ assert got == a
+
+
@pytest.mark.crdb_skip("nested array")
@pytest.mark.parametrize("fmt_out", pq.Format)
@pytest.mark.parametrize(