]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(array): make load function static
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 30 Aug 2022 09:36:46 +0000 (10:36 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 10 Dec 2022 13:01:55 +0000 (13:01 +0000)
This will allow replacing them with a fast version as with the helper
functions in the copy module.

psycopg/psycopg/types/array.py

index 202cd5e3e3857f5f5ff40bc07ce5c7bbfd442669..addcb433933c59ab5bb1d3fc6a9d01b79a458555 100644 (file)
@@ -12,7 +12,7 @@ from typing import Optional, Pattern, Set, Tuple, Type
 from .. import pq
 from .. import errors as e
 from .. import postgres
-from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType, LoadFunc
 from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
 from .._compat import cache
 from .._struct import pack_len, unpack_len
@@ -299,70 +299,8 @@ class ArrayLoader(BaseArrayLoader):
     delimiter = b","
 
     def load(self, data: Buffer) -> List[Any]:
-        rv = None
-        stack: List[Any] = []
-        cast = self._tx.get_loader(self.base_oid, self.format).load
-
-        # Remove the dimensions information prefix (``[...]=``)
-        if data and data[0] == b"["[0]:
-            if isinstance(data, memoryview):
-                data = bytes(data)
-            idx = data.find(b"=")
-            if idx == -1:
-                raise e.DataError("malformed array, no '=' after dimension information")
-            data = data[idx + 1 :]
-
-        re_parse = _get_array_parse_regexp(self.delimiter)
-        for m in re_parse.finditer(data):
-            t = m.group(1)
-            if t == b"{":
-                a: List[Any] = []
-                if rv is None:
-                    rv = a
-                if stack:
-                    stack[-1].append(a)
-                stack.append(a)
-
-            elif t == b"}":
-                if not stack:
-                    raise e.DataError("malformed array, unexpected '}'")
-                rv = stack.pop()
-
-            else:
-                if not stack:
-                    wat = (
-                        t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else ""
-                    )
-                    raise e.DataError(f"malformed array, unexpected '{wat}'")
-                if t == b"NULL":
-                    v = None
-                else:
-                    if t.startswith(b'"'):
-                        t = self._re_unescape.sub(rb"\1", t[1:-1])
-                    v = cast(t)
-
-                stack[-1].append(v)
-
-        assert rv is not None
-        return rv
-
-    _re_unescape = re.compile(rb"\\(.)")
-
-
-@cache
-def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]:
-    """
-    Return a regexp to tokenize an array representation into item and brackets
-    """
-    return re.compile(
-        rb"""(?xi)
-        (     [{}]                        # open or closed bracket
-            | " (?: [^"\\] | \\. )* "     # or a quoted string
-            | [^"{}%s\\]+                 # or an unquoted non-empty string
-        ) ,?
-        """
-        % delimiter
-    )
+        load = self._tx.get_loader(self.base_oid, self.format).load
+        return load_text(data, load, self.delimiter)
 
 
 class ArrayBinaryLoader(BaseArrayLoader):
@@ -370,35 +308,8 @@ class ArrayBinaryLoader(BaseArrayLoader):
     format = pq.Format.BINARY
 
     def load(self, data: Buffer) -> List[Any]:
-        ndims, hasnull, oid = _unpack_head(data)
-        if not ndims:
-            return []
-
-        fcast = self._tx.get_loader(oid, self.format).load
-
-        p = 12 + 8 * ndims
-        dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))]
-
-        def consume(p: int) -> Iterator[Any]:
-            while True:
-                size = unpack_len(data, p)[0]
-                p += 4
-                if size != -1:
-                    yield fcast(data[p : p + size])
-                    p += size
-                else:
-                    yield None
-
-        items = consume(p)
-
-        def agg(dims: List[int]) -> List[Any]:
-            if not dims:
-                return next(items)
-            else:
-                dim, dims = dims[0], dims[1:]
-                return [agg(dims) for _ in range(dim)]
-
-        return agg(dims)
+        load = self._tx.get_loader(self.base_oid, self.format).load
+        return load_binary(data, load)
 
 
 def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
@@ -459,3 +370,100 @@ def register_all_arrays(context: AdaptContext) -> None:
     for t in context.adapters.types:
         if t.array_oid:
             t.register(context)
+
+
+def load_text(
+    data: Buffer,
+    load: LoadFunc,
+    delimiter: bytes = b",",
+    __re_unescape: Pattern[bytes] = re.compile(rb"\\(.)"),
+) -> List[Any]:
+    rv = None
+    stack: List[Any] = []
+
+    # Remove the dimensions information prefix (``[...]=``)
+    if data and data[0] == b"["[0]:
+        if isinstance(data, memoryview):
+            data = bytes(data)
+        idx = data.find(b"=")
+        if idx == -1:
+            raise e.DataError("malformed array, no '=' after dimension information")
+        data = data[idx + 1 :]
+
+    re_parse = _get_array_parse_regexp(delimiter)
+    for m in re_parse.finditer(data):
+        t = m.group(1)
+        if t == b"{":
+            a: List[Any] = []
+            if rv is None:
+                rv = a
+            if stack:
+                stack[-1].append(a)
+            stack.append(a)
+
+        elif t == b"}":
+            if not stack:
+                raise e.DataError("malformed array, unexpected '}'")
+            rv = stack.pop()
+
+        else:
+            if not stack:
+                wat = t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else ""
+                raise e.DataError(f"malformed array, unexpected '{wat}'")
+            if t == b"NULL":
+                v = None
+            else:
+                if t.startswith(b'"'):
+                    t = __re_unescape.sub(rb"\1", t[1:-1])
+                v = load(t)
+
+            stack[-1].append(v)
+
+    assert rv is not None
+    return rv
+
+
+@cache
+def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]:
+    """
+    Return a regexp to tokenize an array representation into item and brackets
+    """
+    return re.compile(
+        rb"""(?xi)
+        (     [{}]                        # open or closed bracket
+            | " (?: [^"\\] | \\. )* "     # or a quoted string
+            | [^"{}%s\\]+                 # or an unquoted non-empty string
+        ) ,?
+        """
+        % delimiter
+    )
+
+
+def load_binary(data: Buffer, load: LoadFunc) -> List[Any]:
+    ndims, hasnull, oid = _unpack_head(data)
+    if not ndims:
+        return []
+
+    p = 12 + 8 * ndims
+    dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))]
+
+    def consume(p: int) -> Iterator[Any]:
+        while True:
+            size = unpack_len(data, p)[0]
+            p += 4
+            if size != -1:
+                yield load(data[p : p + size])
+                p += size
+            else:
+                yield None
+
+    items = consume(p)
+
+    def agg(dims: List[int]) -> List[Any]:
+        if not dims:
+            return next(items)
+        else:
+            dim, dims = dims[0], dims[1:]
+            return [agg(dims) for _ in range(dim)]
+
+    return agg(dims)