]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
perf(array) faster algorithm to load nested binary arrays
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 30 Aug 2022 11:34:27 +0000 (12:34 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 10 Dec 2022 13:01:55 +0000 (13:01 +0000)
Avoid using generators in the algorithm. This also makes the function
easier to port to C.

psycopg/psycopg/types/array.py
tests/types/test_array.py

index addcb433933c59ab5bb1d3fc6a9d01b79a458555..e128ea60bba327618e41fade0fa5938a8efbc8dc 100644 (file)
@@ -6,8 +6,8 @@ Adapters for arrays
 
 import re
 import struct
-from typing import Any, cast, Callable, Iterator, List
-from typing import Optional, Pattern, Set, Tuple, Type
+from math import prod
+from typing import Any, cast, Callable, List, Optional, Pattern, Set, Tuple, Type
 
 from .. import pq
 from .. import errors as e
@@ -441,29 +441,25 @@ def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]:
 
 def load_binary(data: Buffer, load: LoadFunc) -> List[Any]:
     ndims, hasnull, oid = _unpack_head(data)
+
     if not ndims:
         return []
 
     p = 12 + 8 * ndims
-    dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))]
-
-    def consume(p: int) -> Iterator[Any]:
-        while True:
-            size = unpack_len(data, p)[0]
-            p += 4
-            if size != -1:
-                yield load(data[p : p + size])
-                p += size
-            else:
-                yield None
-
-    items = consume(p)
-
-    def agg(dims: List[int]) -> List[Any]:
-        if not dims:
-            return next(items)
-        else:
-            dim, dims = dims[0], dims[1:]
-            return [agg(dims) for _ in range(dim)]
-
-    return agg(dims)
+    dims = [_unpack_dim(data, i)[0] for i in range(12, p, 8)]
+    nelems = prod(dims)
+
+    out: List[Any] = [None] * nelems
+    for i in range(nelems):
+        size = unpack_len(data, p)[0]
+        p += 4
+        if size == -1:
+            continue
+        out[i] = load(data[p : p + size])
+        p += size
+
+    # fon ndims > 1 we have to aggregate the array into sub-arrays
+    for dim in dims[-1:0:-1]:
+        out = [out[i : i + dim] for i in range(0, len(out), dim)]
+
+    return out
index 41756c87947fc2d5e9f2d35503b545597e989b89..0d80152b8367034fe5249c8c868b489d2ba9e303 100644 (file)
@@ -1,3 +1,5 @@
+from math import prod
+from typing import List, Any
 from decimal import Decimal
 
 import pytest
@@ -287,6 +289,21 @@ def test_load_array_no_comma_separator(conn):
     assert cur.fetchone()[0] == ["(2,2),(1,1)", "(5,6),(3,4)"]
 
 
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_nested_array(conn, fmt_out):
+    dims = [3, 4, 5, 6]
+    a: List[Any] = list(range(prod(dims)))
+    for dim in dims[-1:0:-1]:
+        a = [a[i : i + dim] for i in range(0, len(a), dim)]
+
+    assert a[2][3][4][5] == prod(dims) - 1
+
+    sa = str(a).replace("[", "{").replace("]", "}")
+    got = conn.execute("select %s::int[][][][]", [sa], binary=fmt_out).fetchone()[0]
+    assert got == a
+
+
 @pytest.mark.crdb_skip("nested array")
 @pytest.mark.parametrize("fmt_out", pq.Format)
 @pytest.mark.parametrize(