From: Daniele Varrazzo Date: Tue, 27 Oct 2020 15:31:13 +0000 (+0100) Subject: Use local functions for encode/decode in date adaptation X-Git-Tag: 3.0.dev0~427 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ecf2e30ac02c49ea6e2a6b198bfdce6755a90e54;p=thirdparty%2Fpsycopg.git Use local functions for encode/decode in date adaptation --- diff --git a/psycopg3/psycopg3/types/date.py b/psycopg3/psycopg3/types/date.py index 9d4f16f24..f1a369e63 100644 --- a/psycopg3/psycopg3/types/date.py +++ b/psycopg3/psycopg3/types/date.py @@ -11,21 +11,23 @@ from datetime import date, datetime, time, timedelta from typing import cast from ..adapt import Dumper, Loader -from ..proto import AdaptContext +from ..proto import AdaptContext, EncodeFunc, DecodeFunc from ..errors import InterfaceError, DataError from .oids import builtins +_encode_ascii = codecs.lookup("ascii").encode +_decode_ascii = codecs.lookup("ascii").decode + @Dumper.text(date) class DateDumper(Dumper): - _encode = codecs.lookup("ascii").encode DATE_OID = builtins["date"].oid - def dump(self, obj: date) -> bytes: + def dump(self, obj: date, __encode: EncodeFunc = _encode_ascii) -> bytes: # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD) # the YYYY-MM-DD is always understood correctly. - return self._encode(str(obj))[0] + return __encode(str(obj))[0] @property def oid(self) -> int: @@ -35,11 +37,10 @@ class DateDumper(Dumper): @Dumper.text(time) class TimeDumper(Dumper): - _encode = codecs.lookup("ascii").encode TIMETZ_OID = builtins["timetz"].oid - def dump(self, obj: time) -> bytes: - return self._encode(str(obj))[0] + def dump(self, obj: time, __encode: EncodeFunc = _encode_ascii) -> bytes: + return __encode(str(obj))[0] @property def oid(self) -> int: @@ -49,13 +50,12 @@ class TimeDumper(Dumper): @Dumper.text(datetime) class DateTimeDumper(Dumper): - _encode = codecs.lookup("ascii").encode TIMESTAMPTZ_OID = builtins["timestamptz"].oid - def dump(self, obj: date) -> bytes: + def dump(self, obj: date, __encode: EncodeFunc = _encode_ascii) -> bytes: # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD) # the YYYY-MM-DD is always understood correctly. - return self._encode(str(obj))[0] + return __encode(str(obj))[0] @property def oid(self) -> int: @@ -64,7 +64,6 @@ class DateTimeDumper(Dumper): @Dumper.text(timedelta) class TimeDeltaDumper(Dumper): - _encode = codecs.lookup("ascii").encode INTERVAL_OID = builtins["interval"].oid def __init__(self, src: type, context: AdaptContext = None): @@ -76,8 +75,10 @@ class TimeDeltaDumper(Dumper): ): self.dump = self._dump_sql # type: ignore[assignment] - def dump(self, obj: timedelta) -> bytes: - return self._encode(str(obj))[0] + def dump( + self, obj: timedelta, __encode: EncodeFunc = _encode_ascii + ) -> bytes: + return __encode(str(obj))[0] def _dump_sql(self, obj: timedelta) -> bytes: # sql_standard format needs explicit signs @@ -95,18 +96,13 @@ class TimeDeltaDumper(Dumper): @Loader.text(builtins["date"].oid) class DateLoader(Loader): - - _decode = codecs.lookup("ascii").decode - def __init__(self, oid: int, context: AdaptContext): super().__init__(oid, context) self._format = self._format_from_context() - def load(self, data: bytes) -> date: + def load(self, data: bytes, __decode: DecodeFunc = _decode_ascii) -> date: try: - return datetime.strptime( - self._decode(data)[0], self._format - ).date() + return datetime.strptime(__decode(data)[0], self._format).date() except ValueError as e: return self._raise_error(data, e) @@ -159,15 +155,14 @@ class DateLoader(Loader): @Loader.text(builtins["time"].oid) class TimeLoader(Loader): - _decode = codecs.lookup("ascii").decode _format = "%H:%M:%S.%f" _format_no_micro = _format.replace(".%f", "") - def load(self, data: bytes) -> time: + def load(self, data: bytes, __decode: DecodeFunc = _decode_ascii) -> time: # check if the data contains microseconds fmt = self._format if b"." in data else self._format_no_micro try: - return datetime.strptime(self._decode(data)[0], fmt).time() + return datetime.strptime(__decode(data)[0], fmt).time() except ValueError as e: return self._raise_error(data, e) @@ -193,14 +188,14 @@ class TimeTzLoader(TimeLoader): super().__init__(oid, context) - def load(self, data: bytes) -> time: + def load(self, data: bytes, __decode: DecodeFunc = _decode_ascii) -> time: # Hack to convert +HH in +HHMM if data[-3] in (43, 45): data += b"00" fmt = self._format if b"." in data else self._format_no_micro try: - dt = datetime.strptime(self._decode(data)[0], fmt) + dt = datetime.strptime(__decode(data)[0], fmt) except ValueError as e: return self._raise_error(data, e) @@ -223,13 +218,15 @@ class TimestampLoader(DateLoader): super().__init__(oid, context) self._format_no_micro = self._format.replace(".%f", "") - def load(self, data: bytes) -> datetime: + def load( + self, data: bytes, __decode: DecodeFunc = _decode_ascii + ) -> datetime: # check if the data contains microseconds fmt = ( self._format if data.find(b".", 19) >= 0 else self._format_no_micro ) try: - return datetime.strptime(self._decode(data)[0], fmt) + return datetime.strptime(__decode(data)[0], fmt) except ValueError as e: return self._raise_error(data, e) @@ -303,7 +300,9 @@ class TimestamptzLoader(TimestampLoader): self.load = self._load_notimpl # type: ignore[assignment] return "" - def load(self, data: bytes) -> datetime: + def load( + self, data: bytes, __decode: DecodeFunc = _decode_ascii + ) -> datetime: # Hack to convert +HH in +HHMM if data[-3] in (43, 45): data += b"00" @@ -333,7 +332,6 @@ class TimestamptzLoader(TimestampLoader): @Loader.text(builtins["interval"].oid) class IntervalLoader(Loader): - _decode = codecs.lookup("ascii").decode _re_interval = re.compile( br""" (?: (?P [-+]?\d+) \s+ years? \s* )? diff --git a/psycopg3/psycopg3/types/numeric.py b/psycopg3/psycopg3/types/numeric.py index a276fbaf1..cae84c72f 100644 --- a/psycopg3/psycopg3/types/numeric.py +++ b/psycopg3/psycopg3/types/numeric.py @@ -20,12 +20,13 @@ FLOAT8_OID = builtins["float8"].oid NUMERIC_OID = builtins["numeric"].oid BOOL_OID = builtins["bool"].oid +_encode_ascii = codecs.lookup("ascii").encode +_decode_ascii = codecs.lookup("ascii").decode + @Dumper.text(int) class TextIntDumper(Dumper): - def dump( - self, obj: int, __encode: EncodeFunc = codecs.lookup("ascii").encode - ) -> bytes: + def dump(self, obj: int, __encode: EncodeFunc = _encode_ascii) -> bytes: return __encode(str(obj))[0] @property @@ -36,9 +37,7 @@ class TextIntDumper(Dumper): @Dumper.text(float) class TextFloatDumper(Dumper): - def dump( - self, obj: float, __encode: EncodeFunc = codecs.lookup("ascii").encode - ) -> bytes: + def dump(self, obj: float, __encode: EncodeFunc = _encode_ascii) -> bytes: return __encode(str(obj))[0] @property @@ -50,9 +49,7 @@ class TextFloatDumper(Dumper): @Dumper.text(Decimal) class TextDecimalDumper(Dumper): def dump( - self, - obj: Decimal, - __encode: EncodeFunc = codecs.lookup("ascii").encode, + self, obj: Decimal, __encode: EncodeFunc = _encode_ascii ) -> bytes: return __encode(str(obj))[0] @@ -86,9 +83,7 @@ class BinaryBoolDumper(Dumper): @Loader.text(builtins["int8"].oid) @Loader.text(builtins["oid"].oid) class TextIntLoader(Loader): - def load( - self, data: bytes, __decode: DecodeFunc = codecs.lookup("ascii").decode - ) -> int: + def load(self, data: bytes, __decode: DecodeFunc = _decode_ascii) -> int: return int(__decode(data)[0]) @@ -163,7 +158,7 @@ class BinaryFloat8Loader(Loader): @Loader.text(builtins["numeric"].oid) class TextNumericLoader(Loader): def load( - self, data: bytes, __decode: DecodeFunc = codecs.lookup("ascii").decode + self, data: bytes, __decode: DecodeFunc = _decode_ascii ) -> Decimal: return Decimal(__decode(data)[0])