]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(numpy): use existing numeric base classes
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 16 Dec 2022 15:45:19 +0000 (15:45 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 5 Aug 2023 14:21:30 +0000 (15:21 +0100)
psycopg/psycopg/types/numpy.py

index 17d6c1011f0bf9103b089f066d6b807a7e589869..a62fd0a1006dff7055e497d7e1d838bdfe1defe6 100644 (file)
@@ -4,17 +4,7 @@ Adapters for numpy types.
 
 # Copyright (C) 2022 The Psycopg Team
 
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Callable,
-    Dict,
-    Optional,
-    Union,
-    cast,
-)
-
-from .numeric import dump_int_to_numeric_binary
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
 
 from .. import postgres
 from .._struct import pack_int2, pack_int4, pack_int8
@@ -23,6 +13,8 @@ from ..adapt import Dumper
 from ..pq import Format
 from .. import _struct
 
+from .numeric import dump_int_to_numeric_binary, _SpecialValuesDumper
+
 if TYPE_CHECKING:
     import numpy as np
 
@@ -34,10 +26,10 @@ pack_float4 = cast(PackNumpyFloat4, _struct.pack_float4)
 
 
 class _NPIntDumper(Dumper):
-    def dump(self, obj: Any) -> bytes:
+    def dump(self, obj: Any) -> Buffer:
         return str(obj).encode()
 
-    def quote(self, obj: Any) -> bytes:
+    def quote(self, obj: Any) -> Buffer:
         value = self.dump(obj)
         return value if obj >= 0 else b" " + value
 
@@ -78,22 +70,6 @@ class NPUInt64Dumper(_NPIntDumper):
 NPULongLongDumper = NPUInt64Dumper
 
 
-class _SpecialValuesDumper(Dumper):
-
-    _special: Dict[bytes, bytes] = {}
-
-    def dump(self, obj: float) -> Buffer:
-        return str(obj).encode()
-
-    def quote(self, obj: float) -> Buffer:
-        value = self.dump(obj)
-
-        if value in self._special:
-            return self._special[value]
-
-        return value if obj >= 0 else b" " + value
-
-
 class _NPFloatDumper(_SpecialValuesDumper):
 
     _special = {