pack_int4 = cast(PackInt, struct.Struct("!i").pack)
pack_uint4 = cast(PackInt, struct.Struct("!I").pack)
pack_int8 = cast(PackInt, struct.Struct("!q").pack)
+pack_float4 = cast(PackFloat, struct.Struct("!f").pack)
pack_float8 = cast(PackFloat, struct.Struct("!d").pack)
unpack_int2 = cast(UnpackInt, struct.Struct("!h").unpack)
class Int2(int):
__module__ = _MODULE
+ __slots__ = ()
def __new__(cls, arg: int) -> "Int2":
return super().__new__(cls, arg)
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
class Int4(int):
__module__ = _MODULE
+ __slots__ = ()
def __new__(cls, arg: int) -> "Int4":
return super().__new__(cls, arg)
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
class Int8(int):
__module__ = _MODULE
+ __slots__ = ()
def __new__(cls, arg: int) -> "Int8":
return super().__new__(cls, arg)
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
class IntNumeric(int):
__module__ = _MODULE
+ __slots__ = ()
def __new__(cls, arg: int) -> "IntNumeric":
return super().__new__(cls, arg)
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Float4(float):
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: float) -> "Float4":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Float8(float):
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: float) -> "Float8":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
class Oid(int):
__module__ = _MODULE
+ __slots__ = ()
def __new__(cls, arg: int) -> "Oid":
return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
from .._struct import pack_int2, pack_uint2, unpack_int2
from .._struct import pack_int4, pack_uint4, unpack_int4, unpack_uint4
from .._struct import pack_int8, unpack_int8
-from .._struct import pack_float8, unpack_float4, unpack_float8
+from .._struct import pack_float4, pack_float8, unpack_float4, unpack_float8
# Exposed here
from .._wrappers import (
Int8 as Int8,
IntNumeric as IntNumeric,
Oid as Oid,
+ Float4 as Float4,
+ Float8 as Float8,
)
}
+class Float4Dumper(FloatDumper):
+ _oid = postgres.types["float4"].oid
+
+
class FloatBinaryDumper(Dumper):
format = Format.BINARY
return pack_float8(obj)
+class Float4BinaryDumper(FloatBinaryDumper):
+
+ _oid = postgres.types["float4"].oid
+
+ def dump(self, obj: float) -> bytes:
+ return pack_float4(obj)
+
+
class DecimalDumper(_SpecialValuesDumper):
_oid = postgres.types["numeric"].oid
adapters.register_dumper(Int8, Int8Dumper)
adapters.register_dumper(IntNumeric, IntNumericDumper)
adapters.register_dumper(Oid, OidDumper)
+ adapters.register_dumper(Float4, Float4Dumper)
+ adapters.register_dumper(Float8, FloatDumper)
adapters.register_dumper(Int2, Int2BinaryDumper)
adapters.register_dumper(Int4, Int4BinaryDumper)
adapters.register_dumper(Int8, Int8BinaryDumper)
adapters.register_dumper(Oid, OidBinaryDumper)
+ adapters.register_dumper(Float4, Float4BinaryDumper)
+ adapters.register_dumper(Float8, FloatBinaryDumper)
adapters.register_loader("int2", IntLoader)
adapters.register_loader("int4", IntLoader)
adapters.register_loader("int8", IntLoader)
else:
assert got == want
- def make_float(self, spec):
+ def make_float(self, spec, double=True):
if random() <= 0.99:
- # this exponent should generate no inf
+ # These exponents should generate no inf
return float(
f"{choice('-+')}0.{randrange(1 << 53)}e{randrange(-310,309)}"
+ if double
+ else f"{choice('-+')}0.{randrange(1 << 22)}e{randrange(-37,38)}"
)
else:
return choice(
(0.0, -0.0, float("-inf"), float("inf"), float("nan"))
)
- def match_float(self, spec, got, want):
+ def match_float(self, spec, got, want, approx=False):
if got is not None and isnan(got):
assert isnan(want)
else:
# Versions older than 12 make some rounding. e.g. in Postgres 10.4
# select '-1.409006204063909e+112'::float8
# -> -1.40900620406391e+112
- if self.conn.info.server_version >= 120000:
+ if not approx and self.conn.info.server_version >= 120000:
assert got == want
else:
assert got == pytest.approx(want)
+ def make_Float4(self, spec):
+ return spec(self.make_float(spec, double=False))
+
+ def match_Float4(self, spec, got, want):
+ return self.match_float(spec, got, want, approx=True)
+
+ def make_Float8(self, spec):
+ return spec(self.make_float(spec))
+
+ match_Float8 = match_float
+
def make_int(self, spec):
return randrange(-(1 << 90), 1 << 90)
assert dumper.oid == builtins[type].array_oid
+@pytest.mark.parametrize(
+ "wrapper", "Int2 Int4 Int8 Float4 Float8 Decimal".split()
+)
+@pytest.mark.parametrize("fmt_in", fmts_in)
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out):
+ wrapper = getattr(psycopg.types.numeric, wrapper)
+ if wrapper is Decimal:
+ want_cls = Decimal
+ else:
+ assert wrapper.__mro__[1] in (int, float)
+ want_cls = wrapper.__mro__[1]
+
+ obj = [wrapper(1), wrapper(0), wrapper(-1), None]
+ cur = conn.cursor(binary=fmt_out)
+ got = cur.execute("select %s", [obj]).fetchone()[0]
+ assert got == obj
+ for i in got:
+ if i is not None:
+ assert type(i) is want_cls
+
+
def test_mix_types(conn):
cur = conn.cursor()
cur.execute("create table test (id serial primary key, data numeric[])")
import pytest
+import psycopg
from psycopg import pq
from psycopg import sql
from psycopg.adapt import Transformer, PyFormat as Format
cur.execute(sql.SQL("select -{}{}").format(sql.Literal(-1), sql.SQL(cast)))
result = cur.fetchone()[0]
assert result == 1
+
+
+@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split())
+def test_dump_wrapper_oid(wrapper):
+ wrapper = getattr(psycopg.types.numeric, wrapper)
+ base = wrapper.__mro__[1]
+ assert base in (int, float)
+ n = base(3.14)
+ assert str(wrapper(n)) == str(n)
+ assert repr(wrapper(n)) == f"{wrapper.__name__}({n})"
+
+
+@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split())
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_repr_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.numeric, wrapper)
+ cur = conn.execute(f"select pg_typeof(%{fmt_in})::oid", [wrapper(0)])
+ oid = cur.fetchone()[0]
+ assert oid == psycopg.postgres.types[wrapper.__name__.lower()].oid