from typing import cast
from ..adapt import Dumper, Loader
-from ..proto import AdaptContext
+from ..proto import AdaptContext, EncodeFunc, DecodeFunc
from ..errors import InterfaceError, DataError
from .oids import builtins
+_encode_ascii = codecs.lookup("ascii").encode
+_decode_ascii = codecs.lookup("ascii").decode
+
@Dumper.text(date)
class DateDumper(Dumper):
- _encode = codecs.lookup("ascii").encode
DATE_OID = builtins["date"].oid
- def dump(self, obj: date) -> bytes:
+ def dump(self, obj: date, __encode: EncodeFunc = _encode_ascii) -> bytes:
# NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
# the YYYY-MM-DD is always understood correctly.
- return self._encode(str(obj))[0]
+ return __encode(str(obj))[0]
@property
def oid(self) -> int:
@Dumper.text(time)
class TimeDumper(Dumper):
- _encode = codecs.lookup("ascii").encode
TIMETZ_OID = builtins["timetz"].oid
- def dump(self, obj: time) -> bytes:
- return self._encode(str(obj))[0]
+ def dump(self, obj: time, __encode: EncodeFunc = _encode_ascii) -> bytes:
+ return __encode(str(obj))[0]
@property
def oid(self) -> int:
@Dumper.text(datetime)
class DateTimeDumper(Dumper):
- _encode = codecs.lookup("ascii").encode
TIMESTAMPTZ_OID = builtins["timestamptz"].oid
- def dump(self, obj: date) -> bytes:
+ def dump(self, obj: date, __encode: EncodeFunc = _encode_ascii) -> bytes:
# NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
# the YYYY-MM-DD is always understood correctly.
- return self._encode(str(obj))[0]
+ return __encode(str(obj))[0]
@property
def oid(self) -> int:
@Dumper.text(timedelta)
class TimeDeltaDumper(Dumper):
- _encode = codecs.lookup("ascii").encode
INTERVAL_OID = builtins["interval"].oid
def __init__(self, src: type, context: AdaptContext = None):
):
self.dump = self._dump_sql # type: ignore[assignment]
- def dump(self, obj: timedelta) -> bytes:
- return self._encode(str(obj))[0]
+ def dump(
+ self, obj: timedelta, __encode: EncodeFunc = _encode_ascii
+ ) -> bytes:
+ return __encode(str(obj))[0]
def _dump_sql(self, obj: timedelta) -> bytes:
# sql_standard format needs explicit signs
@Loader.text(builtins["date"].oid)
class DateLoader(Loader):
-
- _decode = codecs.lookup("ascii").decode
-
def __init__(self, oid: int, context: AdaptContext):
super().__init__(oid, context)
self._format = self._format_from_context()
- def load(self, data: bytes) -> date:
+ def load(self, data: bytes, __decode: DecodeFunc = _decode_ascii) -> date:
try:
- return datetime.strptime(
- self._decode(data)[0], self._format
- ).date()
+ return datetime.strptime(__decode(data)[0], self._format).date()
except ValueError as e:
return self._raise_error(data, e)
@Loader.text(builtins["time"].oid)
class TimeLoader(Loader):
- _decode = codecs.lookup("ascii").decode
_format = "%H:%M:%S.%f"
_format_no_micro = _format.replace(".%f", "")
- def load(self, data: bytes) -> time:
+ def load(self, data: bytes, __decode: DecodeFunc = _decode_ascii) -> time:
# check if the data contains microseconds
fmt = self._format if b"." in data else self._format_no_micro
try:
- return datetime.strptime(self._decode(data)[0], fmt).time()
+ return datetime.strptime(__decode(data)[0], fmt).time()
except ValueError as e:
return self._raise_error(data, e)
super().__init__(oid, context)
- def load(self, data: bytes) -> time:
+ def load(self, data: bytes, __decode: DecodeFunc = _decode_ascii) -> time:
# Hack to convert +HH in +HHMM
if data[-3] in (43, 45):
data += b"00"
fmt = self._format if b"." in data else self._format_no_micro
try:
- dt = datetime.strptime(self._decode(data)[0], fmt)
+ dt = datetime.strptime(__decode(data)[0], fmt)
except ValueError as e:
return self._raise_error(data, e)
super().__init__(oid, context)
self._format_no_micro = self._format.replace(".%f", "")
- def load(self, data: bytes) -> datetime:
+ def load(
+ self, data: bytes, __decode: DecodeFunc = _decode_ascii
+ ) -> datetime:
# check if the data contains microseconds
fmt = (
self._format if data.find(b".", 19) >= 0 else self._format_no_micro
)
try:
- return datetime.strptime(self._decode(data)[0], fmt)
+ return datetime.strptime(__decode(data)[0], fmt)
except ValueError as e:
return self._raise_error(data, e)
self.load = self._load_notimpl # type: ignore[assignment]
return ""
- def load(self, data: bytes) -> datetime:
+ def load(
+ self, data: bytes, __decode: DecodeFunc = _decode_ascii
+ ) -> datetime:
# Hack to convert +HH in +HHMM
if data[-3] in (43, 45):
data += b"00"
@Loader.text(builtins["interval"].oid)
class IntervalLoader(Loader):
- _decode = codecs.lookup("ascii").decode
_re_interval = re.compile(
br"""
(?: (?P<years> [-+]?\d+) \s+ years? \s* )?
NUMERIC_OID = builtins["numeric"].oid
BOOL_OID = builtins["bool"].oid
+_encode_ascii = codecs.lookup("ascii").encode
+_decode_ascii = codecs.lookup("ascii").decode
+
@Dumper.text(int)
class TextIntDumper(Dumper):
- def dump(
- self, obj: int, __encode: EncodeFunc = codecs.lookup("ascii").encode
- ) -> bytes:
+ def dump(self, obj: int, __encode: EncodeFunc = _encode_ascii) -> bytes:
return __encode(str(obj))[0]
@property
@Dumper.text(float)
class TextFloatDumper(Dumper):
- def dump(
- self, obj: float, __encode: EncodeFunc = codecs.lookup("ascii").encode
- ) -> bytes:
+ def dump(self, obj: float, __encode: EncodeFunc = _encode_ascii) -> bytes:
return __encode(str(obj))[0]
@property
@Dumper.text(Decimal)
class TextDecimalDumper(Dumper):
def dump(
- self,
- obj: Decimal,
- __encode: EncodeFunc = codecs.lookup("ascii").encode,
+ self, obj: Decimal, __encode: EncodeFunc = _encode_ascii
) -> bytes:
return __encode(str(obj))[0]
@Loader.text(builtins["int8"].oid)
@Loader.text(builtins["oid"].oid)
class TextIntLoader(Loader):
- def load(
- self, data: bytes, __decode: DecodeFunc = codecs.lookup("ascii").decode
- ) -> int:
+ def load(self, data: bytes, __decode: DecodeFunc = _decode_ascii) -> int:
return int(__decode(data)[0])
@Loader.text(builtins["numeric"].oid)
class TextNumericLoader(Loader):
def load(
- self, data: bytes, __decode: DecodeFunc = codecs.lookup("ascii").decode
+ self, data: bytes, __decode: DecodeFunc = _decode_ascii
) -> Decimal:
return Decimal(__decode(data)[0])