Added binary array adaptation for test types and some numeric types.
import re
import struct
-from typing import Any, Generator, List, Optional
+from typing import Any, Generator, List, Optional, Tuple
from .. import errors as e
from ..pq import Format
from ..adapt import Adapter, TypeCaster, Transformer, UnknownCaster
from ..adapt import AdaptContext, TypeCasterType, TypeCasterFunc
+from .oids import builtins
+
+TEXT_OID = builtins["text"].oid
+TEXT_ARRAY_OID = builtins["text"].array_oid
# from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
return b'"' + _re_escape.sub(br"\\\1", item) + b'"'
-@Adapter.text(list)
-class ListAdapter(Adapter):
+class BaseListAdapter(Adapter):
def __init__(self, src: type, context: AdaptContext = None):
super().__init__(src, context)
self.tx = Transformer(context)
- def adapt(self, obj: List[Any]) -> bytes:
+ def _array_oid(self, base_oid: int) -> int:
+ """
+ Return the oid of the array from the oid of the base item.
+
+ Fall back on text[].
+ TODO: we shouldn't consider builtins only, but other adaptation
+ contexts too
+ """
+ oid = 0
+ if base_oid:
+ info = builtins.get(base_oid)
+ if info is not None:
+ oid = info.array_oid
+
+ return oid or TEXT_ARRAY_OID
+
+
+@Adapter.text(list)
+class ListAdapter(BaseListAdapter):
+ def adapt(self, obj: List[Any]) -> Tuple[bytes, int]:
tokens: List[bytes] = []
- self.adapt_list(obj, tokens)
- return b"".join(tokens)
- def adapt_list(self, obj: List[Any], tokens: List[bytes]) -> None:
+ oid = 0
+
+ def adapt_list(obj: List[Any]) -> None:
+ nonlocal oid
+
+ if not obj:
+ tokens.append(b"{}")
+ return
+
+ tokens.append(b"{")
+ for item in obj:
+ if isinstance(item, list):
+ adapt_list(item)
+ elif item is None:
+ tokens.append(b"NULL")
+ else:
+ ad = self.tx.adapt(item)
+ if isinstance(ad, tuple):
+ if oid == 0:
+ oid = ad[1]
+ ad = ad[0]
+ tokens.append(escape_item(ad))
+
+ tokens.append(b",")
+
+ tokens[-1] = b"}"
+
+ adapt_list(obj)
+
+ return b"".join(tokens), self._array_oid(oid)
+
+
+@Adapter.binary(list)
+class BinaryListAdapter(BaseListAdapter):
+ def adapt(self, obj: List[Any]) -> Tuple[bytes, int]:
if not obj:
- tokens.append(b"{}")
- return
-
- tokens.append(b"{")
- for item in obj:
- if isinstance(item, list):
- self.adapt_list(item, tokens)
- elif item is None:
- tokens.append(b"NULL")
+ return _struct_head.pack(0, 0, TEXT_OID), TEXT_ARRAY_OID
+
+ data: List[bytes] = []
+ head = [0, 0, 0] # to fill: ndims, hasnull, base_oid
+ dims = []
+ data = []
+
+ def calc_dims(L: List[Any]) -> None:
+ if isinstance(L, self.src):
+ if not L:
+ raise e.DataError("lists cannot contain empty lists")
+ dims.append(len(L))
+ calc_dims(L[0])
+
+ calc_dims(obj)
+
+ def adapt_list(L: List[Any], dim: int) -> None:
+ if len(L) != dims[dim]:
+ raise e.DataError("nested lists have inconsistent lengths")
+
+ if dim == len(dims) - 1:
+ for item in L:
+ ad = self.tx.adapt(item, Format.BINARY)
+ if isinstance(ad, tuple):
+ if head[2] == 0:
+ head[2] = ad[1]
+ ad = ad[0]
+ if ad is None:
+ head[1] = 1
+ data.append(b"\xff\xff\xff\xff")
+ else:
+ data.append(_struct_len.pack(len(ad)))
+ data.append(ad)
else:
- ad = self.tx.adapt(item)
- if isinstance(ad, tuple):
- ad = ad[0]
- tokens.append(escape_item(ad))
+ for item in L:
+ if not isinstance(item, self.src):
+ raise e.DataError(
+ "nested lists have inconsistent depths"
+ )
+ adapt_list(item, dim + 1) # type: ignore
+
+ adapt_list(obj, 0)
+
+ head[0] = len(dims)
+ if head[2] == 0:
+ head[2] = TEXT_OID
- tokens.append(b",")
+ oid = self._array_oid(head[2])
- tokens[-1] = b"}"
+ bhead = _struct_head.pack(*head) + b"".join(
+ _struct_dim.pack(dim, 1) for dim in dims
+ )
+ return bhead + b"".join(data), oid
class ArrayCasterBase(TypeCaster):
return rv
-_unpack_head = struct.Struct("!III").unpack_from
-_unpack_dim = struct.Struct("!II").unpack_from
-_unpack_len = struct.Struct("!i").unpack_from
+_struct_head = struct.Struct("!III")
+_struct_dim = struct.Struct("!II")
+_struct_len = struct.Struct("!i")
class ArrayCasterBinary(ArrayCasterBase):
self.tx = Transformer(context)
def cast(self, data: bytes) -> List[Any]:
- ndims, hasnull, oid = _unpack_head(data[:12])
+ ndims, hasnull, oid = _struct_head.unpack_from(data[:12])
if not ndims:
return []
fcast = self.tx.get_cast_function(oid, Format.BINARY)
p = 12 + 8 * ndims
- dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))]
+ dims = [
+ _struct_dim.unpack_from(data, i)[0] for i in list(range(12, p, 8))
+ ]
def consume(p: int) -> Generator[Any, None, None]:
while 1:
- size = _unpack_len(data, p)[0]
+ size = _struct_len.unpack_from(data, p)[0]
p += 4
if size != -1:
yield fcast(data[p : p + size])
True: (b"t", builtins["bool"].oid),
False: (b"f", builtins["bool"].oid),
}
+_bool_binary_adapt = {
+ True: (b"\x01", builtins["bool"].oid),
+ False: (b"\x00", builtins["bool"].oid),
+}
@Adapter.text(bool)
return _bool_adapt[obj]
+@Adapter.binary(bool)
+def adapt_binary_bool(obj: bool) -> Tuple[bytes, int]:
+ return _bool_binary_adapt[obj]
+
+
@TypeCaster.text(builtins["int2"].oid)
@TypeCaster.text(builtins["int4"].oid)
@TypeCaster.text(builtins["int8"].oid)
@TypeCaster.binary(builtins["int2"].oid)
+@ArrayCaster.binary(builtins["int2"].array_oid)
def cast_binary_int2(data: bytes) -> int:
rv: int = _int2_struct.unpack(data)[0]
return rv
@TypeCaster.binary(builtins["int4"].oid)
+@ArrayCaster.binary(builtins["int4"].array_oid)
def cast_binary_int4(data: bytes) -> int:
rv: int = _int4_struct.unpack(data)[0]
return rv
@TypeCaster.binary(builtins["int8"].oid)
+@ArrayCaster.binary(builtins["int8"].array_oid)
def cast_binary_int8(data: bytes) -> int:
rv: int = _int8_struct.unpack(data)[0]
return rv
@TypeCaster.binary(builtins["oid"].oid)
+@ArrayCaster.binary(builtins["oid"].array_oid)
def cast_binary_oid(data: bytes) -> int:
rv: int = _oid_struct.unpack(data)[0]
return rv
@TypeCaster.binary(builtins["bytea"].oid)
+@ArrayCaster.binary(builtins["bytea"].array_oid)
def cast_bytea_binary(data: bytes) -> bytes:
return data
@pytest.mark.parametrize("obj, want", tests_str)
def test_adapt_list_str(conn, obj, want, fmt_in):
cur = conn.cursor()
- cur.execute("select %s::text[] = %s::text[]", (obj, want))
+ ph = "%s" if fmt_in == Format.TEXT else "%b"
+ cur.execute(f"select {ph}::text[] = %s::text[]", (obj, want))
assert cur.fetchone()[0]
@pytest.mark.parametrize("want, obj", tests_str)
def test_cast_list_str(conn, obj, want, fmt_out):
cur = conn.cursor(binary=fmt_out == Format.BINARY)
- ph = "%s" if format == Format.TEXT else "%b"
- cur.execute("select %s::text[]" % ph, (obj,))
+ cur.execute("select %s::text[]", (obj,))
assert cur.fetchone()[0] == want
-def test_all_chars(conn):
- cur = conn.cursor()
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_all_chars(conn, fmt_in, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
+ ph = "%s" if fmt_in == Format.TEXT else "%b"
for i in range(1, 256):
c = chr(i)
- cur.execute("select %s::text[]", ([c],))
+ cur.execute(f"select {ph}::text[]", ([c],))
assert cur.fetchone()[0] == [c]
a = list(map(chr, range(1, 256)))
a.append("\u20ac")
- cur.execute("select %s::text[]", (a,))
+ cur.execute(f"select {ph}::text[]", (a,))
assert cur.fetchone()[0] == a
a = "".join(a)
- cur.execute("select %s::text[]", ([a],))
+ cur.execute(f"select {ph}::text[]", ([a],))
assert cur.fetchone()[0] == [a]
assert cur.fetchone()[0]
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
@pytest.mark.parametrize("want, obj", tests_int)
-def test_cast_list_int(conn, obj, want):
- cur = conn.cursor()
+def test_cast_list_int(conn, obj, want, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
cur.execute("select %s::int[]", (obj,))
assert cur.fetchone()[0] == want
("4294967295", "oid", 4294967295),
],
)
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_cast_int(conn, val, pgtype, want, format):
- cur = conn.cursor(binary=format == Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_int(conn, val, pgtype, want, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
cur.execute("select %%s::%s" % pgtype, (val,))
- assert cur.pgresult.fformat(0) == format
+ assert cur.pgresult.fformat(0) == fmt_out
assert cur.pgresult.ftype(0) == builtins[pgtype].oid
result = cur.fetchone()[0]
assert result == want
("-inf", "float8", -float("inf")),
],
)
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_cast_float(conn, val, pgtype, want, format):
- cur = conn.cursor(binary=format == Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_float(conn, val, pgtype, want, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
cur.execute("select %%s::%s" % pgtype, (val,))
- assert cur.pgresult.fformat(0) == format
+ assert cur.pgresult.fformat(0) == fmt_out
result = cur.fetchone()[0]
assert type(result) is type(want)
if isnan(want):
("-1.42e40", "float8", -1.42e40),
],
)
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_cast_float_approx(conn, expr, pgtype, want, format):
- cur = conn.cursor(binary=format == Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_float_approx(conn, expr, pgtype, want, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
cur.execute("select %s::%s" % (expr, pgtype))
- assert cur.pgresult.fformat(0) == format
+ assert cur.pgresult.fformat(0) == fmt_out
result = cur.fetchone()[0]
assert result == pytest.approx(want)
#
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
@pytest.mark.parametrize("b", [True, False, None])
-def test_roundtrip_bool(conn, b, format):
- cur = conn.cursor(binary=format == Format.BINARY)
- result = cur.execute("select %s", (b,)).fetchone()[0]
- assert cur.pgresult.fformat(0) == format
+def test_roundtrip_bool(conn, b, fmt_in, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
+ ph = "%s" if fmt_in == Format.TEXT else "%b"
+ result = cur.execute(f"select {ph}", (b,)).fetchone()[0]
+ assert cur.pgresult.fformat(0) == fmt_out
assert result is b
#
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_adapt_1char(conn, format):
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+def test_adapt_1char(conn, fmt_in):
cur = conn.cursor()
- ph = "%s" if format == Format.TEXT else "%b"
+ ph = "%s" if fmt_in == Format.TEXT else "%b"
for i in range(1, 256):
- cur.execute("select %s = chr(%%s::int)" % ph, (chr(i), i))
+ cur.execute(f"select {ph} = chr(%s::int)", (chr(i), i))
assert cur.fetchone()[0], chr(i)
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_cast_1char(conn, format):
- cur = conn.cursor(binary=format == Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_1char(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
for i in range(1, 256):
cur.execute("select chr(%s::int)", (i,))
assert cur.fetchone()[0] == chr(i)
- assert cur.pgresult.fformat(0) == format
+ assert cur.pgresult.fformat(0) == fmt_out
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
@pytest.mark.parametrize("encoding", ["utf8", "latin9"])
-def test_adapt_enc(conn, format, encoding):
+def test_adapt_enc(conn, fmt_in, encoding):
cur = conn.cursor()
- ph = "%s" if format == Format.TEXT else "%b"
+ ph = "%s" if fmt_in == Format.TEXT else "%b"
conn.encoding = encoding
- (res,) = cur.execute("select %s::bytea" % ph, (eur,)).fetchone()
+ (res,) = cur.execute(f"select {ph}::bytea", (eur,)).fetchone()
assert res == eur.encode("utf8")
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_adapt_ascii(conn, format):
- cur = conn.cursor(binary=format == Format.BINARY)
- ph = "%s" if format == Format.TEXT else "%b"
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+def test_adapt_ascii(conn, fmt_in):
+ cur = conn.cursor()
+ ph = "%s" if fmt_in == Format.TEXT else "%b"
conn.encoding = "sql_ascii"
- (res,) = cur.execute("select ascii(%s)" % ph, (eur,)).fetchone()
+ (res,) = cur.execute(f"select ascii({ph})", (eur,)).fetchone()
assert res == ord(eur)
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_adapt_badenc(conn, format):
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+def test_adapt_badenc(conn, fmt_in):
cur = conn.cursor()
- ph = "%s" if format == Format.TEXT else "%b"
+ ph = "%s" if fmt_in == Format.TEXT else "%b"
conn.encoding = "latin1"
with pytest.raises(UnicodeEncodeError):
- cur.execute("select %s::bytea" % ph, (eur,))
+ cur.execute(f"select {ph}::bytea", (eur,))
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
@pytest.mark.parametrize("encoding", ["utf8", "latin9"])
-def test_cast_enc(conn, format, encoding):
- cur = conn.cursor(binary=format == Format.BINARY)
+def test_cast_enc(conn, fmt_out, encoding):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
conn.encoding = encoding
(res,) = cur.execute("select chr(%s::int)", (ord(eur),)).fetchone()
assert res == eur
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_cast_badenc(conn, format):
- cur = conn.cursor(binary=format == Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_badenc(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
conn.encoding = "latin1"
with pytest.raises(psycopg3.DatabaseError):
cur.execute("select chr(%s::int)", (ord(eur),))
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_cast_ascii(conn, format):
- cur = conn.cursor(binary=format == Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_ascii(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
conn.encoding = "sql_ascii"
(res,) = cur.execute("select chr(%s::int)", (ord(eur),)).fetchone()
assert res == eur.encode("utf8")
-def test_text_array(conn):
- cur = conn.cursor()
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_text_array(conn, fmt_in, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
+ ph = "%s" if fmt_in == Format.TEXT else "%b"
a = list(map(chr, range(1, 256))) + [eur]
- (res,) = cur.execute("select %s::text[]", (a,)).fetchone()
+
+ (res,) = cur.execute(f"select {ph}::text[]", (a,)).fetchone()
assert res == a
-def test_text_array_ascii(conn):
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_text_array_ascii(conn, fmt_in, fmt_out):
conn.encoding = "sql_ascii"
- cur = conn.cursor()
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
a = list(map(chr, range(1, 256))) + [eur]
exp = [s.encode("utf8") for s in a]
- (res,) = cur.execute("select %s::text[]", (a,)).fetchone()
+ ph = "%s" if fmt_in == Format.TEXT else "%b"
+ (res,) = cur.execute(f"select {ph}::text[]", (a,)).fetchone()
assert res == exp
#
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_adapt_1byte(conn, format):
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+def test_adapt_1byte(conn, fmt_in):
cur = conn.cursor()
- ph = "%s" if format == Format.TEXT else "%b"
- "select %s = %%s::bytea" % ph
+ ph = "%s" if fmt_in == Format.TEXT else "%b"
for i in range(0, 256):
- cur.execute("select %s = %%s::bytea" % ph, (bytes([i]), fr"\x{i:02x}"))
+ cur.execute(f"select {ph} = %s::bytea", (bytes([i]), fr"\x{i:02x}"))
assert cur.fetchone()[0], i
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_cast_1byte(conn, format):
- cur = conn.cursor(binary=format == Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_1byte(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
for i in range(0, 256):
cur.execute("select %s::bytea", (fr"\x{i:02x}",))
assert cur.fetchone()[0] == bytes([i])
- assert cur.pgresult.fformat(0) == format
+ assert cur.pgresult.fformat(0) == fmt_out
-def test_bytea_array(conn):
- cur = conn.cursor()
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_bytea_array(conn, fmt_in, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
a = [bytes(range(0, 256))]
- (res,) = cur.execute("select %s::bytea[]", (a,)).fetchone()
+ ph = "%s" if fmt_in == Format.TEXT else "%b"
+ (res,) = cur.execute(f"select {ph}::bytea[]", (a,)).fetchone()
assert res == a