]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: add numpy dumpers for int, float, bool types
authorVertemati Francesco <verte.fra@gmail.com>
Tue, 5 Jul 2022 15:37:32 +0000 (11:37 -0400)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 5 Aug 2023 14:21:30 +0000 (15:21 +0100)
Close #192.

.github/workflows/tests.yml
psycopg/psycopg/postgres.py
psycopg/psycopg/types/numpy.py [new file with mode: 0644]
pyproject.toml
tests/conftest.py
tests/types/test_numpy.py [new file with mode: 0644]

index 98cf39c9bd9b223268382e2ca1efbbbb969a180a..880436462d00c6cce4ce3a6de5887f8d2b4a9de0 100644 (file)
@@ -42,6 +42,8 @@ jobs:
 
           - {impl: python, python: "3.9", ext: dns, postgres: "postgres:14"}
           - {impl: python, python: "3.9", ext: postgis, postgres: "postgis/postgis"}
+          - {impl: python, python: "3.10", ext: numpy, postgres: "postgres:14"}
+          - {impl: c, python: "3.11", ext: numpy, postgres: "postgres:15"}
 
           # Test with minimum dependencies versions
           - {impl: c, python: "3.7", ext: min, postgres: "postgres:15"}
@@ -88,10 +90,15 @@ jobs:
           echo "DEPS=$DEPS shapely" >> $GITHUB_ENV
           echo "MARKERS=$MARKERS postgis" >> $GITHUB_ENV
 
+      - if: ${{ matrix.ext == 'numpy' }}
+        run: |
+          echo "DEPS=$DEPS numpy" >> $GITHUB_ENV
+          echo "MARKERS=$MARKERS numpy" >> $GITHUB_ENV
+
       - name: Configure to use the oldest dependencies
         if: ${{ matrix.ext == 'min' }}
         run: |
-          echo "DEPS=$DEPS dnspython shapely" >> $GITHUB_ENV
+          echo "DEPS=$DEPS dnspython shapely numpy" >> $GITHUB_ENV
           echo "PIP_CONSTRAINT=${{ github.workspace }}/tests/constraints.txt" \
             >> $GITHUB_ENV
 
index 44565952d2ffff2527606aa27ca17302e19f9224..045491373aa986e0610627a004c26e961015e2df 100644 (file)
@@ -106,7 +106,7 @@ def register_default_types(types: TypesRegistry) -> None:
 
 def register_default_adapters(context: AdaptContext) -> None:
     from .types import array, bool, composite, datetime, enum, json, multirange
-    from .types import net, none, numeric, range, string, uuid
+    from .types import net, none, numeric, numpy, range, string, uuid
 
     array.register_default_adapters(context)
     bool.register_default_adapters(context)
@@ -118,6 +118,7 @@ def register_default_adapters(context: AdaptContext) -> None:
     net.register_default_adapters(context)
     none.register_default_adapters(context)
     numeric.register_default_adapters(context)
+    numpy.register_default_adapters(context)
     range.register_default_adapters(context)
     string.register_default_adapters(context)
     uuid.register_default_adapters(context)
diff --git a/psycopg/psycopg/types/numpy.py b/psycopg/psycopg/types/numpy.py
new file mode 100644 (file)
index 0000000..17d6c10
--- /dev/null
@@ -0,0 +1,228 @@
+"""
+Adapters for numpy types.
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Optional,
+    Union,
+    cast,
+)
+
+from .numeric import dump_int_to_numeric_binary
+
+from .. import postgres
+from .._struct import pack_int2, pack_int4, pack_int8
+from ..abc import AdaptContext, Buffer
+from ..adapt import Dumper
+from ..pq import Format
+from .. import _struct
+
+if TYPE_CHECKING:
+    import numpy as np
+
+PackNumpyFloat4 = Callable[[Union["np.float16", "np.float32"]], bytes]
+PackNumpyFloat8 = Callable[["np.float64"], bytes]
+
+pack_float8 = cast(PackNumpyFloat8, _struct.pack_float8)
+pack_float4 = cast(PackNumpyFloat4, _struct.pack_float4)
+
+
+class _NPIntDumper(Dumper):
+    def dump(self, obj: Any) -> bytes:
+        return str(obj).encode()
+
+    def quote(self, obj: Any) -> bytes:
+        value = self.dump(obj)
+        return value if obj >= 0 else b" " + value
+
+
+class NPInt8Dumper(_NPIntDumper):
+    oid = postgres.types["int2"].oid
+
+
+NPInt16Dumper = NPInt8Dumper
+NPUInt8Dumper = NPInt8Dumper
+
+
+class NPInt32Dumper(_NPIntDumper):
+    oid = postgres.types["int4"].oid
+
+
+NPUInt16Dumper = NPInt32Dumper
+
+
+class NPInt64Dumper(_NPIntDumper):
+    oid = postgres.types["int8"].oid
+
+
+NPUInt32Dumper = NPInt64Dumper
+
+
+class NPBooleanDumper(_NPIntDumper):
+    oid = postgres.types["bool"].oid
+
+    def dump(self, obj: "np.bool_") -> bytes:
+        return "t".encode() if bool(obj) is True else "f".encode()
+
+
+class NPUInt64Dumper(_NPIntDumper):
+    oid = postgres.types["numeric"].oid
+
+
+NPULongLongDumper = NPUInt64Dumper
+
+
+class _SpecialValuesDumper(Dumper):
+
+    _special: Dict[bytes, bytes] = {}
+
+    def dump(self, obj: float) -> Buffer:
+        return str(obj).encode()
+
+    def quote(self, obj: float) -> Buffer:
+        value = self.dump(obj)
+
+        if value in self._special:
+            return self._special[value]
+
+        return value if obj >= 0 else b" " + value
+
+
+class _NPFloatDumper(_SpecialValuesDumper):
+
+    _special = {
+        b"inf": b"'Infinity'::float8",
+        b"-inf": b"'-Infinity'::float8",
+        b"nan": b"'NaN'::float8",
+    }
+
+
+class NPFloat16Dumper(_NPFloatDumper):
+    oid = postgres.types["float4"].oid
+
+
+NPFloat32Dumper = NPFloat16Dumper
+
+
+class NPFloat64Dumper(_NPFloatDumper):
+
+    oid = postgres.types["float8"].oid
+
+
+# Binary Dumpers
+
+
+class NPFloat16BinaryDumper(Dumper):
+    format = Format.BINARY
+    oid = postgres.types["float4"].oid
+
+    def dump(self, obj: "np.float16") -> bytes:
+        return pack_float4(obj)
+
+
+class NPFloat32BinaryDumper(Dumper):
+    format = Format.BINARY
+    oid = postgres.types["float4"].oid
+
+    def dump(self, obj: "np.float32") -> bytes:
+        return pack_float4(obj)
+
+
+class NPFloat64BinaryDumper(Dumper):
+    format = Format.BINARY
+    oid = postgres.types["float8"].oid
+
+    def dump(self, obj: "np.float64") -> bytes:
+        return pack_float8(obj)
+
+
+class NPInt8BinaryDumper(NPInt8Dumper):
+
+    format = Format.BINARY
+
+    def dump(self, obj: "np.int8") -> bytes:
+        return pack_int2(int(obj))
+
+
+NPInt16BinaryDumper = NPInt8BinaryDumper
+NPUInt8BinaryDumper = NPInt8BinaryDumper
+
+
+class NPInt32BinaryDumper(NPInt32Dumper):
+
+    format = Format.BINARY
+
+    def dump(self, obj: Any) -> bytes:
+        return pack_int4(int(obj))
+
+
+NPUInt16BinaryDumper = NPInt32BinaryDumper
+
+
+class NPInt64BinaryDumper(NPInt64Dumper):
+
+    format = Format.BINARY
+
+    def dump(self, obj: Any) -> bytes:
+        return pack_int8(int(obj))
+
+
+NPUInt32BinaryDumper = NPInt64BinaryDumper
+
+
+class NPBooleanBinaryDumper(NPBooleanDumper):
+
+    format = Format.BINARY
+
+    def dump(self, obj: Any) -> bytes:
+        return b"\x01" if obj else b"\x00"
+
+
+class NPUInt64BinaryDumper(NPUInt64Dumper):
+
+    format = Format.BINARY
+
+    def dump(self, obj: Any) -> Buffer:
+
+        return dump_int_to_numeric_binary(int(obj))
+
+
+NPUlongLongBinaryDumper = NPUInt64BinaryDumper
+
+
+def register_default_adapters(context: Optional[AdaptContext] = None) -> None:
+    adapters = context.adapters if context else postgres.adapters
+
+    adapters.register_dumper("numpy.int8", NPInt8Dumper)
+    adapters.register_dumper("numpy.int16", NPInt16Dumper)
+    adapters.register_dumper("numpy.int32", NPInt32Dumper)
+    adapters.register_dumper("numpy.int64", NPInt64Dumper)
+    adapters.register_dumper("numpy.bool_", NPBooleanDumper)
+    adapters.register_dumper("numpy.uint8", NPUInt8Dumper)
+    adapters.register_dumper("numpy.uint16", NPUInt16Dumper)
+    adapters.register_dumper("numpy.uint32", NPUInt32Dumper)
+    adapters.register_dumper("numpy.uint64", NPUInt64Dumper)
+    adapters.register_dumper("numpy.ulonglong", NPULongLongDumper)
+    adapters.register_dumper("numpy.float16", NPFloat16Dumper)
+    adapters.register_dumper("numpy.float32", NPFloat32Dumper)
+    adapters.register_dumper("numpy.float64", NPFloat64Dumper)
+
+    adapters.register_dumper("numpy.int8", NPInt8BinaryDumper)
+    adapters.register_dumper("numpy.int16", NPInt16BinaryDumper)
+    adapters.register_dumper("numpy.int32", NPInt32BinaryDumper)
+    adapters.register_dumper("numpy.int64", NPInt64BinaryDumper)
+    adapters.register_dumper("numpy.bool_", NPBooleanBinaryDumper)
+    adapters.register_dumper("numpy.uint8", NPUInt8BinaryDumper)
+    adapters.register_dumper("numpy.uint16", NPUInt16BinaryDumper)
+    adapters.register_dumper("numpy.uint32", NPUInt32BinaryDumper)
+    adapters.register_dumper("numpy.uint64", NPUInt64BinaryDumper)
+    adapters.register_dumper("numpy.ulonglong", NPUlongLongBinaryDumper)
+    adapters.register_dumper("numpy.float16", NPFloat16BinaryDumper)
+    adapters.register_dumper("numpy.float32", NPFloat32BinaryDumper)
+    adapters.register_dumper("numpy.float64", NPFloat64BinaryDumper)
index 950bccb4a7bbf60e42b759c6c93a575ddd582083..7a507bc43f376e013120e63cdc1104fa831e9424 100644 (file)
@@ -40,6 +40,7 @@ strict = true
 [[tool.mypy.overrides]]
 module = [
     "shapely.*",
+    "numpy.*",
 ]
 ignore_missing_imports = true
 
index 1ec997bf9872e291bb1e50d9eb9834c147c5e31e..e674f6e765f7f2258e8b78eab3e3d495c573f4c2 100644 (file)
@@ -27,6 +27,7 @@ def pytest_configure(config):
         "timing: the test is timing based and can fail on cheese hardware",
         "dns: the test requires dnspython to run",
         "postgis: the test requires the PostGIS extension to run",
+        "numpy: the test requires numpy module to be installed",
     ]
 
     for marker in markers:
diff --git a/tests/types/test_numpy.py b/tests/types/test_numpy.py
new file mode 100644 (file)
index 0000000..5fc8bf7
--- /dev/null
@@ -0,0 +1,351 @@
+import pytest
+from psycopg.adapt import PyFormat
+
+pytest.importorskip("numpy")
+
+
+import numpy as np  # noqa: E402
+
+
+pytestmark = [pytest.mark.numpy]
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        (-128, "'-128'::int2"),
+        (127, "'127'::int2"),
+        (0, "'0'::int2"),
+        (45, "'45'::int2"),
+    ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_int8(conn, val, expr, fmt_in):
+    val = np.byte(val)
+
+    assert isinstance(val, np.byte)
+    assert np.byte is np.int8
+
+    cur = conn.cursor()
+
+    cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+    assert cur.fetchone()[0] is True
+
+    cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        (-32_768, "'-32768'::int2"),
+        (32_767, "'32767'::int2"),
+        (0, "'0'::int2"),
+        (45, "'45'::int2"),
+    ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_int16(conn, val, expr, fmt_in):
+
+    val = np.short(val)
+
+    assert isinstance(val, np.short)
+    assert np.short is np.int16
+
+    cur = conn.cursor()
+
+    cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+    assert cur.fetchone()[0] is True
+
+    cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        (-2_147_483_648, "'-2147483648'::int4"),
+        (2_147_483_647, "'2147483647'::int4"),
+        (0, "'0'::int4"),
+        (45, "'45'::int4"),
+    ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_int32(conn, val, expr, fmt_in):
+
+    val = np.intc(val)
+
+    assert isinstance(val, np.intc)
+    assert np.intc is np.int32
+
+    cur = conn.cursor()
+
+    cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+    assert cur.fetchone()[0] is True
+
+    cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        (-9_223_372_036_854_775_808, "'-9223372036854775808'::int8"),
+        (9_223_372_036_854_775_807, "'9223372036854775807'::int8"),
+        (0, "'0'::int8"),
+        (45, "'45'::int8"),
+    ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_int64(conn, val, expr, fmt_in):
+
+    val = np.int_(val)
+
+    assert isinstance(val, np.int_)
+    assert np.int_ is np.int64
+
+    cur = conn.cursor()
+
+    cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+    assert cur.fetchone()[0] is True
+
+    cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        (True, "'t'::bool"),
+        (False, "'f'::bool"),
+    ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_bool8(conn, val, expr, fmt_in):
+
+    val = np.bool_(val)
+
+    assert isinstance(val, np.bool_)
+    assert np.bool_ is np.bool8
+
+    cur = conn.cursor()
+
+    cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+    assert cur.fetchone()[0] is True
+
+    cur.execute(f"select {expr} = %{fmt_in.value}", (bool(val),))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [(0, "'0'::int2"), (255, "'255'::int2")],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_uint8(conn, val, expr, fmt_in):
+
+    val = np.ubyte(val)
+
+    assert isinstance(val, np.ubyte)
+    assert np.ubyte is np.uint8
+
+    cur = conn.cursor()
+
+    cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+    assert cur.fetchone()[0] is True
+
+    cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        (0, "'0'::int4"),
+        (65_535, "'65535'::int4"),
+    ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_uint16(conn, val, expr, fmt_in):
+
+    val = np.ushort(val)
+
+    assert isinstance(val, np.ushort)
+    assert np.ushort is np.uint16
+
+    cur = conn.cursor()
+
+    cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+    assert cur.fetchone()[0] is True
+
+    cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        (0, "'0'::int8"),
+        (4_294_967_295, "'4294967295'::int8"),
+    ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_uint32(conn, val, expr, fmt_in):
+
+    val = np.uintc(val)
+
+    assert isinstance(val, np.uintc)
+    assert np.uintc is np.uint32
+
+    cur = conn.cursor()
+
+    cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+    assert cur.fetchone()[0] is True
+
+    cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        (0, "'0'::numeric"),
+        (18_446_744_073_709_551_615, "'18446744073709551615'::numeric"),
+    ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_uint64(conn, val, expr, fmt_in):
+
+    val = np.uint(val)
+
+    assert isinstance(val, np.uint)
+    assert np.uint is np.uint64
+
+    cur = conn.cursor()
+
+    cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+    assert cur.fetchone()[0] is True
+
+    cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        (0, "'0'::numeric"),
+        (18_446_744_073_709_551_615, "'18446744073709551615'::numeric"),
+    ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_ulonglong(conn, val, expr, fmt_in):
+
+    val = np.ulonglong(val)
+
+    assert isinstance(val, np.ulonglong)
+
+    cur = conn.cursor()
+
+    cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+    assert cur.fetchone()[0] is True
+
+    cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+    assert cur.fetchone()[0] is True
+
+
+# Test float special values
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        (np.PZERO, "'0.0'::float8"),
+        (np.NZERO, "'-0.0'::float8"),
+        (np.nan, "'NaN'::float8"),
+        (np.inf, "'Infinity'::float8"),
+        (np.NINF, "'-Infinity'::float8"),
+    ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_special_values(conn, val, expr, fmt_in):
+
+    if val == np.nan:
+        assert np.nan == np.NAN == np.NaN
+
+    if val == np.inf:
+        assert np.inf == np.Inf == np.PINF == np.infty
+
+    assert isinstance(val, float)
+
+    cur = conn.cursor()
+    cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+
+    assert cur.fetchone()[0] is True
+
+    cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val",
+    [
+        "4e4",
+        # "4e-4",
+        "4000.0",
+        # "3.14",
+    ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_float16(conn, val, fmt_in):
+
+    val = np.float16(val)
+    cur = conn.cursor()
+
+    cur.execute(f"select pg_typeof({val}::float4) = pg_typeof(%{fmt_in.value})", (val,))
+    assert cur.fetchone()[0] is True
+
+    cur.execute(f"select {val}::float4 = %{fmt_in.value}", (val,))
+
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val",
+    [
+        "256e6",
+        "256e-6",
+        "2.7182817",
+        "3.1415927",
+    ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_float32(conn, val, fmt_in):
+
+    val = np.float32(val)
+    cur = conn.cursor()
+
+    cur.execute(f"select pg_typeof({val}::float4) = pg_typeof(%{fmt_in.value})", (val,))
+    assert cur.fetchone()[0] is True
+
+    cur.execute(f"select {val}::float4 = %{fmt_in.value}", (val,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val",
+    [
+        "256e12",
+        "256e-12",
+        "2.718281828459045",
+        "3.141592653589793",
+    ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_numpy_float64(conn, val, fmt_in):
+
+    val = np.float64(val)
+    cur = conn.cursor()
+
+    cur.execute(f"select pg_typeof({val}::float8) = pg_typeof(%{fmt_in.value})", (val,))
+    assert cur.fetchone()[0] is True
+
+    cur.execute(f"select {val}::float8 = %{fmt_in.value}", (val,))
+    assert cur.fetchone()[0] is True