]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added int subtypes to cast to specific sizes
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 6 Nov 2020 23:51:43 +0000 (23:51 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 7 Nov 2020 01:49:40 +0000 (01:49 +0000)
psycopg3/psycopg3/types/numeric.py
tests/types/test_numeric.py

index e32f7f638042f7ba06ee1bc6013ad5dcb641fc43..c239a47c697ac832e2acfc3ccafb90639b210d94 100644 (file)
@@ -10,17 +10,49 @@ from decimal import Decimal
 
 from ..oids import builtins
 from ..adapt import Dumper, Loader
-from ..utils.codecs import EncodeFunc, DecodeFunc, encode_ascii, decode_ascii
+from ..utils.codecs import DecodeFunc, decode_ascii
 
+PackInt = Callable[[int], bytes]
 UnpackInt = Callable[[bytes], Tuple[int]]
 UnpackFloat = Callable[[bytes], Tuple[float]]
 
+_pack_int2 = cast(PackInt, struct.Struct("!h").pack)
+_pack_int4 = cast(PackInt, struct.Struct("!i").pack)
+_pack_uint4 = cast(PackInt, struct.Struct("!I").pack)
+_pack_int8 = cast(PackInt, struct.Struct("!q").pack)
+
+# Wrappers to force numbers to be cast as specific PostgreSQL types
+
+
+class Int2(int):
+    def __new__(cls, arg: int) -> "Int2":
+        rv: Int2 = super().__new__(cls, arg)  # type: ignore[call-arg]
+        return rv
+
+
+class Int4(int):
+    def __new__(cls, arg: int) -> "Int4":
+        rv: Int4 = super().__new__(cls, arg)  # type: ignore[call-arg]
+        return rv
+
+
+class Int8(int):
+    def __new__(cls, arg: int) -> "Int8":
+        rv: Int8 = super().__new__(cls, arg)  # type: ignore[call-arg]
+        return rv
+
+
+class Oid(int):
+    def __new__(cls, arg: int) -> "Oid":
+        rv: Oid = super().__new__(cls, arg)  # type: ignore[call-arg]
+        return rv
+
 
 class NumberDumper(Dumper):
     _special: Dict[bytes, bytes] = {}
 
-    def dump(self, obj: Any, __encode: EncodeFunc = encode_ascii) -> bytes:
-        return __encode(str(obj))[0]
+    def dump(self, obj: Any) -> bytes:
+        return str(obj).encode("utf8")
 
     def quote(self, obj: Any) -> bytes:
         value = self.dump(obj)
@@ -61,6 +93,50 @@ class DecimalDumper(NumberDumper):
     }
 
 
+@Dumper.text(Int2)
+class Int2Dumper(NumberDumper):
+    oid = builtins["int2"].oid
+
+
+@Dumper.text(Int4)
+class Int4Dumper(NumberDumper):
+    oid = builtins["int4"].oid
+
+
+@Dumper.text(Int8)
+class Int8Dumper(NumberDumper):
+    oid = builtins["int8"].oid
+
+
+@Dumper.text(Oid)
+class OidDumper(NumberDumper):
+    oid = builtins["oid"].oid
+
+
+@Dumper.binary(Int2)
+class Int2BinaryDumper(Int2Dumper):
+    def dump(self, obj: int) -> bytes:
+        return _pack_int2(obj)
+
+
+@Dumper.binary(Int4)
+class Int4BinaryDumper(Int4Dumper):
+    def dump(self, obj: int) -> bytes:
+        return _pack_int4(obj)
+
+
+@Dumper.binary(Int8)
+class Int8BinaryDumper(Int8Dumper):
+    def dump(self, obj: int) -> bytes:
+        return _pack_int8(obj)
+
+
+@Dumper.binary(Oid)
+class OidBinaryDumper(OidDumper):
+    def dump(self, obj: int) -> bytes:
+        return _pack_uint4(obj)
+
+
 @Loader.text(builtins["int2"].oid)
 @Loader.text(builtins["int4"].oid)
 @Loader.text(builtins["int8"].oid)
index d982b0f0c45869227cd913054bc43d6099e3c158..a4bfe89edcdad9744d656cbf72b7f4d46408eee7 100644 (file)
@@ -3,6 +3,7 @@ from math import isnan, isinf, exp
 
 import pytest
 
+import psycopg3
 from psycopg3 import sql
 from psycopg3.oids import builtins
 from psycopg3.adapt import Transformer, Format
@@ -20,8 +21,8 @@ from psycopg3.types.numeric import FloatLoader
         (0, "'0'::int"),
         (1, "'1'::int"),
         (-1, "'-1'::int"),
-        (42, "'42'::int"),
-        (-42, "'-42'::int"),
+        (42, "'42'::smallint"),
+        (-42, "'-42'::smallint"),
         (int(2 ** 63 - 1), "'9223372036854775807'::bigint"),
         (int(-(2 ** 63)), "'-9223372036854775808'::bigint"),
     ],
@@ -33,6 +34,33 @@ def test_dump_int(conn, val, expr):
     assert cur.fetchone()[0] is True
 
 
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        (0, "'0'::integer"),
+        (1, "'1'::integer"),
+        (-1, "'-1'::integer"),
+        (42, "'42'::smallint"),
+        (-42, "'-42'::smallint"),
+        (int(2 ** 63 - 1), "'9223372036854775807'::bigint"),
+        (int(-(2 ** 63)), "'-9223372036854775808'::bigint"),
+        (0, "'0'::oid"),
+        (4294967295, "'4294967295'::oid"),
+    ],
+)
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+def test_dump_int_subtypes(conn, val, expr, fmt_in):
+    tname = builtins[expr.rsplit(":", 1)[-1]].name.title()
+    assert tname in "Int2 Int4 Int8 Oid".split()
+    Type = getattr(psycopg3.types.numeric, tname)
+    ph = "%s" if fmt_in == Format.TEXT else "%b"
+    cur = conn.cursor()
+    cur.execute(f"select pg_typeof({expr}) = pg_typeof({ph})", (Type(val),))
+    assert cur.fetchone()[0] is True
+    cur.execute(f"select {expr} = {ph}", (Type(val),))
+    assert cur.fetchone()[0] is True
+
+
 @pytest.mark.parametrize(
     "val, expr",
     [