# Copyright (C) 2020 The Psycopg Team
+import sys
import struct
+from abc import ABC, abstractmethod
from math import log
-from typing import Any, Callable, DefaultDict, Dict, Tuple, Union, cast
+from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Union
+from typing import cast, TYPE_CHECKING
from decimal import Decimal, DefaultContext, Context
from .. import _oids
Float8 as Float8,
)
+if TYPE_CHECKING:
+ import numpy
+
class _IntDumper(Dumper):
def dump(self, obj: Any) -> Buffer:
class _IntOrSubclassDumper(_IntDumper):
def dump(self, obj: Any) -> Buffer:
t = type(obj)
+ # Convert to int in order to dump IntEnum or numpy.integer correctly
if t is not int:
- # Convert to int in order to dump IntEnum correctly
- if issubclass(t, int):
- obj = int(obj)
- else:
- raise e.DataError(f"integer expected, got {type(obj).__name__!r}")
+ obj = int(obj)
return str(obj).encode()
oid = _oids.NUMERIC_OID
def dump(self, obj: Decimal) -> bytes:
- if obj.is_nan():
- # cover NaN and sNaN
- return b"NaN"
- else:
- return str(obj).encode()
+ return dump_decimal_to_text(obj)
_special = {
b"Infinity": b"'Infinity'::numeric",
return dump_decimal_to_numeric_binary(obj)
-class NumericDumper(DecimalDumper):
- def dump(self, obj: Union[Decimal, int]) -> bytes:
- if isinstance(obj, int):
+class _MixedNumericDumper(Dumper, ABC):
+ """Base for dumper to dump int, decimal, numpy.integer to Postgres numeric
+
+ Only used when looking up by oid.
+ """
+
+ oid = _oids.NUMERIC_OID
+
+ # If numpy is available, the dumped object might be a numpy integer too
+ int_classes: Union[type, Tuple[type, ...]] = ()
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+
+ # Verify if numpy is available. If it is, we might have to dump
+ # its integers too.
+ if not _MixedNumericDumper.int_classes:
+ if "numpy" in sys.modules:
+ import numpy
+
+ _MixedNumericDumper.int_classes = (int, numpy.integer)
+ else:
+ _MixedNumericDumper.int_classes = int
+
+ @abstractmethod
+ def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer:
+ ...
+
+
+class NumericDumper(_MixedNumericDumper):
+ def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer:
+ if isinstance(obj, self.int_classes):
return str(obj).encode()
+ elif isinstance(obj, Decimal):
+ return dump_decimal_to_text(obj)
else:
- return super().dump(obj)
+ raise TypeError(
+ f"class {type(self).__name__} cannot dump {type(obj).__name__}"
+ )
-class NumericBinaryDumper(Dumper):
+class NumericBinaryDumper(_MixedNumericDumper):
format = Format.BINARY
- oid = _oids.NUMERIC_OID
- def dump(self, obj: Union[Decimal, int]) -> Buffer:
- if isinstance(obj, int):
+ def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer:
+ if type(obj) is int:
return dump_int_to_numeric_binary(obj)
- else:
+ elif isinstance(obj, Decimal):
return dump_decimal_to_numeric_binary(obj)
+ elif isinstance(obj, self.int_classes):
+ return dump_int_to_numeric_binary(int(obj))
+ else:
+ raise TypeError(
+ f"class {type(self).__name__} cannot dump {type(obj).__name__}"
+ )
+
+
+def dump_decimal_to_text(obj: Decimal) -> bytes:
+ if obj.is_nan():
+ # cover NaN and sNaN
+ return b"NaN"
+ else:
+ return str(obj).encode()
def dump_decimal_to_numeric_binary(obj: Decimal) -> Union[bytearray, bytes]:
from cpython.float cimport PyFloat_FromDouble, PyFloat_AsDouble
from cpython.unicode cimport PyUnicode_DecodeUTF8
+import sys
from decimal import Decimal, Context, DefaultContext
from psycopg_c._psycopg cimport endian
return dump_decimal_to_numeric_binary(obj, rv, offset)
+cdef class _MixedNumericDumper(CDumper):
+
+ int_classes = None
+ oid = oids.NUMERIC_OID
+
+ def __cinit__(self, cls, context: Optional[AdaptContext] = None):
+ if _MixedNumericDumper.int_classes is None:
+ if "numpy" in sys.modules:
+ import numpy
+
+ _MixedNumericDumper.int_classes = (int, numpy.integer)
+ else:
+ _MixedNumericDumper.int_classes = int
+
+
@cython.final
-cdef class NumericDumper(CDumper):
+cdef class NumericDumper(_MixedNumericDumper):
format = PQ_TEXT
- oid = oids.NUMERIC_OID
cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
- if isinstance(obj, int):
+ if type(obj) is int: # fast path
return dump_int_to_text(obj, rv, offset)
- else:
+ elif isinstance(obj, Decimal):
return dump_decimal_to_text(obj, rv, offset)
+ elif isinstance(obj, self.int_classes):
+ return dump_int_to_text(obj, rv, offset)
+ else:
+ raise TypeError(
+ f"class {type(self).__name__} cannot dump {type(obj).__name__}"
+ )
@cython.final
-cdef class NumericBinaryDumper(CDumper):
+cdef class NumericBinaryDumper(_MixedNumericDumper):
format = PQ_BINARY
- oid = oids.NUMERIC_OID
cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
- if isinstance(obj, int):
+ if type(obj) is int:
return dump_int_to_numeric_binary(obj, rv, offset)
- else:
+ elif isinstance(obj, Decimal):
return dump_decimal_to_numeric_binary(obj, rv, offset)
+ elif isinstance(obj, self.int_classes):
+ return dump_int_to_numeric_binary(int(obj), rv, offset)
+ else:
+ raise TypeError(
+ f"class {type(self).__name__} cannot dump {type(obj).__name__}"
+ )
cdef Py_ssize_t dump_decimal_to_text(obj, bytearray rv, Py_ssize_t offset) except -1:
import pytest
from psycopg.adapt import PyFormat
+from psycopg.pq import Format
try:
import numpy as np
assert rec[0][0] == rec[1][0]
+@pytest.mark.parametrize(
+ "nptype, val, pgtypes",
+ [
+ ("int8", -128, "int2 int4 int8 numeric"),
+ ("int8", 127, "int2 int4 int8 numeric"),
+ ("int16", -32_768, "int2 int4 int8 numeric"),
+ ("int16", 32_767, "int2 int4 int8 numeric"),
+ ("int32", -(2**31), "int4 int8 numeric"),
+ ("int32", 0, "int2 int4 int8 numeric"),
+ ("int32", 2**31 - 1, "int4 int8 numeric"),
+ ("int64", -(2**63), "int8 numeric"),
+ ("int64", 2**63 - 1, "int8 numeric"),
+ ("longlong", -(2**63), "int8"),
+ ("longlong", 2**63 - 1, "int8"),
+ ("bool_", True, "bool"),
+ ("bool_", False, "bool"),
+ ("uint8", 0, "int2 int4 int8 numeric"),
+ ("uint8", 255, "int2 int4 int8 numeric"),
+ ("uint16", 0, "int2 int4 int8 numeric"),
+ ("uint16", 65_535, "int4 int8 numeric"),
+ ("uint32", 0, "int4 int8 numeric"),
+ ("uint32", (2**32 - 1), "int8 numeric"),
+ ("uint64", 0, "int8 numeric"),
+ ("uint64", (2**64 - 1), "numeric"),
+ ("ulonglong", 0, "int8 numeric"),
+ ("ulonglong", (2**64 - 1), "numeric"),
+ ],
+)
+@pytest.mark.parametrize("fmt", Format)
+def test_copy_by_oid(conn, val, nptype, pgtypes, fmt):
+ nptype = getattr(np, nptype)
+ val = nptype(val)
+ pgtypes = pgtypes.split()
+ cur = conn.cursor()
+
+ fnames = [f"f{t}" for t in pgtypes]
+ fields = [f"f{t} {t}" for fname, t in zip(fnames, pgtypes)]
+ cur.execute(
+ f"create table numpyoid (id serial primary key, {', '.join(fields)})",
+ )
+ with cur.copy(
+ f"copy numpyoid ({', '.join(fnames)}) from stdin (format {fmt.name})"
+ ) as copy:
+ copy.set_types(pgtypes)
+ copy.write_row((val,) * len(fnames))
+
+ cur.execute(f"select {', '.join(fnames)} from numpyoid")
+ rec = cur.fetchone()
+ assert rec == (int(val),) * len(fnames)
+
+
@pytest.mark.slow
@pytest.mark.parametrize("fmt", PyFormat)
def test_random(conn, faker, fmt):