]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added cast of binary arrays
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 4 Apr 2020 14:12:10 +0000 (02:12 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 4 Apr 2020 14:12:10 +0000 (02:12 +1200)
psycopg3/types/array.py
tests/types/test_array.py

index 39f33a1276c7dbe54ac7200e1127b8f246677cb8..8b85c68ad6541ce8cb723eb0bcd70a2ce2acab9e 100644 (file)
@@ -5,7 +5,8 @@ Adapters for arrays
 # Copyright (C) 2020 The Psycopg Team
 
 import re
-from typing import Any, List, Optional
+import struct
+from typing import Any, Generator, List, Optional
 
 from .. import errors as e
 from ..pq import Format
@@ -99,6 +100,8 @@ class ArrayCasterBase(TypeCaster):
         else:
             self.caster_func = type(self).base_caster
 
+
+class ArrayCasterText(ArrayCasterBase):
     def cast(self, data: bytes) -> List[Any]:
         rv = None
         stack: List[Any] = []
@@ -119,10 +122,12 @@ class ArrayCasterBase(TypeCaster):
 
             else:
                 if not stack:
-                    raise e.DataError(
-                        f"malformed array, unexpected"
-                        f" '{t.decode('utf8', 'replace')}'"
+                    wat = (
+                        t[:10].decode("utf8", "replace") + "..."
+                        if len(t) > 10
+                        else ""
                     )
+                    raise e.DataError(f"malformed array, unexpected '{wat}'")
                 if t == b"NULL":
                     v = None
                 else:
@@ -136,6 +141,48 @@ class ArrayCasterBase(TypeCaster):
         return rv
 
 
+_unpack_head = struct.Struct("!III").unpack_from
+_unpack_dim = struct.Struct("!II").unpack_from
+_unpack_len = struct.Struct("!i").unpack_from
+
+
+class ArrayCasterBinary(ArrayCasterBase):
+    def __init__(self, oid: int, context: AdaptContext = None):
+        super().__init__(oid, context)
+        self.tx = Transformer(context)
+
+    def cast(self, data: bytes) -> List[Any]:
+        ndims, hasnull, oid = _unpack_head(data[:12])
+        if not ndims:
+            return []
+
+        fcast = self.tx.get_cast_function(oid, Format.BINARY)
+
+        p = 12 + 8 * ndims
+        dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))]
+
+        def consume(p: int) -> Generator[Any, None, None]:
+            while 1:
+                size = _unpack_len(data, p)[0]
+                p += 4
+                if size != -1:
+                    yield fcast(data[p : p + size])
+                    p += size
+                else:
+                    yield None
+
+        items = consume(p)
+
+        def agg(dims: List[int]) -> List[Any]:
+            if not dims:
+                return next(items)
+            else:
+                dim, dims = dims[0], dims[1:]
+                return [agg(dims) for _ in range(dim)]
+
+        return agg(dims)
+
+
 class ArrayCaster(TypeCaster):
     @staticmethod
     def register(
@@ -144,13 +191,11 @@ class ArrayCaster(TypeCaster):
         context: AdaptContext = None,
         format: Format = Format.TEXT,
     ) -> TypeCasterType:
-        t = type(
-            caster.__name__ + "_array",
-            (ArrayCasterBase,),
-            {"base_caster": caster},
-        )
+        base = ArrayCasterText if format == Format.TEXT else ArrayCasterBinary
+        name = f"{caster.__name__}_{format.name.lower()}_array"
+        t = type(name, (base,), {"base_caster": caster})
         return TypeCaster.register(oid, t, context=context, format=format)
 
 
-class UnknownArrayCaster(ArrayCasterBase):
+class UnknownArrayCaster(ArrayCasterText):
     base_caster = UnknownCaster
index c69733719b4a6b7492bb705677ef8a99a417c32d..a2c2acdc28f341333dc808b3a73331ce5067b56c 100644 (file)
@@ -1,6 +1,6 @@
 import pytest
 from psycopg3.types import builtins
-from psycopg3.adapt import TypeCaster, UnknownCaster
+from psycopg3.adapt import TypeCaster, UnknownCaster, Format
 from psycopg3.types.array import UnknownArrayCaster, ArrayCaster
 
 
@@ -9,7 +9,10 @@ tests_str = [
     (["foo", "bar", "baz"], "{foo,bar,baz}"),
     (["foo", None, "baz"], "{foo,null,baz}"),
     (["foo", "null", "", "baz"], '{foo,"null","",baz}'),
-    ([["foo", "bar"], ["baz", "qux"]], "{{foo,bar},{baz,qux}}"),
+    (
+        [["foo", "bar"], ["baz", "qux"], ["quux", "quuux"]],
+        "{{foo,bar},{baz,qux},{quux,quuux}}",
+    ),
     (
         [[["fo{o", "ba}r"], ['ba"z', "qu'x"], ["qu ux", " "]]],
         r'{{{"fo{o","ba}r"},{"ba\"z",qu\'x},{"qu ux"," "}}}',
@@ -17,17 +20,20 @@ tests_str = [
 ]
 
 
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("obj, want", tests_str)
-def test_adapt_list_str(conn, obj, want):
+def test_adapt_list_str(conn, obj, want, fmt_in):
     cur = conn.cursor()
     cur.execute("select %s::text[] = %s::text[]", (obj, want))
     assert cur.fetchone()[0]
 
 
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("want, obj", tests_str)
-def test_cast_list_str(conn, obj, want):
-    cur = conn.cursor()
-    cur.execute("select %s::text[]", (obj,))
+def test_cast_list_str(conn, obj, want, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
+    ph = "%s" if format == Format.TEXT else "%b"
+    cur.execute("select %s::text[]" % ph, (obj,))
     assert cur.fetchone()[0] == want