]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add Float4/Float8 wrappers
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Jul 2021 15:16:07 +0000 (17:16 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Jul 2021 15:16:07 +0000 (17:16 +0200)
Also improved numeric wrapper repr and made them lighter avoiding a dict
creation.

psycopg/psycopg/_struct.py
psycopg/psycopg/_wrappers.py
psycopg/psycopg/types/numeric.py
tests/fix_faker.py
tests/types/test_array.py
tests/types/test_numeric.py

index e4529ff0854f1dcd42c4156c17c079aa09ab46ba..f390800eb3640c16aa5782373f881b3ca03a7a02 100644 (file)
@@ -26,6 +26,7 @@ pack_uint2 = cast(PackInt, struct.Struct("!H").pack)
 pack_int4 = cast(PackInt, struct.Struct("!i").pack)
 pack_uint4 = cast(PackInt, struct.Struct("!I").pack)
 pack_int8 = cast(PackInt, struct.Struct("!q").pack)
+pack_float4 = cast(PackFloat, struct.Struct("!f").pack)
 pack_float8 = cast(PackFloat, struct.Struct("!d").pack)
 
 unpack_int2 = cast(UnpackInt, struct.Struct("!h").unpack)
index ecbb34c3a58fa99329d7ff89da86e447dc6545db..0064872abf46fc91e730c97adf21f52555469ba8 100644 (file)
@@ -14,38 +14,103 @@ _MODULE = "psycopg.types.numeric"
 class Int2(int):
 
     __module__ = _MODULE
+    __slots__ = ()
 
     def __new__(cls, arg: int) -> "Int2":
         return super().__new__(cls, arg)
 
+    def __str__(self) -> str:
+        return super().__repr__()
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({super().__repr__()})"
+
 
 class Int4(int):
 
     __module__ = _MODULE
+    __slots__ = ()
 
     def __new__(cls, arg: int) -> "Int4":
         return super().__new__(cls, arg)
 
+    def __str__(self) -> str:
+        return super().__repr__()
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({super().__repr__()})"
+
 
 class Int8(int):
 
     __module__ = _MODULE
+    __slots__ = ()
 
     def __new__(cls, arg: int) -> "Int8":
         return super().__new__(cls, arg)
 
+    def __str__(self) -> str:
+        return super().__repr__()
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({super().__repr__()})"
+
 
 class IntNumeric(int):
 
     __module__ = _MODULE
+    __slots__ = ()
 
     def __new__(cls, arg: int) -> "IntNumeric":
         return super().__new__(cls, arg)
 
+    def __str__(self) -> str:
+        return super().__repr__()
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Float4(float):
+
+    __module__ = _MODULE
+    __slots__ = ()
+
+    def __new__(cls, arg: float) -> "Float4":
+        return super().__new__(cls, arg)
+
+    def __str__(self) -> str:
+        return super().__repr__()
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Float8(float):
+
+    __module__ = _MODULE
+    __slots__ = ()
+
+    def __new__(cls, arg: float) -> "Float8":
+        return super().__new__(cls, arg)
+
+    def __str__(self) -> str:
+        return super().__repr__()
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({super().__repr__()})"
+
 
 class Oid(int):
 
     __module__ = _MODULE
+    __slots__ = ()
 
     def __new__(cls, arg: int) -> "Oid":
         return super().__new__(cls, arg)
+
+    def __str__(self) -> str:
+        return super().__repr__()
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({super().__repr__()})"
index 93fa217bbec6b2af31f727af3c65f070eaa9b2cc..323a85dd3d1e3af0c5bd8b3e4f69b3dc6d02671e 100644 (file)
@@ -17,7 +17,7 @@ from ..adapt import Buffer, Dumper, Loader, PyFormat
 from .._struct import pack_int2, pack_uint2, unpack_int2
 from .._struct import pack_int4, pack_uint4, unpack_int4, unpack_uint4
 from .._struct import pack_int8, unpack_int8
-from .._struct import pack_float8, unpack_float4, unpack_float8
+from .._struct import pack_float4, pack_float8, unpack_float4, unpack_float8
 
 # Exposed here
 from .._wrappers import (
@@ -26,6 +26,8 @@ from .._wrappers import (
     Int8 as Int8,
     IntNumeric as IntNumeric,
     Oid as Oid,
+    Float4 as Float4,
+    Float8 as Float8,
 )
 
 
@@ -66,6 +68,10 @@ class FloatDumper(_SpecialValuesDumper):
     }
 
 
+class Float4Dumper(FloatDumper):
+    _oid = postgres.types["float4"].oid
+
+
 class FloatBinaryDumper(Dumper):
 
     format = Format.BINARY
@@ -75,6 +81,14 @@ class FloatBinaryDumper(Dumper):
         return pack_float8(obj)
 
 
+class Float4BinaryDumper(FloatBinaryDumper):
+
+    _oid = postgres.types["float4"].oid
+
+    def dump(self, obj: float) -> bytes:
+        return pack_float4(obj)
+
+
 class DecimalDumper(_SpecialValuesDumper):
 
     _oid = postgres.types["numeric"].oid
@@ -449,10 +463,14 @@ def register_default_adapters(context: AdaptContext) -> None:
     adapters.register_dumper(Int8, Int8Dumper)
     adapters.register_dumper(IntNumeric, IntNumericDumper)
     adapters.register_dumper(Oid, OidDumper)
+    adapters.register_dumper(Float4, Float4Dumper)
+    adapters.register_dumper(Float8, FloatDumper)
     adapters.register_dumper(Int2, Int2BinaryDumper)
     adapters.register_dumper(Int4, Int4BinaryDumper)
     adapters.register_dumper(Int8, Int8BinaryDumper)
     adapters.register_dumper(Oid, OidBinaryDumper)
+    adapters.register_dumper(Float4, Float4BinaryDumper)
+    adapters.register_dumper(Float8, FloatBinaryDumper)
     adapters.register_loader("int2", IntLoader)
     adapters.register_loader("int4", IntLoader)
     adapters.register_loader("int8", IntLoader)
index d1ad460b59a91cc1025021deaa98397134f3dd0e..530cb3bcd6e928d11ed4b16e6851583f58dd6f9b 100644 (file)
@@ -358,29 +358,42 @@ class Faker:
         else:
             assert got == want
 
-    def make_float(self, spec):
+    def make_float(self, spec, double=True):
         if random() <= 0.99:
-            # this exponent should generate no inf
+            # These exponents should generate no inf
             return float(
                 f"{choice('-+')}0.{randrange(1 << 53)}e{randrange(-310,309)}"
+                if double
+                else f"{choice('-+')}0.{randrange(1 << 22)}e{randrange(-37,38)}"
             )
         else:
             return choice(
                 (0.0, -0.0, float("-inf"), float("inf"), float("nan"))
             )
 
-    def match_float(self, spec, got, want):
+    def match_float(self, spec, got, want, approx=False):
         if got is not None and isnan(got):
             assert isnan(want)
         else:
             # Versions older than 12 make some rounding. e.g. in Postgres 10.4
             # select '-1.409006204063909e+112'::float8
             #      -> -1.40900620406391e+112
-            if self.conn.info.server_version >= 120000:
+            if not approx and self.conn.info.server_version >= 120000:
                 assert got == want
             else:
                 assert got == pytest.approx(want)
 
+    def make_Float4(self, spec):
+        return spec(self.make_float(spec, double=False))
+
+    def match_Float4(self, spec, got, want):
+        return self.match_float(spec, got, want, approx=True)
+
+    def make_Float8(self, spec):
+        return spec(self.make_float(spec))
+
+    match_Float8 = match_float
+
     def make_int(self, spec):
         return randrange(-(1 << 90), 1 << 90)
 
index 69e83d6666371a501c09be20f44761b3cb43eed5..ba31553f72f2c523cdb1b633d121809dc2c8675c 100644 (file)
@@ -151,6 +151,28 @@ def test_array_mixed_numbers(array, type):
     assert dumper.oid == builtins[type].array_oid
 
 
+@pytest.mark.parametrize(
+    "wrapper", "Int2 Int4 Int8 Float4 Float8 Decimal".split()
+)
+@pytest.mark.parametrize("fmt_in", fmts_in)
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out):
+    wrapper = getattr(psycopg.types.numeric, wrapper)
+    if wrapper is Decimal:
+        want_cls = Decimal
+    else:
+        assert wrapper.__mro__[1] in (int, float)
+        want_cls = wrapper.__mro__[1]
+
+    obj = [wrapper(1), wrapper(0), wrapper(-1), None]
+    cur = conn.cursor(binary=fmt_out)
+    got = cur.execute("select %s", [obj]).fetchone()[0]
+    assert got == obj
+    for i in got:
+        if i is not None:
+            assert type(i) is want_cls
+
+
 def test_mix_types(conn):
     cur = conn.cursor()
     cur.execute("create table test (id serial primary key, data numeric[])")
index d8461f0350e4add1f707df5527b63d8cba7aea74..f4621dcafccaa14d961a5bf041ac2e5fc7686327 100644 (file)
@@ -3,6 +3,7 @@ from math import isnan, isinf, exp
 
 import pytest
 
+import psycopg
 from psycopg import pq
 from psycopg import sql
 from psycopg.adapt import Transformer, PyFormat as Format
@@ -566,3 +567,22 @@ def test_minus_minus_quote(conn, pgtype):
     cur.execute(sql.SQL("select -{}{}").format(sql.Literal(-1), sql.SQL(cast)))
     result = cur.fetchone()[0]
     assert result == 1
+
+
+@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split())
+def test_dump_wrapper_oid(wrapper):
+    wrapper = getattr(psycopg.types.numeric, wrapper)
+    base = wrapper.__mro__[1]
+    assert base in (int, float)
+    n = base(3.14)
+    assert str(wrapper(n)) == str(n)
+    assert repr(wrapper(n)) == f"{wrapper.__name__}({n})"
+
+
+@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split())
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_repr_wrapper(conn, wrapper, fmt_in):
+    wrapper = getattr(psycopg.types.numeric, wrapper)
+    cur = conn.execute(f"select pg_typeof(%{fmt_in})::oid", [wrapper(0)])
+    oid = cur.fetchone()[0]
+    assert oid == psycopg.postgres.types[wrapper.__name__.lower()].oid