from .oids import builtins
# Register default adapters
-from . import numeric, text # noqa
+from . import array, numeric, text # noqa
__all__ = ["builtins"]
--- /dev/null
+"""
+Adapters for arrays
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+from typing import Any, Callable, List, Optional, cast, TYPE_CHECKING
+
+from .. import errors as e
+from ..pq import Format
+from ..adapt import Adapter, Typecaster, Transformer, UnknownCaster
+from ..adapt import AdaptContext, TypecasterType, TypecasterFunc
+
+if TYPE_CHECKING:
+ from ..connection import BaseConnection
+
+
+# from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
+#
+# The array output routine will put double quotes around element values if they
+# are empty strings, contain curly braces, delimiter characters, double quotes,
+# backslashes, or white space, or match the word NULL.
+# TODO: recognise only , as delimiter. Should be configured
+_re_needs_quote = re.compile(
+ br"""(?xi)
+ ^$ # the empty string
+ | ["{},\\\s] # or a char to escape
+ | ^null$ # or the word NULL
+ """
+)
+
+# Double quotes and backslashes embedded in element values will be
+# backslash-escaped.
+_re_escape = re.compile(br'(["\\])')
+_re_unescape = re.compile(br"\\(.)")
+
+
+# Tokenize an array representation into item and brackets
+# TODO: currently recognise only , as delimiter. Should be configured
+_re_parse = re.compile(
+ br"""(?xi)
+ ( [{}] # open or closed bracket
+ | " (?: [^"\\] | \\. )* " # or a quoted string
+ | [^"{},\\]+ # or an unquoted non-empty string
+ ) ,?
+ """
+)
+
+
+def escape_item(item: Optional[bytes]) -> bytes:
+ if item is None:
+ return b"NULL"
+ if _re_needs_quote.search(item) is None:
+ return item
+ else:
+ return b'"' + _re_escape.sub(br"\\\1", item) + b'"'
+
+
+@Adapter.text(list)
+class ListAdapter(Adapter):
+ def __init__(self, cls: type, conn: "BaseConnection"):
+ super().__init__(cls, conn)
+ self.tx = Transformer(conn)
+
+ def adapt(self, obj: List[Any]) -> bytes:
+ tokens: List[bytes] = []
+ self.adapt_list(obj, tokens)
+ return b"".join(tokens)
+
+ def adapt_list(self, obj: List[Any], tokens: List[bytes]) -> None:
+ if not obj:
+ tokens.append(b"{}")
+ return
+
+ tokens.append(b"{")
+ for item in obj:
+ if isinstance(item, list):
+ self.adapt_list(item, tokens)
+ elif item is None:
+ tokens.append(b"NULL")
+ else:
+ ad = self.tx.adapt(item)
+ if isinstance(ad, tuple):
+ ad = ad[0]
+ tokens.append(escape_item(ad))
+
+ tokens.append(b",")
+
+ tokens[-1] = b"}"
+
+
+class ArrayCasterBase(Typecaster):
+ base_caster: TypecasterType
+
+ def __init__(
+ self, oid: int, conn: Optional["BaseConnection"],
+ ):
+ super().__init__(oid, conn)
+ self.caster_func = TypecasterFunc # type: ignore
+
+ if isinstance(self.base_caster, type):
+ self.caster_func = self.base_caster(oid, conn).cast
+ else:
+ self.caster_func = cast(TypecasterFunc, type(self).base_caster)
+
+ def cast(self, data: bytes) -> List[Any]:
+ rv = None
+ stack: List[Any] = []
+ for m in _re_parse.finditer(data):
+ t = m.group(1)
+ if t == b"{":
+ a: List[Any] = []
+ if rv is None:
+ rv = a
+ if stack:
+ stack[-1].append(a)
+ stack.append(a)
+
+ elif t == b"}":
+ if not stack:
+ raise e.DataError("malformed array, unexpected '}'")
+ rv = stack.pop()
+
+ else:
+ if not stack:
+ raise e.DataError(
+ f"malformed array, unexpected"
+ f" '{t.decode('utf8', 'replace')}'"
+ )
+ if t == b"NULL":
+ v = None
+ else:
+ if t.startswith(b'"'):
+ t = _re_unescape.sub(br"\1", t[1:-1])
+ v = self.caster_func(t)
+
+ stack[-1].append(v)
+
+ assert rv is not None
+ return rv
+
+
+class ArrayCaster(Typecaster):
+ @staticmethod
+ def register(
+ oid: int, # array oid
+ caster: TypecasterType,
+ context: AdaptContext = None,
+ format: Format = Format.TEXT,
+ ) -> TypecasterType:
+ t = type(
+ caster.__name__ + "_array", # type: ignore
+ (ArrayCasterBase,),
+ {"base_caster": caster},
+ )
+ return Typecaster.register(oid, t, context=context, format=format)
+
+ @staticmethod
+ def text(oid: int) -> Callable[[Any], Any]:
+ def text_(caster: TypecasterType) -> TypecasterType:
+ ArrayCaster.register(oid, caster, format=Format.TEXT)
+ return caster
+
+ return text_
+
+
+class UnknownArrayCaster(ArrayCasterBase):
+ base_caster = UnknownCaster
from ..adapt import Adapter, Typecaster
from .oids import builtins
+from .array import ArrayCaster
FLOAT8_OID = builtins["float8"].oid
NUMERIC_OID = builtins["numeric"].oid
@Typecaster.text(builtins["int4"].oid)
@Typecaster.text(builtins["int8"].oid)
@Typecaster.text(builtins["oid"].oid)
+@ArrayCaster.text(builtins["int2"].array_oid)
+@ArrayCaster.text(builtins["int4"].array_oid)
+@ArrayCaster.text(builtins["int8"].array_oid)
+@ArrayCaster.text(builtins["oid"].array_oid)
def cast_int(data: bytes) -> int:
return int(_decode(data)[0])
from ..utils.typing import EncodeFunc, DecodeFunc
from ..pq import Escaping
from .oids import builtins
+from .array import ArrayCaster
TEXT_OID = builtins["text"].oid
BYTEA_OID = builtins["bytea"].oid
@Typecaster.text(builtins["text"].oid)
@Typecaster.binary(builtins["text"].oid)
+@ArrayCaster.text(builtins["text"].array_oid)
class StringCaster(Typecaster):
decode: Optional[DecodeFunc]
@Typecaster.text(builtins["bytea"].oid)
+@ArrayCaster.text(builtins["bytea"].array_oid)
def cast_bytea(data: bytes) -> bytes:
return Escaping.unescape_bytea(data)
--- /dev/null
+import pytest
+from psycopg3.types import builtins
+from psycopg3.adapt import Typecaster, UnknownCaster
+from psycopg3.types.array import UnknownArrayCaster, ArrayCaster
+
+
+tests_str = [
+ ([], "{}"),
+ (["foo", "bar", "baz"], "{foo,bar,baz}"),
+ (["foo", None, "baz"], "{foo,null,baz}"),
+ (["foo", "null", "", "baz"], '{foo,"null","",baz}'),
+ ([["foo", "bar"], ["baz", "qux"]], "{{foo,bar},{baz,qux}}"),
+ (
+ [[["fo{o", "ba}r"], ['ba"z', "qu'x"], ["qu ux", " "]]],
+ r'{{{"fo{o","ba}r"},{"ba\"z",qu\'x},{"qu ux"," "}}}',
+ ),
+]
+
+
+@pytest.mark.parametrize("obj, want", tests_str)
+def test_adapt_list_str(conn, obj, want):
+ cur = conn.cursor()
+ cur.execute("select %s::text[] = %s::text[]", (obj, want))
+ assert cur.fetchone()[0]
+
+
+@pytest.mark.parametrize("want, obj", tests_str)
+def test_cast_list_str(conn, obj, want):
+ cur = conn.cursor()
+ cur.execute("select %s::text[]", (obj,))
+ assert cur.fetchone()[0] == want
+
+
+def test_all_chars(conn):
+ cur = conn.cursor()
+ for i in range(1, 256):
+ c = chr(i)
+ cur.execute("select %s::text[]", ([c],))
+ assert cur.fetchone()[0] == [c]
+
+ a = list(map(chr, range(1, 256)))
+ a.append("\u20ac")
+ cur.execute("select %s::text[]", (a,))
+ assert cur.fetchone()[0] == a
+
+ a = "".join(a)
+ cur.execute("select %s::text[]", ([a],))
+ assert cur.fetchone()[0] == [a]
+
+
+tests_int = [
+ ([], "{}"),
+ ([10, 20, -30], "{10,20,-30}"),
+ ([10, None, 30], "{10,null,30}"),
+ ([[10, 20], [30, 40]], "{{10,20},{30,40}}"),
+]
+
+
+@pytest.mark.parametrize("obj, want", tests_int)
+def test_adapt_list_int(conn, obj, want):
+ cur = conn.cursor()
+ cur.execute("select %s::int[] = %s::int[]", (obj, want))
+ assert cur.fetchone()[0]
+
+
+@pytest.mark.parametrize("want, obj", tests_int)
+def test_cast_list_int(conn, obj, want):
+ cur = conn.cursor()
+ cur.execute("select %s::int[]", (obj,))
+ assert cur.fetchone()[0] == want
+
+
+def test_unknown(conn):
+ # unknown for real
+ assert builtins["aclitem"].array_oid not in Typecaster.globals
+ Typecaster.register(
+ builtins["aclitem"].array_oid, UnknownArrayCaster, context=conn
+ )
+ cur = conn.cursor()
+ cur.execute("select '{postgres=arwdDxt/postgres}'::aclitem[]")
+ res = cur.fetchone()[0]
+ assert res == ["postgres=arwdDxt/postgres"]
+
+
+def test_array_register(conn):
+ cur = conn.cursor()
+ cur.execute("select '{postgres=arwdDxt/postgres}'::aclitem[]")
+ res = cur.fetchone()[0]
+ assert res == "{postgres=arwdDxt/postgres}"
+
+ ArrayCaster.register(
+ builtins["aclitem"].array_oid, UnknownCaster, context=conn
+ )
+ cur.execute("select '{postgres=arwdDxt/postgres}'::aclitem[]")
+ res = cur.fetchone()[0]
+ assert res == ["postgres=arwdDxt/postgres"]