]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(numpy): fix dumping numpy values by oid
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 13 Jan 2023 01:24:45 +0000 (01:24 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 5 Aug 2023 14:21:30 +0000 (15:21 +0100)
psycopg/psycopg/types/numeric.py
psycopg_c/psycopg_c/types/numeric.pyx
tests/types/test_numpy.py

index da7178ee16bf1859bafc33228340fff0b12abd80..441927372499ab79a7e9efb5fd1da6b5cd00c420 100644 (file)
@@ -4,9 +4,12 @@ Adapers for numeric types.
 
 # Copyright (C) 2020 The Psycopg Team
 
+import sys
 import struct
+from abc import ABC, abstractmethod
 from math import log
-from typing import Any, Callable, DefaultDict, Dict, Tuple, Union, cast
+from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Union
+from typing import cast, TYPE_CHECKING
 from decimal import Decimal, DefaultContext, Context
 
 from .. import _oids
@@ -30,6 +33,9 @@ from .._wrappers import (
     Float8 as Float8,
 )
 
+if TYPE_CHECKING:
+    import numpy
+
 
 class _IntDumper(Dumper):
     def dump(self, obj: Any) -> Buffer:
@@ -43,12 +49,9 @@ class _IntDumper(Dumper):
 class _IntOrSubclassDumper(_IntDumper):
     def dump(self, obj: Any) -> Buffer:
         t = type(obj)
+        # Convert to int in order to dump IntEnum or numpy.integer correctly
         if t is not int:
-            # Convert to int in order to dump IntEnum correctly
-            if issubclass(t, int):
-                obj = int(obj)
-            else:
-                raise e.DataError(f"integer expected, got {type(obj).__name__!r}")
+            obj = int(obj)
 
         return str(obj).encode()
 
@@ -101,11 +104,7 @@ class DecimalDumper(_SpecialValuesDumper):
     oid = _oids.NUMERIC_OID
 
     def dump(self, obj: Decimal) -> bytes:
-        if obj.is_nan():
-            # cover NaN and sNaN
-            return b"NaN"
-        else:
-            return str(obj).encode()
+        return dump_decimal_to_text(obj)
 
     _special = {
         b"Infinity": b"'Infinity'::numeric",
@@ -355,23 +354,69 @@ class DecimalBinaryDumper(Dumper):
         return dump_decimal_to_numeric_binary(obj)
 
 
-class NumericDumper(DecimalDumper):
-    def dump(self, obj: Union[Decimal, int]) -> bytes:
-        if isinstance(obj, int):
+class _MixedNumericDumper(Dumper, ABC):
+    """Base for dumper to dump int, decimal, numpy.integer to Postgres numeric
+
+    Only used when looking up by oid.
+    """
+
+    oid = _oids.NUMERIC_OID
+
+    # If numpy is available, the dumped object might be a numpy integer too
+    int_classes: Union[type, Tuple[type, ...]] = ()
+
+    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+        super().__init__(cls, context)
+
+        # Verify if numpy is available. If it is, we might have to dump
+        # its integers too.
+        if not _MixedNumericDumper.int_classes:
+            if "numpy" in sys.modules:
+                import numpy
+
+                _MixedNumericDumper.int_classes = (int, numpy.integer)
+            else:
+                _MixedNumericDumper.int_classes = int
+
+    @abstractmethod
+    def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer:
+        ...
+
+
+class NumericDumper(_MixedNumericDumper):
+    def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer:
+        if isinstance(obj, self.int_classes):
             return str(obj).encode()
+        elif isinstance(obj, Decimal):
+            return dump_decimal_to_text(obj)
         else:
-            return super().dump(obj)
+            raise TypeError(
+                f"class {type(self).__name__} cannot dump {type(obj).__name__}"
+            )
 
 
-class NumericBinaryDumper(Dumper):
+class NumericBinaryDumper(_MixedNumericDumper):
     format = Format.BINARY
-    oid = _oids.NUMERIC_OID
 
-    def dump(self, obj: Union[Decimal, int]) -> Buffer:
-        if isinstance(obj, int):
+    def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer:
+        if type(obj) is int:
             return dump_int_to_numeric_binary(obj)
-        else:
+        elif isinstance(obj, Decimal):
             return dump_decimal_to_numeric_binary(obj)
+        elif isinstance(obj, self.int_classes):
+            return dump_int_to_numeric_binary(int(obj))
+        else:
+            raise TypeError(
+                f"class {type(self).__name__} cannot dump {type(obj).__name__}"
+            )
+
+
+def dump_decimal_to_text(obj: Decimal) -> bytes:
+    if obj.is_nan():
+        # cover NaN and sNaN
+        return b"NaN"
+    else:
+        return str(obj).encode()
 
 
 def dump_decimal_to_numeric_binary(obj: Decimal) -> Union[bytearray, bytes]:
index fdf9561ac1834b0e0fe44bfb6ea916e19fb83d94..941785b01d43ee001f37c4399935232daade3e96 100644 (file)
@@ -17,6 +17,7 @@ from cpython.bytes cimport PyBytes_AsStringAndSize
 from cpython.float cimport PyFloat_FromDouble, PyFloat_AsDouble
 from cpython.unicode cimport PyUnicode_DecodeUTF8
 
+import sys
 from decimal import Decimal, Context, DefaultContext
 
 from psycopg_c._psycopg cimport endian
@@ -497,30 +498,55 @@ cdef class DecimalBinaryDumper(CDumper):
         return dump_decimal_to_numeric_binary(obj, rv, offset)
 
 
+cdef class _MixedNumericDumper(CDumper):
+
+    int_classes = None
+    oid = oids.NUMERIC_OID
+
+    def __cinit__(self, cls, context: Optional[AdaptContext] = None):
+        if _MixedNumericDumper.int_classes is None:
+            if "numpy" in sys.modules:
+                import numpy
+
+                _MixedNumericDumper.int_classes = (int, numpy.integer)
+            else:
+                _MixedNumericDumper.int_classes = int
+
+
 @cython.final
-cdef class NumericDumper(CDumper):
+cdef class NumericDumper(_MixedNumericDumper):
 
     format = PQ_TEXT
-    oid = oids.NUMERIC_OID
 
     cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
-        if isinstance(obj, int):
+        if type(obj) is int:  # fast path
             return dump_int_to_text(obj, rv, offset)
-        else:
+        elif isinstance(obj, Decimal):
             return dump_decimal_to_text(obj, rv, offset)
+        elif isinstance(obj, self.int_classes):
+            return dump_int_to_text(obj, rv, offset)
+        else:
+            raise TypeError(
+                f"class {type(self).__name__} cannot dump {type(obj).__name__}"
+            )
 
 
 @cython.final
-cdef class NumericBinaryDumper(CDumper):
+cdef class NumericBinaryDumper(_MixedNumericDumper):
 
     format = PQ_BINARY
-    oid = oids.NUMERIC_OID
 
     cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
-        if isinstance(obj, int):
+        if type(obj) is int:
             return dump_int_to_numeric_binary(obj, rv, offset)
-        else:
+        elif isinstance(obj, Decimal):
             return dump_decimal_to_numeric_binary(obj, rv, offset)
+        elif isinstance(obj, self.int_classes):
+            return dump_int_to_numeric_binary(int(obj), rv, offset)
+        else:
+            raise TypeError(
+                f"class {type(self).__name__} cannot dump {type(obj).__name__}"
+            )
 
 
 cdef Py_ssize_t dump_decimal_to_text(obj, bytearray rv, Py_ssize_t offset) except -1:
index 2eb04a64d8220a0e26cb382be8fe146e59f9366c..84e1082433e863c994c9d7c0ba969025633a8e62 100644 (file)
@@ -2,6 +2,7 @@ from math import isnan
 
 import pytest
 from psycopg.adapt import PyFormat
+from psycopg.pq import Format
 
 try:
     import numpy as np
@@ -140,6 +141,57 @@ def test_dump_float(conn, nptype, val, pgtype, fmt_in):
         assert rec[0][0] == rec[1][0]
 
 
+@pytest.mark.parametrize(
+    "nptype, val, pgtypes",
+    [
+        ("int8", -128, "int2 int4 int8 numeric"),
+        ("int8", 127, "int2 int4 int8 numeric"),
+        ("int16", -32_768, "int2 int4 int8 numeric"),
+        ("int16", 32_767, "int2 int4 int8 numeric"),
+        ("int32", -(2**31), "int4 int8 numeric"),
+        ("int32", 0, "int2 int4 int8 numeric"),
+        ("int32", 2**31 - 1, "int4 int8 numeric"),
+        ("int64", -(2**63), "int8 numeric"),
+        ("int64", 2**63 - 1, "int8 numeric"),
+        ("longlong", -(2**63), "int8"),
+        ("longlong", 2**63 - 1, "int8"),
+        ("bool_", True, "bool"),
+        ("bool_", False, "bool"),
+        ("uint8", 0, "int2 int4 int8 numeric"),
+        ("uint8", 255, "int2 int4 int8 numeric"),
+        ("uint16", 0, "int2 int4 int8 numeric"),
+        ("uint16", 65_535, "int4 int8 numeric"),
+        ("uint32", 0, "int4 int8 numeric"),
+        ("uint32", (2**32 - 1), "int8 numeric"),
+        ("uint64", 0, "int8 numeric"),
+        ("uint64", (2**64 - 1), "numeric"),
+        ("ulonglong", 0, "int8 numeric"),
+        ("ulonglong", (2**64 - 1), "numeric"),
+    ],
+)
+@pytest.mark.parametrize("fmt", Format)
+def test_copy_by_oid(conn, val, nptype, pgtypes, fmt):
+    nptype = getattr(np, nptype)
+    val = nptype(val)
+    pgtypes = pgtypes.split()
+    cur = conn.cursor()
+
+    fnames = [f"f{t}" for t in pgtypes]
+    fields = [f"f{t} {t}" for fname, t in zip(fnames, pgtypes)]
+    cur.execute(
+        f"create table numpyoid (id serial primary key, {', '.join(fields)})",
+    )
+    with cur.copy(
+        f"copy numpyoid ({', '.join(fnames)}) from stdin (format {fmt.name})"
+    ) as copy:
+        copy.set_types(pgtypes)
+        copy.write_row((val,) * len(fnames))
+
+    cur.execute(f"select {', '.join(fnames)} from numpyoid")
+    rec = cur.fetchone()
+    assert rec == (int(val),) * len(fnames)
+
+
 @pytest.mark.slow
 @pytest.mark.parametrize("fmt", PyFormat)
 def test_random(conn, faker, fmt):