]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Load arrays of types with delimiters different than comma (e.g. box)
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Aug 2021 19:14:50 +0000 (21:14 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 25 Aug 2021 01:05:17 +0000 (03:05 +0200)
psycopg/psycopg/types/array.py
tests/types/test_array.py

index 31900355961b0ad9dc890de6bea490caca8f0e9b..997c2aae15bf9ac80de140ec650d7ea7ea1fe7d5 100644 (file)
@@ -7,8 +7,9 @@ Adapters for arrays
 import re
 import struct
 from decimal import Decimal
-from typing import Any, Callable, Iterator, List, Optional, Set, Tuple, Type
-from typing import cast
+from typing import Any, cast, Callable, Iterator, List
+from typing import Optional, Pattern, Set, Tuple, Type
+from functools import lru_cache
 
 from .. import pq
 from .. import errors as e
@@ -305,24 +306,15 @@ class BaseArrayLoader(RecursiveLoader):
 class ArrayLoader(BaseArrayLoader):
 
     format = pq.Format.TEXT
-
-    # Tokenize an array representation into item and brackets
-    # TODO: currently recognise only , as delimiter. Should be configured
-    _re_parse = re.compile(
-        br"""(?xi)
-        (     [{}]                        # open or closed bracket
-            | " (?: [^"\\] | \\. )* "     # or a quoted string
-            | [^"{},\\]+                  # or an unquoted non-empty string
-        ) ,?
-        """
-    )
+    delimiter = b","
 
     def load(self, data: Buffer) -> List[Any]:
         rv = None
         stack: List[Any] = []
         cast = self._tx.get_loader(self.base_oid, self.format).load
 
-        for m in self._re_parse.finditer(data):
+        re_parse = _get_array_parse_regexp(self.delimiter)
+        for m in re_parse.finditer(data):
             t = m.group(1)
             if t == b"{":
                 a: List[Any] = []
@@ -360,6 +352,22 @@ class ArrayLoader(BaseArrayLoader):
     _re_unescape = re.compile(br"\\(.)")
 
 
+@lru_cache()
+def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]:
+    """
+    Return a regexp to tokenize an array representation into item and brackets
+    """
+    return re.compile(
+        br"""(?xi)
+        (     [{}]                        # open or closed bracket
+            | " (?: [^"\\] | \\. )* "     # or a quoted string
+            | [^"{}%s\\]+                 # or an unquoted non-empty string
+        ) ,?
+        """
+        % delimiter
+    )
+
+
 class ArrayBinaryLoader(BaseArrayLoader):
 
     format = pq.Format.BINARY
@@ -400,12 +408,21 @@ def register_adapters(
     info: TypeInfo, context: Optional[AdaptContext] = None
 ) -> None:
     adapters = context.adapters if context else postgres.adapters
-    for base in (ArrayLoader, ArrayBinaryLoader):
-        lname = f"{info.name.title()}{base.__name__}"
-        loader: Type[BaseArrayLoader] = type(
-            lname, (base,), {"base_oid": info.oid}
-        )
-        adapters.register_loader(info.array_oid, loader)
+
+    base: Type[BaseArrayLoader] = ArrayLoader
+    lname = f"{info.name.title()}{base.__name__}"
+    attribs = {
+        "base_oid": info.oid,
+        "delimiter": info.delimiter.encode("utf-8"),
+    }
+    loader = type(lname, (base,), attribs)
+    adapters.register_loader(info.array_oid, loader)
+
+    base = ArrayBinaryLoader
+    lname = f"{info.name.title()}{base.__name__}"
+    attribs = {"base_oid": info.oid}
+    loader = type(lname, (base,), attribs)
+    adapters.register_loader(info.array_oid, loader)
 
 
 def register_default_adapters(context: AdaptContext) -> None:
@@ -423,6 +440,5 @@ def register_all_arrays(context: AdaptContext) -> None:
     registered all the base loaders.
     """
     for t in context.adapters.types:
-        # TODO: handle different delimiters (box)
-        if t.array_oid and getattr(t, "delimiter", None) == ",":
+        if t.array_oid:
             t.register(context)
index ba31553f72f2c523cdb1b633d121809dc2c8675c..8d205508716ee212bf1273ea61d694daccb09579 100644 (file)
@@ -220,3 +220,9 @@ def test_empty_list_after_choice(conn, fmt_in):
     )
     cur.execute("select data from test order by id")
     assert cur.fetchall() == [([1.0],), ([],)]
+
+
+def test_array_no_comma_separator(conn):
+    cur = conn.execute("select '{(2,2),(1,1);(5,6),(3,4)}'::box[]")
+    # Not parsed at the moment, but split ok on ; separator
+    assert cur.fetchone()[0] == ["(2,2),(1,1)", "(5,6),(3,4)"]