]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Dump lists keeping the type delimiter into account
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Aug 2021 20:55:09 +0000 (22:55 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 25 Aug 2021 01:05:17 +0000 (03:05 +0200)
psycopg/psycopg/types/array.py
tests/types/test_array.py

index 997c2aae15bf9ac80de140ec650d7ea7ea1fe7d5..bd3fa9b66c9df5eedccfeade58ba0dba4bfb74d6 100644 (file)
@@ -64,24 +64,24 @@ class BaseListDumper(RecursiveDumper):
 
         return None
 
-    def _get_array_oid(self, base_oid: int) -> int:
+    def _get_base_type_info(self, base_oid: int) -> TypeInfo:
         """
-        Return the oid of the array from the oid of the base item.
+        Return info about the base type.
 
-        Fall back on text[].
+        Return text info as fallback.
         """
-        oid = 0
         if base_oid:
             info = self._tx.adapters.types.get(base_oid)
             if info:
-                oid = info.array_oid
+                return info
 
-        return oid or TEXT_ARRAY_OID
+        return self._tx.adapters.types["text"]
 
 
 class ListDumper(BaseListDumper):
 
     format = pq.Format.TEXT
+    delimiter = b","
 
     def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
         if self.oid:
@@ -120,32 +120,21 @@ class ListDumper(BaseListDumper):
         # We consider an array of unknowns as unknown, so we can dump empty
         # lists or lists containing only None elements.
         if sd.oid != INVALID_OID:
-            dumper.oid = self._get_array_oid(sd.oid)
+            info = self._get_base_type_info(sd.oid)
+            dumper.oid = info.array_oid or TEXT_ARRAY_OID
+            dumper.delimiter = info.delimiter.encode("utf-8")
         else:
             dumper.oid = INVALID_OID
 
         return dumper
 
-    # from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
-    #
-    # The array output routine will put double quotes around element values if
-    # they are empty strings, contain curly braces, delimiter characters,
-    # double quotes, backslashes, or white space, or match the word NULL.
-    # TODO: recognise only , as delimiter. Should be configured
-    _re_needs_quotes = re.compile(
-        br"""(?xi)
-          ^$              # the empty string
-        | ["{},\\\s]      # or a char to escape
-        | ^null$          # or the word NULL
-        """
-    )
-
     # Double quotes and backslashes embedded in element values will be
     # backslash-escaped.
     _re_esc = re.compile(br'(["\\])')
 
     def dump(self, obj: List[Any]) -> bytes:
         tokens: List[bytes] = []
+        needs_quotes = _get_needs_quotes_regexp(self.delimiter).search
 
         def dump_list(obj: List[Any]) -> None:
             if not obj:
@@ -158,7 +147,7 @@ class ListDumper(BaseListDumper):
                     dump_list(item)
                 elif item is not None:
                     ad = self._dump_item(item)
-                    if self._re_needs_quotes.search(ad):
+                    if needs_quotes(ad):
                         if not isinstance(ad, bytes):
                             ad = bytes(ad)
                         ad = b'"' + self._re_esc.sub(br"\\\1", ad) + b'"'
@@ -166,7 +155,7 @@ class ListDumper(BaseListDumper):
                 else:
                     tokens.append(b"NULL")
 
-                tokens.append(b",")
+                tokens.append(self.delimiter)
 
             tokens[-1] = b"}"
 
@@ -181,6 +170,26 @@ class ListDumper(BaseListDumper):
             return self._tx.get_dumper(item, PyFormat.TEXT).dump(item)
 
 
+@lru_cache()
+def _get_needs_quotes_regexp(delimiter: bytes) -> Pattern[bytes]:
+    """Return a regexp to recognise when a value needs quotes
+
+    from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
+
+    The array output routine will put double quotes around element values if
+    they are empty strings, contain curly braces, delimiter characters,
+    double quotes, backslashes, or white space, or match the word NULL.
+    """
+    return re.compile(
+        br"""(?xi)
+          ^$              # the empty string
+        | ["{}%s\\\s]      # or a char to escape
+        | ^null$          # or the word NULL
+        """
+        % delimiter
+    )
+
+
 class MixedItemsListDumper(ListDumper):
     """
     An array dumper that doesn't assume that all the items are the same type.
@@ -234,7 +243,8 @@ class ListBinaryDumper(BaseListDumper):
         sd = self._tx.get_dumper(item, format.from_pq(self.format))
         dumper = type(self)(self.cls, self._tx)
         dumper.sub_dumper = sd
-        dumper.oid = self._get_array_oid(sd.oid)
+        info = self._get_base_type_info(sd.oid)
+        dumper.oid = info.array_oid or TEXT_ARRAY_OID
 
         return dumper
 
index 8d205508716ee212bf1273ea61d694daccb09579..d3c084c779248bf6ddb4e438d071a0c55669c29a 100644 (file)
@@ -5,7 +5,7 @@ import pytest
 import psycopg
 from psycopg import pq
 from psycopg import sql
-from psycopg.adapt import PyFormat as Format, Transformer
+from psycopg.adapt import PyFormat as Format, Transformer, Dumper
 from psycopg.types import TypeInfo
 from psycopg.postgres import types as builtins
 
@@ -222,7 +222,33 @@ def test_empty_list_after_choice(conn, fmt_in):
     assert cur.fetchall() == [([1.0],), ([],)]
 
 
-def test_array_no_comma_separator(conn):
+def test_dump_list_no_comma_separator(conn):
+    class Box:
+        def __init__(self, x1, y1, x2, y2):
+            self.coords = (x1, y1, x2, y2)
+
+    class BoxDumper(Dumper):
+
+        format = pq.Format.TEXT
+        _oid = psycopg.postgres.types["box"].oid
+
+        def dump(self, box):
+            return ("(%s,%s),(%s,%s)" % box.coords).encode("utf8")
+
+    conn.adapters.register_dumper(Box, BoxDumper)
+
+    cur = conn.execute("select (%s::box)::text", (Box(1, 2, 3, 4),))
+    got = cur.fetchone()[0]
+    assert got == "(3,4),(1,2)"
+
+    cur = conn.execute(
+        "select (%s::box[])::text", ([Box(1, 2, 3, 4), Box(5, 4, 3, 2)],)
+    )
+    got = cur.fetchone()[0]
+    assert got == "{(3,4),(1,2);(5,4),(3,2)}"
+
+
+def test_load_array_no_comma_separator(conn):
     cur = conn.execute("select '{(2,2),(1,1);(5,6),(3,4)}'::box[]")
     # Not parsed at the moment, but split ok on ; separator
     assert cur.fetchone()[0] == ["(2,2),(1,1)", "(5,6),(3,4)"]