]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added basic array adaptation infrastructure
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 3 Apr 2020 11:17:31 +0000 (00:17 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 3 Apr 2020 11:20:26 +0000 (00:20 +1300)
psycopg3/types/__init__.py
psycopg3/types/array.py [new file with mode: 0644]
psycopg3/types/numeric.py
psycopg3/types/text.py
tests/types/test_array.py [new file with mode: 0644]

index 5dbaf107354a695674c92c55f19e3b3312b69aef..8379dfd225dcb209ffdffdb2268cce18a82e0634 100644 (file)
@@ -8,6 +8,6 @@ psycopg3 types package
 from .oids import builtins
 
 # Register default adapters
-from . import numeric, text  # noqa
+from . import array, numeric, text  # noqa
 
 __all__ = ["builtins"]
diff --git a/psycopg3/types/array.py b/psycopg3/types/array.py
new file mode 100644 (file)
index 0000000..90b1bc6
--- /dev/null
@@ -0,0 +1,169 @@
+"""
+Adapters for arrays
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+from typing import Any, Callable, List, Optional, cast, TYPE_CHECKING
+
+from .. import errors as e
+from ..pq import Format
+from ..adapt import Adapter, Typecaster, Transformer, UnknownCaster
+from ..adapt import AdaptContext, TypecasterType, TypecasterFunc
+
+if TYPE_CHECKING:
+    from ..connection import BaseConnection
+
+
+# from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
+#
+# The array output routine will put double quotes around element values if they
+# are empty strings, contain curly braces, delimiter characters, double quotes,
+# backslashes, or white space, or match the word NULL.
+# TODO: recognise only , as delimiter. Should be configured
+_re_needs_quote = re.compile(
+    br"""(?xi)
+      ^$              # the empty string
+    | ["{},\\\s]      # or a char to escape
+    | ^null$          # or the word NULL
+    """
+)
+
+# Double quotes and backslashes embedded in element values will be
+# backslash-escaped.
+_re_escape = re.compile(br'(["\\])')
+_re_unescape = re.compile(br"\\(.)")
+
+
+# Tokenize an array representation into item and brackets
+# TODO: currently recognise only , as delimiter. Should be configured
+_re_parse = re.compile(
+    br"""(?xi)
+    (     [{}]                        # open or closed bracket
+        | " (?: [^"\\] | \\. )* "     # or a quoted string
+        | [^"{},\\]+                  # or an unquoted non-empty string
+    ) ,?
+    """
+)
+
+
+def escape_item(item: Optional[bytes]) -> bytes:
+    if item is None:
+        return b"NULL"
+    if _re_needs_quote.search(item) is None:
+        return item
+    else:
+        return b'"' + _re_escape.sub(br"\\\1", item) + b'"'
+
+
+@Adapter.text(list)
+class ListAdapter(Adapter):
+    def __init__(self, cls: type, conn: "BaseConnection"):
+        super().__init__(cls, conn)
+        self.tx = Transformer(conn)
+
+    def adapt(self, obj: List[Any]) -> bytes:
+        tokens: List[bytes] = []
+        self.adapt_list(obj, tokens)
+        return b"".join(tokens)
+
+    def adapt_list(self, obj: List[Any], tokens: List[bytes]) -> None:
+        if not obj:
+            tokens.append(b"{}")
+            return
+
+        tokens.append(b"{")
+        for item in obj:
+            if isinstance(item, list):
+                self.adapt_list(item, tokens)
+            elif item is None:
+                tokens.append(b"NULL")
+            else:
+                ad = self.tx.adapt(item)
+                if isinstance(ad, tuple):
+                    ad = ad[0]
+                tokens.append(escape_item(ad))
+
+            tokens.append(b",")
+
+        tokens[-1] = b"}"
+
+
+class ArrayCasterBase(Typecaster):
+    base_caster: TypecasterType
+
+    def __init__(
+        self, oid: int, conn: Optional["BaseConnection"],
+    ):
+        super().__init__(oid, conn)
+        self.caster_func = TypecasterFunc  # type: ignore
+
+        if isinstance(self.base_caster, type):
+            self.caster_func = self.base_caster(oid, conn).cast
+        else:
+            self.caster_func = cast(TypecasterFunc, type(self).base_caster)
+
+    def cast(self, data: bytes) -> List[Any]:
+        rv = None
+        stack: List[Any] = []
+        for m in _re_parse.finditer(data):
+            t = m.group(1)
+            if t == b"{":
+                a: List[Any] = []
+                if rv is None:
+                    rv = a
+                if stack:
+                    stack[-1].append(a)
+                stack.append(a)
+
+            elif t == b"}":
+                if not stack:
+                    raise e.DataError("malformed array, unexpected '}'")
+                rv = stack.pop()
+
+            else:
+                if not stack:
+                    raise e.DataError(
+                        f"malformed array, unexpected"
+                        f" '{t.decode('utf8', 'replace')}'"
+                    )
+                if t == b"NULL":
+                    v = None
+                else:
+                    if t.startswith(b'"'):
+                        t = _re_unescape.sub(br"\1", t[1:-1])
+                    v = self.caster_func(t)
+
+                stack[-1].append(v)
+
+        assert rv is not None
+        return rv
+
+
+class ArrayCaster(Typecaster):
+    @staticmethod
+    def register(
+        oid: int,  # array oid
+        caster: TypecasterType,
+        context: AdaptContext = None,
+        format: Format = Format.TEXT,
+    ) -> TypecasterType:
+        t = type(
+            caster.__name__ + "_array",  # type: ignore
+            (ArrayCasterBase,),
+            {"base_caster": caster},
+        )
+        return Typecaster.register(oid, t, context=context, format=format)
+
+    @staticmethod
+    def text(oid: int) -> Callable[[Any], Any]:
+        def text_(caster: TypecasterType) -> TypecasterType:
+            ArrayCaster.register(oid, caster, format=Format.TEXT)
+            return caster
+
+        return text_
+
+
+class UnknownArrayCaster(ArrayCasterBase):
+    base_caster = UnknownCaster
index 9b22a2bb6a30a101baf904a2a28cfc8c39d1b272..c942f2a30367f85e4392712b0e6adebcaaf7261b 100644 (file)
@@ -11,6 +11,7 @@ from typing import Tuple
 
 from ..adapt import Adapter, Typecaster
 from .oids import builtins
+from .array import ArrayCaster
 
 FLOAT8_OID = builtins["float8"].oid
 NUMERIC_OID = builtins["numeric"].oid
@@ -59,6 +60,10 @@ def adapt_bool(obj: bool) -> Tuple[bytes, int]:
 @Typecaster.text(builtins["int4"].oid)
 @Typecaster.text(builtins["int8"].oid)
 @Typecaster.text(builtins["oid"].oid)
+@ArrayCaster.text(builtins["int2"].array_oid)
+@ArrayCaster.text(builtins["int4"].array_oid)
+@ArrayCaster.text(builtins["int8"].array_oid)
+@ArrayCaster.text(builtins["oid"].array_oid)
 def cast_int(data: bytes) -> int:
     return int(_decode(data)[0])
 
index 20d473fc259a1eb4c33a18d7af561b63debfc9dd..4083d30225c70dfbc4e928e3b096836e145ed0e2 100644 (file)
@@ -15,6 +15,7 @@ from ..connection import BaseConnection
 from ..utils.typing import EncodeFunc, DecodeFunc
 from ..pq import Escaping
 from .oids import builtins
+from .array import ArrayCaster
 
 TEXT_OID = builtins["text"].oid
 BYTEA_OID = builtins["bytea"].oid
@@ -41,6 +42,7 @@ class StringAdapter(Adapter):
 
 @Typecaster.text(builtins["text"].oid)
 @Typecaster.binary(builtins["text"].oid)
+@ArrayCaster.text(builtins["text"].array_oid)
 class StringCaster(Typecaster):
 
     decode: Optional[DecodeFunc]
@@ -80,6 +82,7 @@ def adapt_bytes(b: bytes) -> Tuple[bytes, int]:
 
 
 @Typecaster.text(builtins["bytea"].oid)
+@ArrayCaster.text(builtins["bytea"].array_oid)
 def cast_bytea(data: bytes) -> bytes:
     return Escaping.unescape_bytea(data)
 
diff --git a/tests/types/test_array.py b/tests/types/test_array.py
new file mode 100644 (file)
index 0000000..48853ff
--- /dev/null
@@ -0,0 +1,96 @@
+import pytest
+from psycopg3.types import builtins
+from psycopg3.adapt import Typecaster, UnknownCaster
+from psycopg3.types.array import UnknownArrayCaster, ArrayCaster
+
+
+tests_str = [
+    ([], "{}"),
+    (["foo", "bar", "baz"], "{foo,bar,baz}"),
+    (["foo", None, "baz"], "{foo,null,baz}"),
+    (["foo", "null", "", "baz"], '{foo,"null","",baz}'),
+    ([["foo", "bar"], ["baz", "qux"]], "{{foo,bar},{baz,qux}}"),
+    (
+        [[["fo{o", "ba}r"], ['ba"z', "qu'x"], ["qu ux", " "]]],
+        r'{{{"fo{o","ba}r"},{"ba\"z",qu\'x},{"qu ux"," "}}}',
+    ),
+]
+
+
+@pytest.mark.parametrize("obj, want", tests_str)
+def test_adapt_list_str(conn, obj, want):
+    cur = conn.cursor()
+    cur.execute("select %s::text[] = %s::text[]", (obj, want))
+    assert cur.fetchone()[0]
+
+
+@pytest.mark.parametrize("want, obj", tests_str)
+def test_cast_list_str(conn, obj, want):
+    cur = conn.cursor()
+    cur.execute("select %s::text[]", (obj,))
+    assert cur.fetchone()[0] == want
+
+
+def test_all_chars(conn):
+    cur = conn.cursor()
+    for i in range(1, 256):
+        c = chr(i)
+        cur.execute("select %s::text[]", ([c],))
+        assert cur.fetchone()[0] == [c]
+
+    a = list(map(chr, range(1, 256)))
+    a.append("\u20ac")
+    cur.execute("select %s::text[]", (a,))
+    assert cur.fetchone()[0] == a
+
+    a = "".join(a)
+    cur.execute("select %s::text[]", ([a],))
+    assert cur.fetchone()[0] == [a]
+
+
+tests_int = [
+    ([], "{}"),
+    ([10, 20, -30], "{10,20,-30}"),
+    ([10, None, 30], "{10,null,30}"),
+    ([[10, 20], [30, 40]], "{{10,20},{30,40}}"),
+]
+
+
+@pytest.mark.parametrize("obj, want", tests_int)
+def test_adapt_list_int(conn, obj, want):
+    cur = conn.cursor()
+    cur.execute("select %s::int[] = %s::int[]", (obj, want))
+    assert cur.fetchone()[0]
+
+
+@pytest.mark.parametrize("want, obj", tests_int)
+def test_cast_list_int(conn, obj, want):
+    cur = conn.cursor()
+    cur.execute("select %s::int[]", (obj,))
+    assert cur.fetchone()[0] == want
+
+
+def test_unknown(conn):
+    # unknown for real
+    assert builtins["aclitem"].array_oid not in Typecaster.globals
+    Typecaster.register(
+        builtins["aclitem"].array_oid, UnknownArrayCaster, context=conn
+    )
+    cur = conn.cursor()
+    cur.execute("select '{postgres=arwdDxt/postgres}'::aclitem[]")
+    res = cur.fetchone()[0]
+    assert res == ["postgres=arwdDxt/postgres"]
+
+
+def test_array_register(conn):
+    cur = conn.cursor()
+    cur.execute("select '{postgres=arwdDxt/postgres}'::aclitem[]")
+    res = cur.fetchone()[0]
+    assert res == "{postgres=arwdDxt/postgres}"
+
+    ArrayCaster.register(
+        builtins["aclitem"].array_oid, UnknownCaster, context=conn
+    )
+    cur.execute("select '{postgres=arwdDxt/postgres}'::aclitem[]")
+    res = cur.fetchone()[0]
+    assert res == ["postgres=arwdDxt/postgres"]