From: Daniele Varrazzo Date: Tue, 30 Aug 2022 09:36:46 +0000 (+0100) Subject: refactor(array): make load function static X-Git-Tag: 3.1.5~12^2~16 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=43ff49a99ebbee107503e9d4007dfbaecb579da4;p=thirdparty%2Fpsycopg.git refactor(array): make load function static This will allow replacing them with a fast version as with the helper functions in the copy module. --- diff --git a/psycopg/psycopg/types/array.py b/psycopg/psycopg/types/array.py index 202cd5e3e..addcb4339 100644 --- a/psycopg/psycopg/types/array.py +++ b/psycopg/psycopg/types/array.py @@ -12,7 +12,7 @@ from typing import Optional, Pattern, Set, Tuple, Type from .. import pq from .. import errors as e from .. import postgres -from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType +from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType, LoadFunc from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat from .._compat import cache from .._struct import pack_len, unpack_len @@ -299,70 +299,8 @@ class ArrayLoader(BaseArrayLoader): delimiter = b"," def load(self, data: Buffer) -> List[Any]: - rv = None - stack: List[Any] = [] - cast = self._tx.get_loader(self.base_oid, self.format).load - - # Remove the dimensions information prefix (``[...]=``) - if data and data[0] == b"["[0]: - if isinstance(data, memoryview): - data = bytes(data) - idx = data.find(b"=") - if idx == -1: - raise e.DataError("malformed array, no '=' after dimension information") - data = data[idx + 1 :] - - re_parse = _get_array_parse_regexp(self.delimiter) - for m in re_parse.finditer(data): - t = m.group(1) - if t == b"{": - a: List[Any] = [] - if rv is None: - rv = a - if stack: - stack[-1].append(a) - stack.append(a) - - elif t == b"}": - if not stack: - raise e.DataError("malformed array, unexpected '}'") - rv = stack.pop() - - else: - if not stack: - wat = ( - t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else "" - ) - raise e.DataError(f"malformed array, unexpected '{wat}'") - if t == b"NULL": - v = None - else: - if t.startswith(b'"'): - t = self._re_unescape.sub(rb"\1", t[1:-1]) - v = cast(t) - - stack[-1].append(v) - - assert rv is not None - return rv - - _re_unescape = re.compile(rb"\\(.)") - - -@cache -def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]: - """ - Return a regexp to tokenize an array representation into item and brackets - """ - return re.compile( - rb"""(?xi) - ( [{}] # open or closed bracket - | " (?: [^"\\] | \\. )* " # or a quoted string - | [^"{}%s\\]+ # or an unquoted non-empty string - ) ,? - """ - % delimiter - ) + load = self._tx.get_loader(self.base_oid, self.format).load + return load_text(data, load, self.delimiter) class ArrayBinaryLoader(BaseArrayLoader): @@ -370,35 +308,8 @@ class ArrayBinaryLoader(BaseArrayLoader): format = pq.Format.BINARY def load(self, data: Buffer) -> List[Any]: - ndims, hasnull, oid = _unpack_head(data) - if not ndims: - return [] - - fcast = self._tx.get_loader(oid, self.format).load - - p = 12 + 8 * ndims - dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))] - - def consume(p: int) -> Iterator[Any]: - while True: - size = unpack_len(data, p)[0] - p += 4 - if size != -1: - yield fcast(data[p : p + size]) - p += size - else: - yield None - - items = consume(p) - - def agg(dims: List[int]) -> List[Any]: - if not dims: - return next(items) - else: - dim, dims = dims[0], dims[1:] - return [agg(dims) for _ in range(dim)] - - return agg(dims) + load = self._tx.get_loader(self.base_oid, self.format).load + return load_binary(data, load) def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None: @@ -459,3 +370,100 @@ def register_all_arrays(context: AdaptContext) -> None: for t in context.adapters.types: if t.array_oid: t.register(context) + + +def load_text( + data: Buffer, + load: LoadFunc, + delimiter: bytes = b",", + __re_unescape: Pattern[bytes] = re.compile(rb"\\(.)"), +) -> List[Any]: + rv = None + stack: List[Any] = [] + + # Remove the dimensions information prefix (``[...]=``) + if data and data[0] == b"["[0]: + if isinstance(data, memoryview): + data = bytes(data) + idx = data.find(b"=") + if idx == -1: + raise e.DataError("malformed array, no '=' after dimension information") + data = data[idx + 1 :] + + re_parse = _get_array_parse_regexp(delimiter) + for m in re_parse.finditer(data): + t = m.group(1) + if t == b"{": + a: List[Any] = [] + if rv is None: + rv = a + if stack: + stack[-1].append(a) + stack.append(a) + + elif t == b"}": + if not stack: + raise e.DataError("malformed array, unexpected '}'") + rv = stack.pop() + + else: + if not stack: + wat = t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else "" + raise e.DataError(f"malformed array, unexpected '{wat}'") + if t == b"NULL": + v = None + else: + if t.startswith(b'"'): + t = __re_unescape.sub(rb"\1", t[1:-1]) + v = load(t) + + stack[-1].append(v) + + assert rv is not None + return rv + + +@cache +def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]: + """ + Return a regexp to tokenize an array representation into item and brackets + """ + return re.compile( + rb"""(?xi) + ( [{}] # open or closed bracket + | " (?: [^"\\] | \\. )* " # or a quoted string + | [^"{}%s\\]+ # or an unquoted non-empty string + ) ,? + """ + % delimiter + ) + + +def load_binary(data: Buffer, load: LoadFunc) -> List[Any]: + ndims, hasnull, oid = _unpack_head(data) + if not ndims: + return [] + + p = 12 + 8 * ndims + dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))] + + def consume(p: int) -> Iterator[Any]: + while True: + size = unpack_len(data, p)[0] + p += 4 + if size != -1: + yield load(data[p : p + size]) + p += size + else: + yield None + + items = consume(p) + + def agg(dims: List[int]) -> List[Any]: + if not dims: + return next(items) + else: + dim, dims = dims[0], dims[1:] + return [agg(dims) for _ in range(dim)] + + return agg(dims)