From: Daniele Varrazzo Date: Fri, 16 Dec 2022 15:45:19 +0000 (+0000) Subject: refactor(numpy): use existing numeric base classes X-Git-Tag: pool-3.2.0~70^2~23 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9387e91e3e2a756e42fb1decbbb6891668844684;p=thirdparty%2Fpsycopg.git refactor(numpy): use existing numeric base classes --- diff --git a/psycopg/psycopg/types/numpy.py b/psycopg/psycopg/types/numpy.py index 17d6c1011..a62fd0a10 100644 --- a/psycopg/psycopg/types/numpy.py +++ b/psycopg/psycopg/types/numpy.py @@ -4,17 +4,7 @@ Adapters for numpy types. # Copyright (C) 2022 The Psycopg Team -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Optional, - Union, - cast, -) - -from .numeric import dump_int_to_numeric_binary +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast from .. import postgres from .._struct import pack_int2, pack_int4, pack_int8 @@ -23,6 +13,8 @@ from ..adapt import Dumper from ..pq import Format from .. import _struct +from .numeric import dump_int_to_numeric_binary, _SpecialValuesDumper + if TYPE_CHECKING: import numpy as np @@ -34,10 +26,10 @@ pack_float4 = cast(PackNumpyFloat4, _struct.pack_float4) class _NPIntDumper(Dumper): - def dump(self, obj: Any) -> bytes: + def dump(self, obj: Any) -> Buffer: return str(obj).encode() - def quote(self, obj: Any) -> bytes: + def quote(self, obj: Any) -> Buffer: value = self.dump(obj) return value if obj >= 0 else b" " + value @@ -78,22 +70,6 @@ class NPUInt64Dumper(_NPIntDumper): NPULongLongDumper = NPUInt64Dumper -class _SpecialValuesDumper(Dumper): - - _special: Dict[bytes, bytes] = {} - - def dump(self, obj: float) -> Buffer: - return str(obj).encode() - - def quote(self, obj: float) -> Buffer: - value = self.dump(obj) - - if value in self._special: - return self._special[value] - - return value if obj >= 0 else b" " + value - - class _NPFloatDumper(_SpecialValuesDumper): _special = {