--- /dev/null
+"""
+Adapters for numpy types.
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Optional,
+ Union,
+ cast,
+)
+
+from .numeric import dump_int_to_numeric_binary
+
+from .. import postgres
+from .._struct import pack_int2, pack_int4, pack_int8
+from ..abc import AdaptContext, Buffer
+from ..adapt import Dumper
+from ..pq import Format
+from .. import _struct
+
+if TYPE_CHECKING:
+ import numpy as np
+
+PackNumpyFloat4 = Callable[[Union["np.float16", "np.float32"]], bytes]
+PackNumpyFloat8 = Callable[["np.float64"], bytes]
+
+pack_float8 = cast(PackNumpyFloat8, _struct.pack_float8)
+pack_float4 = cast(PackNumpyFloat4, _struct.pack_float4)
+
+
+class _NPIntDumper(Dumper):
+ def dump(self, obj: Any) -> bytes:
+ return str(obj).encode()
+
+ def quote(self, obj: Any) -> bytes:
+ value = self.dump(obj)
+ return value if obj >= 0 else b" " + value
+
+
+class NPInt8Dumper(_NPIntDumper):
+ oid = postgres.types["int2"].oid
+
+
+NPInt16Dumper = NPInt8Dumper
+NPUInt8Dumper = NPInt8Dumper
+
+
+class NPInt32Dumper(_NPIntDumper):
+ oid = postgres.types["int4"].oid
+
+
+NPUInt16Dumper = NPInt32Dumper
+
+
+class NPInt64Dumper(_NPIntDumper):
+ oid = postgres.types["int8"].oid
+
+
+NPUInt32Dumper = NPInt64Dumper
+
+
+class NPBooleanDumper(_NPIntDumper):
+ oid = postgres.types["bool"].oid
+
+ def dump(self, obj: "np.bool_") -> bytes:
+ return "t".encode() if bool(obj) is True else "f".encode()
+
+
+class NPUInt64Dumper(_NPIntDumper):
+ oid = postgres.types["numeric"].oid
+
+
+NPULongLongDumper = NPUInt64Dumper
+
+
+class _SpecialValuesDumper(Dumper):
+
+ _special: Dict[bytes, bytes] = {}
+
+ def dump(self, obj: float) -> Buffer:
+ return str(obj).encode()
+
+ def quote(self, obj: float) -> Buffer:
+ value = self.dump(obj)
+
+ if value in self._special:
+ return self._special[value]
+
+ return value if obj >= 0 else b" " + value
+
+
+class _NPFloatDumper(_SpecialValuesDumper):
+
+ _special = {
+ b"inf": b"'Infinity'::float8",
+ b"-inf": b"'-Infinity'::float8",
+ b"nan": b"'NaN'::float8",
+ }
+
+
+class NPFloat16Dumper(_NPFloatDumper):
+ oid = postgres.types["float4"].oid
+
+
+NPFloat32Dumper = NPFloat16Dumper
+
+
+class NPFloat64Dumper(_NPFloatDumper):
+
+ oid = postgres.types["float8"].oid
+
+
+# Binary Dumpers
+
+
+class NPFloat16BinaryDumper(Dumper):
+ format = Format.BINARY
+ oid = postgres.types["float4"].oid
+
+ def dump(self, obj: "np.float16") -> bytes:
+ return pack_float4(obj)
+
+
+class NPFloat32BinaryDumper(Dumper):
+ format = Format.BINARY
+ oid = postgres.types["float4"].oid
+
+ def dump(self, obj: "np.float32") -> bytes:
+ return pack_float4(obj)
+
+
+class NPFloat64BinaryDumper(Dumper):
+ format = Format.BINARY
+ oid = postgres.types["float8"].oid
+
+ def dump(self, obj: "np.float64") -> bytes:
+ return pack_float8(obj)
+
+
+class NPInt8BinaryDumper(NPInt8Dumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: "np.int8") -> bytes:
+ return pack_int2(int(obj))
+
+
+NPInt16BinaryDumper = NPInt8BinaryDumper
+NPUInt8BinaryDumper = NPInt8BinaryDumper
+
+
+class NPInt32BinaryDumper(NPInt32Dumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: Any) -> bytes:
+ return pack_int4(int(obj))
+
+
+NPUInt16BinaryDumper = NPInt32BinaryDumper
+
+
+class NPInt64BinaryDumper(NPInt64Dumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: Any) -> bytes:
+ return pack_int8(int(obj))
+
+
+NPUInt32BinaryDumper = NPInt64BinaryDumper
+
+
+class NPBooleanBinaryDumper(NPBooleanDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: Any) -> bytes:
+ return b"\x01" if obj else b"\x00"
+
+
+class NPUInt64BinaryDumper(NPUInt64Dumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: Any) -> Buffer:
+
+ return dump_int_to_numeric_binary(int(obj))
+
+
+NPUlongLongBinaryDumper = NPUInt64BinaryDumper
+
+
+def register_default_adapters(context: Optional[AdaptContext] = None) -> None:
+ adapters = context.adapters if context else postgres.adapters
+
+ adapters.register_dumper("numpy.int8", NPInt8Dumper)
+ adapters.register_dumper("numpy.int16", NPInt16Dumper)
+ adapters.register_dumper("numpy.int32", NPInt32Dumper)
+ adapters.register_dumper("numpy.int64", NPInt64Dumper)
+ adapters.register_dumper("numpy.bool_", NPBooleanDumper)
+ adapters.register_dumper("numpy.uint8", NPUInt8Dumper)
+ adapters.register_dumper("numpy.uint16", NPUInt16Dumper)
+ adapters.register_dumper("numpy.uint32", NPUInt32Dumper)
+ adapters.register_dumper("numpy.uint64", NPUInt64Dumper)
+ adapters.register_dumper("numpy.ulonglong", NPULongLongDumper)
+ adapters.register_dumper("numpy.float16", NPFloat16Dumper)
+ adapters.register_dumper("numpy.float32", NPFloat32Dumper)
+ adapters.register_dumper("numpy.float64", NPFloat64Dumper)
+
+ adapters.register_dumper("numpy.int8", NPInt8BinaryDumper)
+ adapters.register_dumper("numpy.int16", NPInt16BinaryDumper)
+ adapters.register_dumper("numpy.int32", NPInt32BinaryDumper)
+ adapters.register_dumper("numpy.int64", NPInt64BinaryDumper)
+ adapters.register_dumper("numpy.bool_", NPBooleanBinaryDumper)
+ adapters.register_dumper("numpy.uint8", NPUInt8BinaryDumper)
+ adapters.register_dumper("numpy.uint16", NPUInt16BinaryDumper)
+ adapters.register_dumper("numpy.uint32", NPUInt32BinaryDumper)
+ adapters.register_dumper("numpy.uint64", NPUInt64BinaryDumper)
+ adapters.register_dumper("numpy.ulonglong", NPUlongLongBinaryDumper)
+ adapters.register_dumper("numpy.float16", NPFloat16BinaryDumper)
+ adapters.register_dumper("numpy.float32", NPFloat32BinaryDumper)
+ adapters.register_dumper("numpy.float64", NPFloat64BinaryDumper)
--- /dev/null
+import pytest
+from psycopg.adapt import PyFormat
+
+pytest.importorskip("numpy")
+
+
+import numpy as np # noqa: E402
+
+
+pytestmark = [pytest.mark.numpy]
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (-128, "'-128'::int2"),
+ (127, "'127'::int2"),
+ (0, "'0'::int2"),
+ (45, "'45'::int2"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_int8(conn, val, expr, fmt_in):
+ val = np.byte(val)
+
+ assert isinstance(val, np.byte)
+ assert np.byte is np.int8
+
+ cur = conn.cursor()
+
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (-32_768, "'-32768'::int2"),
+ (32_767, "'32767'::int2"),
+ (0, "'0'::int2"),
+ (45, "'45'::int2"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_int16(conn, val, expr, fmt_in):
+
+ val = np.short(val)
+
+ assert isinstance(val, np.short)
+ assert np.short is np.int16
+
+ cur = conn.cursor()
+
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (-2_147_483_648, "'-2147483648'::int4"),
+ (2_147_483_647, "'2147483647'::int4"),
+ (0, "'0'::int4"),
+ (45, "'45'::int4"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_int32(conn, val, expr, fmt_in):
+
+ val = np.intc(val)
+
+ assert isinstance(val, np.intc)
+ assert np.intc is np.int32
+
+ cur = conn.cursor()
+
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (-9_223_372_036_854_775_808, "'-9223372036854775808'::int8"),
+ (9_223_372_036_854_775_807, "'9223372036854775807'::int8"),
+ (0, "'0'::int8"),
+ (45, "'45'::int8"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_int64(conn, val, expr, fmt_in):
+
+ val = np.int_(val)
+
+ assert isinstance(val, np.int_)
+ assert np.int_ is np.int64
+
+ cur = conn.cursor()
+
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (True, "'t'::bool"),
+ (False, "'f'::bool"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_bool8(conn, val, expr, fmt_in):
+
+ val = np.bool_(val)
+
+ assert isinstance(val, np.bool_)
+ assert np.bool_ is np.bool8
+
+ cur = conn.cursor()
+
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select {expr} = %{fmt_in.value}", (bool(val),))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [(0, "'0'::int2"), (255, "'255'::int2")],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_uint8(conn, val, expr, fmt_in):
+
+ val = np.ubyte(val)
+
+ assert isinstance(val, np.ubyte)
+ assert np.ubyte is np.uint8
+
+ cur = conn.cursor()
+
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0, "'0'::int4"),
+ (65_535, "'65535'::int4"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_uint16(conn, val, expr, fmt_in):
+
+ val = np.ushort(val)
+
+ assert isinstance(val, np.ushort)
+ assert np.ushort is np.uint16
+
+ cur = conn.cursor()
+
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0, "'0'::int8"),
+ (4_294_967_295, "'4294967295'::int8"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_uint32(conn, val, expr, fmt_in):
+
+ val = np.uintc(val)
+
+ assert isinstance(val, np.uintc)
+ assert np.uintc is np.uint32
+
+ cur = conn.cursor()
+
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0, "'0'::numeric"),
+ (18_446_744_073_709_551_615, "'18446744073709551615'::numeric"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_uint64(conn, val, expr, fmt_in):
+
+ val = np.uint(val)
+
+ assert isinstance(val, np.uint)
+ assert np.uint is np.uint64
+
+ cur = conn.cursor()
+
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0, "'0'::numeric"),
+ (18_446_744_073_709_551_615, "'18446744073709551615'::numeric"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_ulonglong(conn, val, expr, fmt_in):
+
+ val = np.ulonglong(val)
+
+ assert isinstance(val, np.ulonglong)
+
+ cur = conn.cursor()
+
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+
+# Test float special values
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (np.PZERO, "'0.0'::float8"),
+ (np.NZERO, "'-0.0'::float8"),
+ (np.nan, "'NaN'::float8"),
+ (np.inf, "'Infinity'::float8"),
+ (np.NINF, "'-Infinity'::float8"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_special_values(conn, val, expr, fmt_in):
+
+ if val == np.nan:
+ assert np.nan == np.NAN == np.NaN
+
+ if val == np.inf:
+ assert np.inf == np.Inf == np.PINF == np.infty
+
+ assert isinstance(val, float)
+
+ cur = conn.cursor()
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val",
+ [
+ "4e4",
+ # "4e-4",
+ "4000.0",
+ # "3.14",
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_float16(conn, val, fmt_in):
+
+ val = np.float16(val)
+ cur = conn.cursor()
+
+ cur.execute(f"select pg_typeof({val}::float4) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select {val}::float4 = %{fmt_in.value}", (val,))
+
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val",
+ [
+ "256e6",
+ "256e-6",
+ "2.7182817",
+ "3.1415927",
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_float32(conn, val, fmt_in):
+
+ val = np.float32(val)
+ cur = conn.cursor()
+
+ cur.execute(f"select pg_typeof({val}::float4) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select {val}::float4 = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val",
+ [
+ "256e12",
+ "256e-12",
+ "2.718281828459045",
+ "3.141592653589793",
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_float64(conn, val, fmt_in):
+
+ val = np.float64(val)
+ cur = conn.cursor()
+
+ cur.execute(f"select pg_typeof({val}::float8) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select {val}::float8 = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True