]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Use local functions for encode/decode in date adaptation
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 27 Oct 2020 15:31:13 +0000 (16:31 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Oct 2020 03:19:24 +0000 (04:19 +0100)
psycopg3/psycopg3/types/date.py
psycopg3/psycopg3/types/numeric.py

index 9d4f16f247217d2915a4392f1f1d0cbfece7183b..f1a369e63737f049aa03613f196b46f698157bbf 100644 (file)
@@ -11,21 +11,23 @@ from datetime import date, datetime, time, timedelta
 from typing import cast
 
 from ..adapt import Dumper, Loader
-from ..proto import AdaptContext
+from ..proto import AdaptContext, EncodeFunc, DecodeFunc
 from ..errors import InterfaceError, DataError
 from .oids import builtins
 
+_encode_ascii = codecs.lookup("ascii").encode
+_decode_ascii = codecs.lookup("ascii").decode
+
 
 @Dumper.text(date)
 class DateDumper(Dumper):
 
-    _encode = codecs.lookup("ascii").encode
     DATE_OID = builtins["date"].oid
 
-    def dump(self, obj: date) -> bytes:
+    def dump(self, obj: date, __encode: EncodeFunc = _encode_ascii) -> bytes:
         # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
         # the YYYY-MM-DD is always understood correctly.
-        return self._encode(str(obj))[0]
+        return __encode(str(obj))[0]
 
     @property
     def oid(self) -> int:
@@ -35,11 +37,10 @@ class DateDumper(Dumper):
 @Dumper.text(time)
 class TimeDumper(Dumper):
 
-    _encode = codecs.lookup("ascii").encode
     TIMETZ_OID = builtins["timetz"].oid
 
-    def dump(self, obj: time) -> bytes:
-        return self._encode(str(obj))[0]
+    def dump(self, obj: time, __encode: EncodeFunc = _encode_ascii) -> bytes:
+        return __encode(str(obj))[0]
 
     @property
     def oid(self) -> int:
@@ -49,13 +50,12 @@ class TimeDumper(Dumper):
 @Dumper.text(datetime)
 class DateTimeDumper(Dumper):
 
-    _encode = codecs.lookup("ascii").encode
     TIMESTAMPTZ_OID = builtins["timestamptz"].oid
 
-    def dump(self, obj: date) -> bytes:
+    def dump(self, obj: date, __encode: EncodeFunc = _encode_ascii) -> bytes:
         # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
         # the YYYY-MM-DD is always understood correctly.
-        return self._encode(str(obj))[0]
+        return __encode(str(obj))[0]
 
     @property
     def oid(self) -> int:
@@ -64,7 +64,6 @@ class DateTimeDumper(Dumper):
 
 @Dumper.text(timedelta)
 class TimeDeltaDumper(Dumper):
-    _encode = codecs.lookup("ascii").encode
     INTERVAL_OID = builtins["interval"].oid
 
     def __init__(self, src: type, context: AdaptContext = None):
@@ -76,8 +75,10 @@ class TimeDeltaDumper(Dumper):
             ):
                 self.dump = self._dump_sql  # type: ignore[assignment]
 
-    def dump(self, obj: timedelta) -> bytes:
-        return self._encode(str(obj))[0]
+    def dump(
+        self, obj: timedelta, __encode: EncodeFunc = _encode_ascii
+    ) -> bytes:
+        return __encode(str(obj))[0]
 
     def _dump_sql(self, obj: timedelta) -> bytes:
         # sql_standard format needs explicit signs
@@ -95,18 +96,13 @@ class TimeDeltaDumper(Dumper):
 
 @Loader.text(builtins["date"].oid)
 class DateLoader(Loader):
-
-    _decode = codecs.lookup("ascii").decode
-
     def __init__(self, oid: int, context: AdaptContext):
         super().__init__(oid, context)
         self._format = self._format_from_context()
 
-    def load(self, data: bytes) -> date:
+    def load(self, data: bytes, __decode: DecodeFunc = _decode_ascii) -> date:
         try:
-            return datetime.strptime(
-                self._decode(data)[0], self._format
-            ).date()
+            return datetime.strptime(__decode(data)[0], self._format).date()
         except ValueError as e:
             return self._raise_error(data, e)
 
@@ -159,15 +155,14 @@ class DateLoader(Loader):
 @Loader.text(builtins["time"].oid)
 class TimeLoader(Loader):
 
-    _decode = codecs.lookup("ascii").decode
     _format = "%H:%M:%S.%f"
     _format_no_micro = _format.replace(".%f", "")
 
-    def load(self, data: bytes) -> time:
+    def load(self, data: bytes, __decode: DecodeFunc = _decode_ascii) -> time:
         # check if the data contains microseconds
         fmt = self._format if b"." in data else self._format_no_micro
         try:
-            return datetime.strptime(self._decode(data)[0], fmt).time()
+            return datetime.strptime(__decode(data)[0], fmt).time()
         except ValueError as e:
             return self._raise_error(data, e)
 
@@ -193,14 +188,14 @@ class TimeTzLoader(TimeLoader):
 
         super().__init__(oid, context)
 
-    def load(self, data: bytes) -> time:
+    def load(self, data: bytes, __decode: DecodeFunc = _decode_ascii) -> time:
         # Hack to convert +HH in +HHMM
         if data[-3] in (43, 45):
             data += b"00"
 
         fmt = self._format if b"." in data else self._format_no_micro
         try:
-            dt = datetime.strptime(self._decode(data)[0], fmt)
+            dt = datetime.strptime(__decode(data)[0], fmt)
         except ValueError as e:
             return self._raise_error(data, e)
 
@@ -223,13 +218,15 @@ class TimestampLoader(DateLoader):
         super().__init__(oid, context)
         self._format_no_micro = self._format.replace(".%f", "")
 
-    def load(self, data: bytes) -> datetime:
+    def load(
+        self, data: bytes, __decode: DecodeFunc = _decode_ascii
+    ) -> datetime:
         # check if the data contains microseconds
         fmt = (
             self._format if data.find(b".", 19) >= 0 else self._format_no_micro
         )
         try:
-            return datetime.strptime(self._decode(data)[0], fmt)
+            return datetime.strptime(__decode(data)[0], fmt)
         except ValueError as e:
             return self._raise_error(data, e)
 
@@ -303,7 +300,9 @@ class TimestamptzLoader(TimestampLoader):
             self.load = self._load_notimpl  # type: ignore[assignment]
             return ""
 
-    def load(self, data: bytes) -> datetime:
+    def load(
+        self, data: bytes, __decode: DecodeFunc = _decode_ascii
+    ) -> datetime:
         # Hack to convert +HH in +HHMM
         if data[-3] in (43, 45):
             data += b"00"
@@ -333,7 +332,6 @@ class TimestamptzLoader(TimestampLoader):
 @Loader.text(builtins["interval"].oid)
 class IntervalLoader(Loader):
 
-    _decode = codecs.lookup("ascii").decode
     _re_interval = re.compile(
         br"""
         (?: (?P<years> [-+]?\d+) \s+ years? \s* )?
index a276fbaf17c7e2f59e30b7db08b773b60e26bfa2..cae84c72f07c1931a64c306f724b6854d7b90867 100644 (file)
@@ -20,12 +20,13 @@ FLOAT8_OID = builtins["float8"].oid
 NUMERIC_OID = builtins["numeric"].oid
 BOOL_OID = builtins["bool"].oid
 
+_encode_ascii = codecs.lookup("ascii").encode
+_decode_ascii = codecs.lookup("ascii").decode
+
 
 @Dumper.text(int)
 class TextIntDumper(Dumper):
-    def dump(
-        self, obj: int, __encode: EncodeFunc = codecs.lookup("ascii").encode
-    ) -> bytes:
+    def dump(self, obj: int, __encode: EncodeFunc = _encode_ascii) -> bytes:
         return __encode(str(obj))[0]
 
     @property
@@ -36,9 +37,7 @@ class TextIntDumper(Dumper):
 
 @Dumper.text(float)
 class TextFloatDumper(Dumper):
-    def dump(
-        self, obj: float, __encode: EncodeFunc = codecs.lookup("ascii").encode
-    ) -> bytes:
+    def dump(self, obj: float, __encode: EncodeFunc = _encode_ascii) -> bytes:
         return __encode(str(obj))[0]
 
     @property
@@ -50,9 +49,7 @@ class TextFloatDumper(Dumper):
 @Dumper.text(Decimal)
 class TextDecimalDumper(Dumper):
     def dump(
-        self,
-        obj: Decimal,
-        __encode: EncodeFunc = codecs.lookup("ascii").encode,
+        self, obj: Decimal, __encode: EncodeFunc = _encode_ascii
     ) -> bytes:
         return __encode(str(obj))[0]
 
@@ -86,9 +83,7 @@ class BinaryBoolDumper(Dumper):
 @Loader.text(builtins["int8"].oid)
 @Loader.text(builtins["oid"].oid)
 class TextIntLoader(Loader):
-    def load(
-        self, data: bytes, __decode: DecodeFunc = codecs.lookup("ascii").decode
-    ) -> int:
+    def load(self, data: bytes, __decode: DecodeFunc = _decode_ascii) -> int:
         return int(__decode(data)[0])
 
 
@@ -163,7 +158,7 @@ class BinaryFloat8Loader(Loader):
 @Loader.text(builtins["numeric"].oid)
 class TextNumericLoader(Loader):
     def load(
-        self, data: bytes, __decode: DecodeFunc = codecs.lookup("ascii").decode
+        self, data: bytes, __decode: DecodeFunc = _decode_ascii
     ) -> Decimal:
         return Decimal(__decode(data)[0])