from .. import pq
from .. import errors as e
from .. import postgres
-from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType, LoadFunc
from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
from .._compat import cache
from .._struct import pack_len, unpack_len
delimiter = b","
def load(self, data: Buffer) -> List[Any]:
- rv = None
- stack: List[Any] = []
- cast = self._tx.get_loader(self.base_oid, self.format).load
-
- # Remove the dimensions information prefix (``[...]=``)
- if data and data[0] == b"["[0]:
- if isinstance(data, memoryview):
- data = bytes(data)
- idx = data.find(b"=")
- if idx == -1:
- raise e.DataError("malformed array, no '=' after dimension information")
- data = data[idx + 1 :]
-
- re_parse = _get_array_parse_regexp(self.delimiter)
- for m in re_parse.finditer(data):
- t = m.group(1)
- if t == b"{":
- a: List[Any] = []
- if rv is None:
- rv = a
- if stack:
- stack[-1].append(a)
- stack.append(a)
-
- elif t == b"}":
- if not stack:
- raise e.DataError("malformed array, unexpected '}'")
- rv = stack.pop()
-
- else:
- if not stack:
- wat = (
- t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else ""
- )
- raise e.DataError(f"malformed array, unexpected '{wat}'")
- if t == b"NULL":
- v = None
- else:
- if t.startswith(b'"'):
- t = self._re_unescape.sub(rb"\1", t[1:-1])
- v = cast(t)
-
- stack[-1].append(v)
-
- assert rv is not None
- return rv
-
- _re_unescape = re.compile(rb"\\(.)")
-
-
-@cache
-def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]:
- """
- Return a regexp to tokenize an array representation into item and brackets
- """
- return re.compile(
- rb"""(?xi)
- ( [{}] # open or closed bracket
- | " (?: [^"\\] | \\. )* " # or a quoted string
- | [^"{}%s\\]+ # or an unquoted non-empty string
- ) ,?
- """
- % delimiter
- )
+ load = self._tx.get_loader(self.base_oid, self.format).load
+ return load_text(data, load, self.delimiter)
class ArrayBinaryLoader(BaseArrayLoader):
format = pq.Format.BINARY
def load(self, data: Buffer) -> List[Any]:
- ndims, hasnull, oid = _unpack_head(data)
- if not ndims:
- return []
-
- fcast = self._tx.get_loader(oid, self.format).load
-
- p = 12 + 8 * ndims
- dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))]
-
- def consume(p: int) -> Iterator[Any]:
- while True:
- size = unpack_len(data, p)[0]
- p += 4
- if size != -1:
- yield fcast(data[p : p + size])
- p += size
- else:
- yield None
-
- items = consume(p)
-
- def agg(dims: List[int]) -> List[Any]:
- if not dims:
- return next(items)
- else:
- dim, dims = dims[0], dims[1:]
- return [agg(dims) for _ in range(dim)]
-
- return agg(dims)
+ load = self._tx.get_loader(self.base_oid, self.format).load
+ return load_binary(data, load)
def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
for t in context.adapters.types:
if t.array_oid:
t.register(context)
+
+
+def load_text(
+ data: Buffer,
+ load: LoadFunc,
+ delimiter: bytes = b",",
+ __re_unescape: Pattern[bytes] = re.compile(rb"\\(.)"),
+) -> List[Any]:
+ rv = None
+ stack: List[Any] = []
+
+ # Remove the dimensions information prefix (``[...]=``)
+ if data and data[0] == b"["[0]:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ idx = data.find(b"=")
+ if idx == -1:
+ raise e.DataError("malformed array, no '=' after dimension information")
+ data = data[idx + 1 :]
+
+ re_parse = _get_array_parse_regexp(delimiter)
+ for m in re_parse.finditer(data):
+ t = m.group(1)
+ if t == b"{":
+ a: List[Any] = []
+ if rv is None:
+ rv = a
+ if stack:
+ stack[-1].append(a)
+ stack.append(a)
+
+ elif t == b"}":
+ if not stack:
+ raise e.DataError("malformed array, unexpected '}'")
+ rv = stack.pop()
+
+ else:
+ if not stack:
+ wat = t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else ""
+ raise e.DataError(f"malformed array, unexpected '{wat}'")
+ if t == b"NULL":
+ v = None
+ else:
+ if t.startswith(b'"'):
+ t = __re_unescape.sub(rb"\1", t[1:-1])
+ v = load(t)
+
+ stack[-1].append(v)
+
+ assert rv is not None
+ return rv
+
+
+@cache
+def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]:
+ """
+ Return a regexp to tokenize an array representation into item and brackets
+ """
+ return re.compile(
+ rb"""(?xi)
+ ( [{}] # open or closed bracket
+ | " (?: [^"\\] | \\. )* " # or a quoted string
+ | [^"{}%s\\]+ # or an unquoted non-empty string
+ ) ,?
+ """
+ % delimiter
+ )
+
+
+def load_binary(data: Buffer, load: LoadFunc) -> List[Any]:
+ ndims, hasnull, oid = _unpack_head(data)
+ if not ndims:
+ return []
+
+ p = 12 + 8 * ndims
+ dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))]
+
+ def consume(p: int) -> Iterator[Any]:
+ while True:
+ size = unpack_len(data, p)[0]
+ p += 4
+ if size != -1:
+ yield load(data[p : p + size])
+ p += size
+ else:
+ yield None
+
+ items = consume(p)
+
+ def agg(dims: List[int]) -> List[Any]:
+ if not dims:
+ return next(items)
+ else:
+ dim, dims = dims[0], dims[1:]
+ return [agg(dims) for _ in range(dim)]
+
+ return agg(dims)