From: Vertemati Francesco Date: Tue, 5 Jul 2022 15:37:32 +0000 (-0400) Subject: feat: add numpy dumpers for int, float, bool types X-Git-Tag: pool-3.2.0~70^2~25 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=15fbcd62b9b9cd772923edf7c3a9c4ed3e421c60;p=thirdparty%2Fpsycopg.git feat: add numpy dumpers for int, float, bool types Close #192. --- diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 98cf39c9b..880436462 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,6 +42,8 @@ jobs: - {impl: python, python: "3.9", ext: dns, postgres: "postgres:14"} - {impl: python, python: "3.9", ext: postgis, postgres: "postgis/postgis"} + - {impl: python, python: "3.10", ext: numpy, postgres: "postgres:14"} + - {impl: c, python: "3.11", ext: numpy, postgres: "postgres:15"} # Test with minimum dependencies versions - {impl: c, python: "3.7", ext: min, postgres: "postgres:15"} @@ -88,10 +90,15 @@ jobs: echo "DEPS=$DEPS shapely" >> $GITHUB_ENV echo "MARKERS=$MARKERS postgis" >> $GITHUB_ENV + - if: ${{ matrix.ext == 'numpy' }} + run: | + echo "DEPS=$DEPS numpy" >> $GITHUB_ENV + echo "MARKERS=$MARKERS numpy" >> $GITHUB_ENV + - name: Configure to use the oldest dependencies if: ${{ matrix.ext == 'min' }} run: | - echo "DEPS=$DEPS dnspython shapely" >> $GITHUB_ENV + echo "DEPS=$DEPS dnspython shapely numpy" >> $GITHUB_ENV echo "PIP_CONSTRAINT=${{ github.workspace }}/tests/constraints.txt" \ >> $GITHUB_ENV diff --git a/psycopg/psycopg/postgres.py b/psycopg/psycopg/postgres.py index 44565952d..045491373 100644 --- a/psycopg/psycopg/postgres.py +++ b/psycopg/psycopg/postgres.py @@ -106,7 +106,7 @@ def register_default_types(types: TypesRegistry) -> None: def register_default_adapters(context: AdaptContext) -> None: from .types import array, bool, composite, datetime, enum, json, multirange - from .types import net, none, numeric, range, string, uuid + from .types import net, none, numeric, numpy, range, string, uuid array.register_default_adapters(context) bool.register_default_adapters(context) @@ -118,6 +118,7 @@ def register_default_adapters(context: AdaptContext) -> None: net.register_default_adapters(context) none.register_default_adapters(context) numeric.register_default_adapters(context) + numpy.register_default_adapters(context) range.register_default_adapters(context) string.register_default_adapters(context) uuid.register_default_adapters(context) diff --git a/psycopg/psycopg/types/numpy.py b/psycopg/psycopg/types/numpy.py new file mode 100644 index 000000000..17d6c1011 --- /dev/null +++ b/psycopg/psycopg/types/numpy.py @@ -0,0 +1,228 @@ +""" +Adapters for numpy types. +""" + +# Copyright (C) 2022 The Psycopg Team + +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Optional, + Union, + cast, +) + +from .numeric import dump_int_to_numeric_binary + +from .. import postgres +from .._struct import pack_int2, pack_int4, pack_int8 +from ..abc import AdaptContext, Buffer +from ..adapt import Dumper +from ..pq import Format +from .. import _struct + +if TYPE_CHECKING: + import numpy as np + +PackNumpyFloat4 = Callable[[Union["np.float16", "np.float32"]], bytes] +PackNumpyFloat8 = Callable[["np.float64"], bytes] + +pack_float8 = cast(PackNumpyFloat8, _struct.pack_float8) +pack_float4 = cast(PackNumpyFloat4, _struct.pack_float4) + + +class _NPIntDumper(Dumper): + def dump(self, obj: Any) -> bytes: + return str(obj).encode() + + def quote(self, obj: Any) -> bytes: + value = self.dump(obj) + return value if obj >= 0 else b" " + value + + +class NPInt8Dumper(_NPIntDumper): + oid = postgres.types["int2"].oid + + +NPInt16Dumper = NPInt8Dumper +NPUInt8Dumper = NPInt8Dumper + + +class NPInt32Dumper(_NPIntDumper): + oid = postgres.types["int4"].oid + + +NPUInt16Dumper = NPInt32Dumper + + +class NPInt64Dumper(_NPIntDumper): + oid = postgres.types["int8"].oid + + +NPUInt32Dumper = NPInt64Dumper + + +class NPBooleanDumper(_NPIntDumper): + oid = postgres.types["bool"].oid + + def dump(self, obj: "np.bool_") -> bytes: + return "t".encode() if bool(obj) is True else "f".encode() + + +class NPUInt64Dumper(_NPIntDumper): + oid = postgres.types["numeric"].oid + + +NPULongLongDumper = NPUInt64Dumper + + +class _SpecialValuesDumper(Dumper): + + _special: Dict[bytes, bytes] = {} + + def dump(self, obj: float) -> Buffer: + return str(obj).encode() + + def quote(self, obj: float) -> Buffer: + value = self.dump(obj) + + if value in self._special: + return self._special[value] + + return value if obj >= 0 else b" " + value + + +class _NPFloatDumper(_SpecialValuesDumper): + + _special = { + b"inf": b"'Infinity'::float8", + b"-inf": b"'-Infinity'::float8", + b"nan": b"'NaN'::float8", + } + + +class NPFloat16Dumper(_NPFloatDumper): + oid = postgres.types["float4"].oid + + +NPFloat32Dumper = NPFloat16Dumper + + +class NPFloat64Dumper(_NPFloatDumper): + + oid = postgres.types["float8"].oid + + +# Binary Dumpers + + +class NPFloat16BinaryDumper(Dumper): + format = Format.BINARY + oid = postgres.types["float4"].oid + + def dump(self, obj: "np.float16") -> bytes: + return pack_float4(obj) + + +class NPFloat32BinaryDumper(Dumper): + format = Format.BINARY + oid = postgres.types["float4"].oid + + def dump(self, obj: "np.float32") -> bytes: + return pack_float4(obj) + + +class NPFloat64BinaryDumper(Dumper): + format = Format.BINARY + oid = postgres.types["float8"].oid + + def dump(self, obj: "np.float64") -> bytes: + return pack_float8(obj) + + +class NPInt8BinaryDumper(NPInt8Dumper): + + format = Format.BINARY + + def dump(self, obj: "np.int8") -> bytes: + return pack_int2(int(obj)) + + +NPInt16BinaryDumper = NPInt8BinaryDumper +NPUInt8BinaryDumper = NPInt8BinaryDumper + + +class NPInt32BinaryDumper(NPInt32Dumper): + + format = Format.BINARY + + def dump(self, obj: Any) -> bytes: + return pack_int4(int(obj)) + + +NPUInt16BinaryDumper = NPInt32BinaryDumper + + +class NPInt64BinaryDumper(NPInt64Dumper): + + format = Format.BINARY + + def dump(self, obj: Any) -> bytes: + return pack_int8(int(obj)) + + +NPUInt32BinaryDumper = NPInt64BinaryDumper + + +class NPBooleanBinaryDumper(NPBooleanDumper): + + format = Format.BINARY + + def dump(self, obj: Any) -> bytes: + return b"\x01" if obj else b"\x00" + + +class NPUInt64BinaryDumper(NPUInt64Dumper): + + format = Format.BINARY + + def dump(self, obj: Any) -> Buffer: + + return dump_int_to_numeric_binary(int(obj)) + + +NPUlongLongBinaryDumper = NPUInt64BinaryDumper + + +def register_default_adapters(context: Optional[AdaptContext] = None) -> None: + adapters = context.adapters if context else postgres.adapters + + adapters.register_dumper("numpy.int8", NPInt8Dumper) + adapters.register_dumper("numpy.int16", NPInt16Dumper) + adapters.register_dumper("numpy.int32", NPInt32Dumper) + adapters.register_dumper("numpy.int64", NPInt64Dumper) + adapters.register_dumper("numpy.bool_", NPBooleanDumper) + adapters.register_dumper("numpy.uint8", NPUInt8Dumper) + adapters.register_dumper("numpy.uint16", NPUInt16Dumper) + adapters.register_dumper("numpy.uint32", NPUInt32Dumper) + adapters.register_dumper("numpy.uint64", NPUInt64Dumper) + adapters.register_dumper("numpy.ulonglong", NPULongLongDumper) + adapters.register_dumper("numpy.float16", NPFloat16Dumper) + adapters.register_dumper("numpy.float32", NPFloat32Dumper) + adapters.register_dumper("numpy.float64", NPFloat64Dumper) + + adapters.register_dumper("numpy.int8", NPInt8BinaryDumper) + adapters.register_dumper("numpy.int16", NPInt16BinaryDumper) + adapters.register_dumper("numpy.int32", NPInt32BinaryDumper) + adapters.register_dumper("numpy.int64", NPInt64BinaryDumper) + adapters.register_dumper("numpy.bool_", NPBooleanBinaryDumper) + adapters.register_dumper("numpy.uint8", NPUInt8BinaryDumper) + adapters.register_dumper("numpy.uint16", NPUInt16BinaryDumper) + adapters.register_dumper("numpy.uint32", NPUInt32BinaryDumper) + adapters.register_dumper("numpy.uint64", NPUInt64BinaryDumper) + adapters.register_dumper("numpy.ulonglong", NPUlongLongBinaryDumper) + adapters.register_dumper("numpy.float16", NPFloat16BinaryDumper) + adapters.register_dumper("numpy.float32", NPFloat32BinaryDumper) + adapters.register_dumper("numpy.float64", NPFloat64BinaryDumper) diff --git a/pyproject.toml b/pyproject.toml index 950bccb4a..7a507bc43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ strict = true [[tool.mypy.overrides]] module = [ "shapely.*", + "numpy.*", ] ignore_missing_imports = true diff --git a/tests/conftest.py b/tests/conftest.py index 1ec997bf9..e674f6e76 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,6 +27,7 @@ def pytest_configure(config): "timing: the test is timing based and can fail on cheese hardware", "dns: the test requires dnspython to run", "postgis: the test requires the PostGIS extension to run", + "numpy: the test requires numpy module to be installed", ] for marker in markers: diff --git a/tests/types/test_numpy.py b/tests/types/test_numpy.py new file mode 100644 index 000000000..5fc8bf7e5 --- /dev/null +++ b/tests/types/test_numpy.py @@ -0,0 +1,351 @@ +import pytest +from psycopg.adapt import PyFormat + +pytest.importorskip("numpy") + + +import numpy as np # noqa: E402 + + +pytestmark = [pytest.mark.numpy] + + +@pytest.mark.parametrize( + "val, expr", + [ + (-128, "'-128'::int2"), + (127, "'127'::int2"), + (0, "'0'::int2"), + (45, "'45'::int2"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_numpy_int8(conn, val, expr, fmt_in): + val = np.byte(val) + + assert isinstance(val, np.byte) + assert np.byte is np.int8 + + cur = conn.cursor() + + cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select {expr} = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + (-32_768, "'-32768'::int2"), + (32_767, "'32767'::int2"), + (0, "'0'::int2"), + (45, "'45'::int2"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_numpy_int16(conn, val, expr, fmt_in): + + val = np.short(val) + + assert isinstance(val, np.short) + assert np.short is np.int16 + + cur = conn.cursor() + + cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select {expr} = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + (-2_147_483_648, "'-2147483648'::int4"), + (2_147_483_647, "'2147483647'::int4"), + (0, "'0'::int4"), + (45, "'45'::int4"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_numpy_int32(conn, val, expr, fmt_in): + + val = np.intc(val) + + assert isinstance(val, np.intc) + assert np.intc is np.int32 + + cur = conn.cursor() + + cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select {expr} = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + (-9_223_372_036_854_775_808, "'-9223372036854775808'::int8"), + (9_223_372_036_854_775_807, "'9223372036854775807'::int8"), + (0, "'0'::int8"), + (45, "'45'::int8"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_numpy_int64(conn, val, expr, fmt_in): + + val = np.int_(val) + + assert isinstance(val, np.int_) + assert np.int_ is np.int64 + + cur = conn.cursor() + + cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select {expr} = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + (True, "'t'::bool"), + (False, "'f'::bool"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_numpy_bool8(conn, val, expr, fmt_in): + + val = np.bool_(val) + + assert isinstance(val, np.bool_) + assert np.bool_ is np.bool8 + + cur = conn.cursor() + + cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select {expr} = %{fmt_in.value}", (bool(val),)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [(0, "'0'::int2"), (255, "'255'::int2")], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_numpy_uint8(conn, val, expr, fmt_in): + + val = np.ubyte(val) + + assert isinstance(val, np.ubyte) + assert np.ubyte is np.uint8 + + cur = conn.cursor() + + cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select {expr} = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + (0, "'0'::int4"), + (65_535, "'65535'::int4"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_numpy_uint16(conn, val, expr, fmt_in): + + val = np.ushort(val) + + assert isinstance(val, np.ushort) + assert np.ushort is np.uint16 + + cur = conn.cursor() + + cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select {expr} = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + (0, "'0'::int8"), + (4_294_967_295, "'4294967295'::int8"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_numpy_uint32(conn, val, expr, fmt_in): + + val = np.uintc(val) + + assert isinstance(val, np.uintc) + assert np.uintc is np.uint32 + + cur = conn.cursor() + + cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select {expr} = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + (0, "'0'::numeric"), + (18_446_744_073_709_551_615, "'18446744073709551615'::numeric"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_numpy_uint64(conn, val, expr, fmt_in): + + val = np.uint(val) + + assert isinstance(val, np.uint) + assert np.uint is np.uint64 + + cur = conn.cursor() + + cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select {expr} = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + (0, "'0'::numeric"), + (18_446_744_073_709_551_615, "'18446744073709551615'::numeric"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_numpy_ulonglong(conn, val, expr, fmt_in): + + val = np.ulonglong(val) + + assert isinstance(val, np.ulonglong) + + cur = conn.cursor() + + cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select {expr} = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + +# Test float special values +@pytest.mark.parametrize( + "val, expr", + [ + (np.PZERO, "'0.0'::float8"), + (np.NZERO, "'-0.0'::float8"), + (np.nan, "'NaN'::float8"), + (np.inf, "'Infinity'::float8"), + (np.NINF, "'-Infinity'::float8"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_special_values(conn, val, expr, fmt_in): + + if val == np.nan: + assert np.nan == np.NAN == np.NaN + + if val == np.inf: + assert np.inf == np.Inf == np.PINF == np.infty + + assert isinstance(val, float) + + cur = conn.cursor() + cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,)) + + assert cur.fetchone()[0] is True + + cur.execute(f"select {expr} = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val", + [ + "4e4", + # "4e-4", + "4000.0", + # "3.14", + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_numpy_float16(conn, val, fmt_in): + + val = np.float16(val) + cur = conn.cursor() + + cur.execute(f"select pg_typeof({val}::float4) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select {val}::float4 = %{fmt_in.value}", (val,)) + + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val", + [ + "256e6", + "256e-6", + "2.7182817", + "3.1415927", + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_numpy_float32(conn, val, fmt_in): + + val = np.float32(val) + cur = conn.cursor() + + cur.execute(f"select pg_typeof({val}::float4) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select {val}::float4 = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val", + [ + "256e12", + "256e-12", + "2.718281828459045", + "3.141592653589793", + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_numpy_float64(conn, val, fmt_in): + + val = np.float64(val) + cur = conn.cursor() + + cur.execute(f"select pg_typeof({val}::float8) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select {val}::float8 = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True