From a3f3b7843e8f10a20d12f146578fcbb1196e6f2b Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sat, 4 Apr 2020 00:17:31 +1300 Subject: [PATCH] Added basic array adaptation infrastructure --- psycopg3/types/__init__.py | 2 +- psycopg3/types/array.py | 169 +++++++++++++++++++++++++++++++++++++ psycopg3/types/numeric.py | 5 ++ psycopg3/types/text.py | 3 + tests/types/test_array.py | 96 +++++++++++++++++++++ 5 files changed, 274 insertions(+), 1 deletion(-) create mode 100644 psycopg3/types/array.py create mode 100644 tests/types/test_array.py diff --git a/psycopg3/types/__init__.py b/psycopg3/types/__init__.py index 5dbaf1073..8379dfd22 100644 --- a/psycopg3/types/__init__.py +++ b/psycopg3/types/__init__.py @@ -8,6 +8,6 @@ psycopg3 types package from .oids import builtins # Register default adapters -from . import numeric, text # noqa +from . import array, numeric, text # noqa __all__ = ["builtins"] diff --git a/psycopg3/types/array.py b/psycopg3/types/array.py new file mode 100644 index 000000000..90b1bc669 --- /dev/null +++ b/psycopg3/types/array.py @@ -0,0 +1,169 @@ +""" +Adapters for arrays +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +from typing import Any, Callable, List, Optional, cast, TYPE_CHECKING + +from .. import errors as e +from ..pq import Format +from ..adapt import Adapter, Typecaster, Transformer, UnknownCaster +from ..adapt import AdaptContext, TypecasterType, TypecasterFunc + +if TYPE_CHECKING: + from ..connection import BaseConnection + + +# from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO +# +# The array output routine will put double quotes around element values if they +# are empty strings, contain curly braces, delimiter characters, double quotes, +# backslashes, or white space, or match the word NULL. +# TODO: recognise only , as delimiter. Should be configured +_re_needs_quote = re.compile( + br"""(?xi) + ^$ # the empty string + | ["{},\\\s] # or a char to escape + | ^null$ # or the word NULL + """ +) + +# Double quotes and backslashes embedded in element values will be +# backslash-escaped. +_re_escape = re.compile(br'(["\\])') +_re_unescape = re.compile(br"\\(.)") + + +# Tokenize an array representation into item and brackets +# TODO: currently recognise only , as delimiter. Should be configured +_re_parse = re.compile( + br"""(?xi) + ( [{}] # open or closed bracket + | " (?: [^"\\] | \\. )* " # or a quoted string + | [^"{},\\]+ # or an unquoted non-empty string + ) ,? + """ +) + + +def escape_item(item: Optional[bytes]) -> bytes: + if item is None: + return b"NULL" + if _re_needs_quote.search(item) is None: + return item + else: + return b'"' + _re_escape.sub(br"\\\1", item) + b'"' + + +@Adapter.text(list) +class ListAdapter(Adapter): + def __init__(self, cls: type, conn: "BaseConnection"): + super().__init__(cls, conn) + self.tx = Transformer(conn) + + def adapt(self, obj: List[Any]) -> bytes: + tokens: List[bytes] = [] + self.adapt_list(obj, tokens) + return b"".join(tokens) + + def adapt_list(self, obj: List[Any], tokens: List[bytes]) -> None: + if not obj: + tokens.append(b"{}") + return + + tokens.append(b"{") + for item in obj: + if isinstance(item, list): + self.adapt_list(item, tokens) + elif item is None: + tokens.append(b"NULL") + else: + ad = self.tx.adapt(item) + if isinstance(ad, tuple): + ad = ad[0] + tokens.append(escape_item(ad)) + + tokens.append(b",") + + tokens[-1] = b"}" + + +class ArrayCasterBase(Typecaster): + base_caster: TypecasterType + + def __init__( + self, oid: int, conn: Optional["BaseConnection"], + ): + super().__init__(oid, conn) + self.caster_func = TypecasterFunc # type: ignore + + if isinstance(self.base_caster, type): + self.caster_func = self.base_caster(oid, conn).cast + else: + self.caster_func = cast(TypecasterFunc, type(self).base_caster) + + def cast(self, data: bytes) -> List[Any]: + rv = None + stack: List[Any] = [] + for m in _re_parse.finditer(data): + t = m.group(1) + if t == b"{": + a: List[Any] = [] + if rv is None: + rv = a + if stack: + stack[-1].append(a) + stack.append(a) + + elif t == b"}": + if not stack: + raise e.DataError("malformed array, unexpected '}'") + rv = stack.pop() + + else: + if not stack: + raise e.DataError( + f"malformed array, unexpected" + f" '{t.decode('utf8', 'replace')}'" + ) + if t == b"NULL": + v = None + else: + if t.startswith(b'"'): + t = _re_unescape.sub(br"\1", t[1:-1]) + v = self.caster_func(t) + + stack[-1].append(v) + + assert rv is not None + return rv + + +class ArrayCaster(Typecaster): + @staticmethod + def register( + oid: int, # array oid + caster: TypecasterType, + context: AdaptContext = None, + format: Format = Format.TEXT, + ) -> TypecasterType: + t = type( + caster.__name__ + "_array", # type: ignore + (ArrayCasterBase,), + {"base_caster": caster}, + ) + return Typecaster.register(oid, t, context=context, format=format) + + @staticmethod + def text(oid: int) -> Callable[[Any], Any]: + def text_(caster: TypecasterType) -> TypecasterType: + ArrayCaster.register(oid, caster, format=Format.TEXT) + return caster + + return text_ + + +class UnknownArrayCaster(ArrayCasterBase): + base_caster = UnknownCaster diff --git a/psycopg3/types/numeric.py b/psycopg3/types/numeric.py index 9b22a2bb6..c942f2a30 100644 --- a/psycopg3/types/numeric.py +++ b/psycopg3/types/numeric.py @@ -11,6 +11,7 @@ from typing import Tuple from ..adapt import Adapter, Typecaster from .oids import builtins +from .array import ArrayCaster FLOAT8_OID = builtins["float8"].oid NUMERIC_OID = builtins["numeric"].oid @@ -59,6 +60,10 @@ def adapt_bool(obj: bool) -> Tuple[bytes, int]: @Typecaster.text(builtins["int4"].oid) @Typecaster.text(builtins["int8"].oid) @Typecaster.text(builtins["oid"].oid) +@ArrayCaster.text(builtins["int2"].array_oid) +@ArrayCaster.text(builtins["int4"].array_oid) +@ArrayCaster.text(builtins["int8"].array_oid) +@ArrayCaster.text(builtins["oid"].array_oid) def cast_int(data: bytes) -> int: return int(_decode(data)[0]) diff --git a/psycopg3/types/text.py b/psycopg3/types/text.py index 20d473fc2..4083d3022 100644 --- a/psycopg3/types/text.py +++ b/psycopg3/types/text.py @@ -15,6 +15,7 @@ from ..connection import BaseConnection from ..utils.typing import EncodeFunc, DecodeFunc from ..pq import Escaping from .oids import builtins +from .array import ArrayCaster TEXT_OID = builtins["text"].oid BYTEA_OID = builtins["bytea"].oid @@ -41,6 +42,7 @@ class StringAdapter(Adapter): @Typecaster.text(builtins["text"].oid) @Typecaster.binary(builtins["text"].oid) +@ArrayCaster.text(builtins["text"].array_oid) class StringCaster(Typecaster): decode: Optional[DecodeFunc] @@ -80,6 +82,7 @@ def adapt_bytes(b: bytes) -> Tuple[bytes, int]: @Typecaster.text(builtins["bytea"].oid) +@ArrayCaster.text(builtins["bytea"].array_oid) def cast_bytea(data: bytes) -> bytes: return Escaping.unescape_bytea(data) diff --git a/tests/types/test_array.py b/tests/types/test_array.py new file mode 100644 index 000000000..48853ff57 --- /dev/null +++ b/tests/types/test_array.py @@ -0,0 +1,96 @@ +import pytest +from psycopg3.types import builtins +from psycopg3.adapt import Typecaster, UnknownCaster +from psycopg3.types.array import UnknownArrayCaster, ArrayCaster + + +tests_str = [ + ([], "{}"), + (["foo", "bar", "baz"], "{foo,bar,baz}"), + (["foo", None, "baz"], "{foo,null,baz}"), + (["foo", "null", "", "baz"], '{foo,"null","",baz}'), + ([["foo", "bar"], ["baz", "qux"]], "{{foo,bar},{baz,qux}}"), + ( + [[["fo{o", "ba}r"], ['ba"z', "qu'x"], ["qu ux", " "]]], + r'{{{"fo{o","ba}r"},{"ba\"z",qu\'x},{"qu ux"," "}}}', + ), +] + + +@pytest.mark.parametrize("obj, want", tests_str) +def test_adapt_list_str(conn, obj, want): + cur = conn.cursor() + cur.execute("select %s::text[] = %s::text[]", (obj, want)) + assert cur.fetchone()[0] + + +@pytest.mark.parametrize("want, obj", tests_str) +def test_cast_list_str(conn, obj, want): + cur = conn.cursor() + cur.execute("select %s::text[]", (obj,)) + assert cur.fetchone()[0] == want + + +def test_all_chars(conn): + cur = conn.cursor() + for i in range(1, 256): + c = chr(i) + cur.execute("select %s::text[]", ([c],)) + assert cur.fetchone()[0] == [c] + + a = list(map(chr, range(1, 256))) + a.append("\u20ac") + cur.execute("select %s::text[]", (a,)) + assert cur.fetchone()[0] == a + + a = "".join(a) + cur.execute("select %s::text[]", ([a],)) + assert cur.fetchone()[0] == [a] + + +tests_int = [ + ([], "{}"), + ([10, 20, -30], "{10,20,-30}"), + ([10, None, 30], "{10,null,30}"), + ([[10, 20], [30, 40]], "{{10,20},{30,40}}"), +] + + +@pytest.mark.parametrize("obj, want", tests_int) +def test_adapt_list_int(conn, obj, want): + cur = conn.cursor() + cur.execute("select %s::int[] = %s::int[]", (obj, want)) + assert cur.fetchone()[0] + + +@pytest.mark.parametrize("want, obj", tests_int) +def test_cast_list_int(conn, obj, want): + cur = conn.cursor() + cur.execute("select %s::int[]", (obj,)) + assert cur.fetchone()[0] == want + + +def test_unknown(conn): + # unknown for real + assert builtins["aclitem"].array_oid not in Typecaster.globals + Typecaster.register( + builtins["aclitem"].array_oid, UnknownArrayCaster, context=conn + ) + cur = conn.cursor() + cur.execute("select '{postgres=arwdDxt/postgres}'::aclitem[]") + res = cur.fetchone()[0] + assert res == ["postgres=arwdDxt/postgres"] + + +def test_array_register(conn): + cur = conn.cursor() + cur.execute("select '{postgres=arwdDxt/postgres}'::aclitem[]") + res = cur.fetchone()[0] + assert res == "{postgres=arwdDxt/postgres}" + + ArrayCaster.register( + builtins["aclitem"].array_oid, UnknownCaster, context=conn + ) + cur.execute("select '{postgres=arwdDxt/postgres}'::aclitem[]") + res = cur.fetchone()[0] + assert res == ["postgres=arwdDxt/postgres"] -- 2.47.3