From: Daniele Varrazzo Date: Fri, 6 Nov 2020 23:51:43 +0000 (+0000) Subject: Added int subtypes to cast to specific sizes X-Git-Tag: 3.0.dev0~396 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=125f9bb3162c2c406fc43e1060c0264527e79af9;p=thirdparty%2Fpsycopg.git Added int subtypes to cast to specific sizes --- diff --git a/psycopg3/psycopg3/types/numeric.py b/psycopg3/psycopg3/types/numeric.py index e32f7f638..c239a47c6 100644 --- a/psycopg3/psycopg3/types/numeric.py +++ b/psycopg3/psycopg3/types/numeric.py @@ -10,17 +10,49 @@ from decimal import Decimal from ..oids import builtins from ..adapt import Dumper, Loader -from ..utils.codecs import EncodeFunc, DecodeFunc, encode_ascii, decode_ascii +from ..utils.codecs import DecodeFunc, decode_ascii +PackInt = Callable[[int], bytes] UnpackInt = Callable[[bytes], Tuple[int]] UnpackFloat = Callable[[bytes], Tuple[float]] +_pack_int2 = cast(PackInt, struct.Struct("!h").pack) +_pack_int4 = cast(PackInt, struct.Struct("!i").pack) +_pack_uint4 = cast(PackInt, struct.Struct("!I").pack) +_pack_int8 = cast(PackInt, struct.Struct("!q").pack) + +# Wrappers to force numbers to be cast as specific PostgreSQL types + + +class Int2(int): + def __new__(cls, arg: int) -> "Int2": + rv: Int2 = super().__new__(cls, arg) # type: ignore[call-arg] + return rv + + +class Int4(int): + def __new__(cls, arg: int) -> "Int4": + rv: Int4 = super().__new__(cls, arg) # type: ignore[call-arg] + return rv + + +class Int8(int): + def __new__(cls, arg: int) -> "Int8": + rv: Int8 = super().__new__(cls, arg) # type: ignore[call-arg] + return rv + + +class Oid(int): + def __new__(cls, arg: int) -> "Oid": + rv: Oid = super().__new__(cls, arg) # type: ignore[call-arg] + return rv + class NumberDumper(Dumper): _special: Dict[bytes, bytes] = {} - def dump(self, obj: Any, __encode: EncodeFunc = encode_ascii) -> bytes: - return __encode(str(obj))[0] + def dump(self, obj: Any) -> bytes: + return str(obj).encode("utf8") def quote(self, obj: Any) -> bytes: value = self.dump(obj) @@ -61,6 +93,50 @@ class DecimalDumper(NumberDumper): } +@Dumper.text(Int2) +class Int2Dumper(NumberDumper): + oid = builtins["int2"].oid + + +@Dumper.text(Int4) +class Int4Dumper(NumberDumper): + oid = builtins["int4"].oid + + +@Dumper.text(Int8) +class Int8Dumper(NumberDumper): + oid = builtins["int8"].oid + + +@Dumper.text(Oid) +class OidDumper(NumberDumper): + oid = builtins["oid"].oid + + +@Dumper.binary(Int2) +class Int2BinaryDumper(Int2Dumper): + def dump(self, obj: int) -> bytes: + return _pack_int2(obj) + + +@Dumper.binary(Int4) +class Int4BinaryDumper(Int4Dumper): + def dump(self, obj: int) -> bytes: + return _pack_int4(obj) + + +@Dumper.binary(Int8) +class Int8BinaryDumper(Int8Dumper): + def dump(self, obj: int) -> bytes: + return _pack_int8(obj) + + +@Dumper.binary(Oid) +class OidBinaryDumper(OidDumper): + def dump(self, obj: int) -> bytes: + return _pack_uint4(obj) + + @Loader.text(builtins["int2"].oid) @Loader.text(builtins["int4"].oid) @Loader.text(builtins["int8"].oid) diff --git a/tests/types/test_numeric.py b/tests/types/test_numeric.py index d982b0f0c..a4bfe89ed 100644 --- a/tests/types/test_numeric.py +++ b/tests/types/test_numeric.py @@ -3,6 +3,7 @@ from math import isnan, isinf, exp import pytest +import psycopg3 from psycopg3 import sql from psycopg3.oids import builtins from psycopg3.adapt import Transformer, Format @@ -20,8 +21,8 @@ from psycopg3.types.numeric import FloatLoader (0, "'0'::int"), (1, "'1'::int"), (-1, "'-1'::int"), - (42, "'42'::int"), - (-42, "'-42'::int"), + (42, "'42'::smallint"), + (-42, "'-42'::smallint"), (int(2 ** 63 - 1), "'9223372036854775807'::bigint"), (int(-(2 ** 63)), "'-9223372036854775808'::bigint"), ], @@ -33,6 +34,33 @@ def test_dump_int(conn, val, expr): assert cur.fetchone()[0] is True +@pytest.mark.parametrize( + "val, expr", + [ + (0, "'0'::integer"), + (1, "'1'::integer"), + (-1, "'-1'::integer"), + (42, "'42'::smallint"), + (-42, "'-42'::smallint"), + (int(2 ** 63 - 1), "'9223372036854775807'::bigint"), + (int(-(2 ** 63)), "'-9223372036854775808'::bigint"), + (0, "'0'::oid"), + (4294967295, "'4294967295'::oid"), + ], +) +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) +def test_dump_int_subtypes(conn, val, expr, fmt_in): + tname = builtins[expr.rsplit(":", 1)[-1]].name.title() + assert tname in "Int2 Int4 Int8 Oid".split() + Type = getattr(psycopg3.types.numeric, tname) + ph = "%s" if fmt_in == Format.TEXT else "%b" + cur = conn.cursor() + cur.execute(f"select pg_typeof({expr}) = pg_typeof({ph})", (Type(val),)) + assert cur.fetchone()[0] is True + cur.execute(f"select {expr} = {ph}", (Type(val),)) + assert cur.fetchone()[0] is True + + @pytest.mark.parametrize( "val, expr", [