]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(numpy): reuse base or final classes from builtin numeric types
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 16 Dec 2022 16:49:32 +0000 (16:49 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 5 Aug 2023 14:21:30 +0000 (15:21 +0100)
psycopg/psycopg/types/numeric.py
psycopg/psycopg/types/numpy.py

index b371f37ab46863d0727b45783f4f93bc6aad0aad..da7178ee16bf1859bafc33228340fff0b12abd80 100644 (file)
@@ -32,6 +32,15 @@ from .._wrappers import (
 
 
 class _IntDumper(Dumper):
+    def dump(self, obj: Any) -> Buffer:
+        return str(obj).encode()
+
+    def quote(self, obj: Any) -> Buffer:
+        value = self.dump(obj)
+        return value if obj >= 0 else b" " + value
+
+
+class _IntOrSubclassDumper(_IntDumper):
     def dump(self, obj: Any) -> Buffer:
         t = type(obj)
         if t is not int:
@@ -43,10 +52,6 @@ class _IntDumper(Dumper):
 
         return str(obj).encode()
 
-    def quote(self, obj: Any) -> Buffer:
-        value = self.dump(obj)
-        return value if obj >= 0 else b" " + value
-
 
 class _SpecialValuesDumper(Dumper):
     _special: Dict[bytes, bytes] = {}
@@ -109,23 +114,23 @@ class DecimalDumper(_SpecialValuesDumper):
     }
 
 
-class Int2Dumper(_IntDumper):
+class Int2Dumper(_IntOrSubclassDumper):
     oid = _oids.INT2_OID
 
 
-class Int4Dumper(_IntDumper):
+class Int4Dumper(_IntOrSubclassDumper):
     oid = _oids.INT4_OID
 
 
-class Int8Dumper(_IntDumper):
+class Int8Dumper(_IntOrSubclassDumper):
     oid = _oids.INT8_OID
 
 
-class IntNumericDumper(_IntDumper):
+class IntNumericDumper(_IntOrSubclassDumper):
     oid = _oids.NUMERIC_OID
 
 
-class OidDumper(_IntDumper):
+class OidDumper(_IntOrSubclassDumper):
     oid = _oids.OID_OID
 
 
index d0ee52393b63c8d17e7afe77975372a8553879f1..68bb5f2b98fe754f29ebed1397069da4d20e8f0d 100644 (file)
@@ -4,102 +4,42 @@ Adapters for numpy types.
 
 # Copyright (C) 2022 The Psycopg Team
 
-from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
+from typing import Any, Optional
 
 from .. import postgres
 from .._struct import pack_int2, pack_int4, pack_int8
 from ..abc import AdaptContext, Buffer
-from ..adapt import Dumper
 from ..pq import Format
-from .. import _struct
 
 from .bool import BoolDumper, BoolBinaryDumper
-from .numeric import dump_int_to_numeric_binary, _SpecialValuesDumper
+from .numeric import _IntDumper, dump_int_to_numeric_binary
+from .numeric import FloatDumper, Float4Dumper, FloatBinaryDumper, Float4BinaryDumper
 
-if TYPE_CHECKING:
-    import numpy as np
 
-PackNumpyFloat4 = Callable[[Union["np.float16", "np.float32"]], bytes]
-PackNumpyFloat8 = Callable[["np.float64"], bytes]
-
-pack_float8 = cast(PackNumpyFloat8, _struct.pack_float8)
-pack_float4 = cast(PackNumpyFloat4, _struct.pack_float4)
-
-
-class _NPIntDumper(Dumper):
-    def dump(self, obj: Any) -> Buffer:
-        return str(obj).encode()
-
-    def quote(self, obj: Any) -> Buffer:
-        value = self.dump(obj)
-        return value if obj >= 0 else b" " + value
-
-
-class NPInt16Dumper(_NPIntDumper):
+class NPInt16Dumper(_IntDumper):
     oid = postgres.types["int2"].oid
 
 
-class NPInt32Dumper(_NPIntDumper):
+class NPInt32Dumper(_IntDumper):
     oid = postgres.types["int4"].oid
 
 
-class NPInt64Dumper(_NPIntDumper):
+class NPInt64Dumper(_IntDumper):
     oid = postgres.types["int8"].oid
 
 
-class NPNumericDumper(_NPIntDumper):
+class NPNumericDumper(_IntDumper):
     oid = postgres.types["numeric"].oid
 
 
-class _NPFloatDumper(_SpecialValuesDumper):
-
-    _special = {
-        b"inf": b"'Infinity'::float8",
-        b"-inf": b"'-Infinity'::float8",
-        b"nan": b"'NaN'::float8",
-    }
-
-
-class NPFloat32Dumper(_NPFloatDumper):
-    oid = postgres.types["float4"].oid
-
-
-class NPFloat64Dumper(_NPFloatDumper):
-    oid = postgres.types["float8"].oid
-
-
 # Binary Dumpers
 
 
-class NPFloat16BinaryDumper(Dumper):
-    format = Format.BINARY
-    oid = postgres.types["float4"].oid
-
-    def dump(self, obj: "np.float16") -> bytes:
-        return pack_float4(obj)
-
-
-class NPFloat32BinaryDumper(Dumper):
-    format = Format.BINARY
-    oid = postgres.types["float4"].oid
-
-    def dump(self, obj: "np.float32") -> bytes:
-        return pack_float4(obj)
-
-
-class NPFloat64BinaryDumper(Dumper):
-    format = Format.BINARY
-    oid = postgres.types["float8"].oid
-
-    def dump(self, obj: "np.float64") -> bytes:
-        return pack_float8(obj)
-
-
 class NPInt16BinaryDumper(NPInt16Dumper):
 
     format = Format.BINARY
 
-    def dump(self, obj: "np.int8") -> bytes:
+    def dump(self, obj: Any) -> bytes:
         return pack_int2(int(obj))
 
 
@@ -124,7 +64,6 @@ class NPNumericBinaryDumper(NPNumericDumper):
     format = Format.BINARY
 
     def dump(self, obj: Any) -> Buffer:
-
         return dump_int_to_numeric_binary(int(obj))
 
 
@@ -141,9 +80,9 @@ def register_default_adapters(context: Optional[AdaptContext] = None) -> None:
     adapters.register_dumper("numpy.uint32", NPInt64Dumper)
     adapters.register_dumper("numpy.uint64", NPNumericDumper)
     adapters.register_dumper("numpy.ulonglong", NPNumericDumper)
-    adapters.register_dumper("numpy.float16", NPFloat32Dumper)
-    adapters.register_dumper("numpy.float32", NPFloat32Dumper)
-    adapters.register_dumper("numpy.float64", NPFloat64Dumper)
+    adapters.register_dumper("numpy.float16", Float4Dumper)
+    adapters.register_dumper("numpy.float32", Float4Dumper)
+    adapters.register_dumper("numpy.float64", FloatDumper)
 
     adapters.register_dumper("numpy.int8", NPInt16BinaryDumper)
     adapters.register_dumper("numpy.int16", NPInt16BinaryDumper)
@@ -155,6 +94,6 @@ def register_default_adapters(context: Optional[AdaptContext] = None) -> None:
     adapters.register_dumper("numpy.uint32", NPInt64BinaryDumper)
     adapters.register_dumper("numpy.uint64", NPNumericBinaryDumper)
     adapters.register_dumper("numpy.ulonglong", NPNumericBinaryDumper)
-    adapters.register_dumper("numpy.float16", NPFloat16BinaryDumper)
-    adapters.register_dumper("numpy.float32", NPFloat32BinaryDumper)
-    adapters.register_dumper("numpy.float64", NPFloat64BinaryDumper)
+    adapters.register_dumper("numpy.float16", Float4BinaryDumper)
+    adapters.register_dumper("numpy.float32", Float4BinaryDumper)
+    adapters.register_dumper("numpy.float64", FloatBinaryDumper)