from ..oids import builtins
from ..adapt import Dumper, Loader
-from ..utils.codecs import EncodeFunc, DecodeFunc, encode_ascii, decode_ascii
+from ..utils.codecs import DecodeFunc, decode_ascii
+PackInt = Callable[[int], bytes]
UnpackInt = Callable[[bytes], Tuple[int]]
UnpackFloat = Callable[[bytes], Tuple[float]]
+_pack_int2 = cast(PackInt, struct.Struct("!h").pack)
+_pack_int4 = cast(PackInt, struct.Struct("!i").pack)
+_pack_uint4 = cast(PackInt, struct.Struct("!I").pack)
+_pack_int8 = cast(PackInt, struct.Struct("!q").pack)
+
+# Wrappers to force numbers to be cast as specific PostgreSQL types
+
+
+class Int2(int):
+ def __new__(cls, arg: int) -> "Int2":
+ rv: Int2 = super().__new__(cls, arg) # type: ignore[call-arg]
+ return rv
+
+
+class Int4(int):
+ def __new__(cls, arg: int) -> "Int4":
+ rv: Int4 = super().__new__(cls, arg) # type: ignore[call-arg]
+ return rv
+
+
+class Int8(int):
+ def __new__(cls, arg: int) -> "Int8":
+ rv: Int8 = super().__new__(cls, arg) # type: ignore[call-arg]
+ return rv
+
+
+class Oid(int):
+ def __new__(cls, arg: int) -> "Oid":
+ rv: Oid = super().__new__(cls, arg) # type: ignore[call-arg]
+ return rv
+
class NumberDumper(Dumper):
_special: Dict[bytes, bytes] = {}
- def dump(self, obj: Any, __encode: EncodeFunc = encode_ascii) -> bytes:
- return __encode(str(obj))[0]
+ def dump(self, obj: Any) -> bytes:
+ return str(obj).encode("utf8")
def quote(self, obj: Any) -> bytes:
value = self.dump(obj)
}
+@Dumper.text(Int2)
+class Int2Dumper(NumberDumper):
+ oid = builtins["int2"].oid
+
+
+@Dumper.text(Int4)
+class Int4Dumper(NumberDumper):
+ oid = builtins["int4"].oid
+
+
+@Dumper.text(Int8)
+class Int8Dumper(NumberDumper):
+ oid = builtins["int8"].oid
+
+
+@Dumper.text(Oid)
+class OidDumper(NumberDumper):
+ oid = builtins["oid"].oid
+
+
+@Dumper.binary(Int2)
+class Int2BinaryDumper(Int2Dumper):
+ def dump(self, obj: int) -> bytes:
+ return _pack_int2(obj)
+
+
+@Dumper.binary(Int4)
+class Int4BinaryDumper(Int4Dumper):
+ def dump(self, obj: int) -> bytes:
+ return _pack_int4(obj)
+
+
+@Dumper.binary(Int8)
+class Int8BinaryDumper(Int8Dumper):
+ def dump(self, obj: int) -> bytes:
+ return _pack_int8(obj)
+
+
+@Dumper.binary(Oid)
+class OidBinaryDumper(OidDumper):
+ def dump(self, obj: int) -> bytes:
+ return _pack_uint4(obj)
+
+
@Loader.text(builtins["int2"].oid)
@Loader.text(builtins["int4"].oid)
@Loader.text(builtins["int8"].oid)
import pytest
+import psycopg3
from psycopg3 import sql
from psycopg3.oids import builtins
from psycopg3.adapt import Transformer, Format
(0, "'0'::int"),
(1, "'1'::int"),
(-1, "'-1'::int"),
- (42, "'42'::int"),
- (-42, "'-42'::int"),
+ (42, "'42'::smallint"),
+ (-42, "'-42'::smallint"),
(int(2 ** 63 - 1), "'9223372036854775807'::bigint"),
(int(-(2 ** 63)), "'-9223372036854775808'::bigint"),
],
assert cur.fetchone()[0] is True
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0, "'0'::integer"),
+ (1, "'1'::integer"),
+ (-1, "'-1'::integer"),
+ (42, "'42'::smallint"),
+ (-42, "'-42'::smallint"),
+ (int(2 ** 63 - 1), "'9223372036854775807'::bigint"),
+ (int(-(2 ** 63)), "'-9223372036854775808'::bigint"),
+ (0, "'0'::oid"),
+ (4294967295, "'4294967295'::oid"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+def test_dump_int_subtypes(conn, val, expr, fmt_in):
+ tname = builtins[expr.rsplit(":", 1)[-1]].name.title()
+ assert tname in "Int2 Int4 Int8 Oid".split()
+ Type = getattr(psycopg3.types.numeric, tname)
+ ph = "%s" if fmt_in == Format.TEXT else "%b"
+ cur = conn.cursor()
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof({ph})", (Type(val),))
+ assert cur.fetchone()[0] is True
+ cur.execute(f"select {expr} = {ph}", (Type(val),))
+ assert cur.fetchone()[0] is True
+
+
@pytest.mark.parametrize(
"val, expr",
[