]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Move basic struct-related functions to a common module
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 8 Jun 2021 18:24:39 +0000 (19:24 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 25 Jun 2021 15:16:26 +0000 (16:16 +0100)
psycopg3/psycopg3/_struct.py [new file with mode: 0644]
psycopg3/psycopg3/types/array.py
psycopg3/psycopg3/types/composite.py
psycopg3/psycopg3/types/date.py
psycopg3/psycopg3/types/numeric.py

diff --git a/psycopg3/psycopg3/_struct.py b/psycopg3/psycopg3/_struct.py
new file mode 100644 (file)
index 0000000..3bb798a
--- /dev/null
@@ -0,0 +1,41 @@
+"""
+Utility functions to deal with binary structs.
+"""
+
+# Copyright (C) 2020-2021 The Psycopg Team
+
+import struct
+from typing import Callable, cast, Optional, Tuple
+
+from .proto import Buffer
+from .compat import Protocol
+
+PackInt = Callable[[int], bytes]
+UnpackInt = Callable[[bytes], Tuple[int]]
+PackFloat = Callable[[float], bytes]
+UnpackFloat = Callable[[bytes], Tuple[float]]
+
+
+class UnpackLen(Protocol):
+    def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]:
+        ...
+
+
+pack_int2 = cast(PackInt, struct.Struct("!h").pack)
+pack_uint2 = cast(PackInt, struct.Struct("!H").pack)
+pack_int4 = cast(PackInt, struct.Struct("!i").pack)
+pack_uint4 = cast(PackInt, struct.Struct("!I").pack)
+pack_int8 = cast(PackInt, struct.Struct("!q").pack)
+pack_float8 = cast(PackFloat, struct.Struct("!d").pack)
+
+unpack_int2 = cast(UnpackInt, struct.Struct("!h").unpack)
+unpack_uint2 = cast(UnpackInt, struct.Struct("!H").unpack)
+unpack_int4 = cast(UnpackInt, struct.Struct("!i").unpack)
+unpack_uint4 = cast(UnpackInt, struct.Struct("!I").unpack)
+unpack_int8 = cast(UnpackInt, struct.Struct("!q").unpack)
+unpack_float4 = cast(UnpackFloat, struct.Struct("!f").unpack)
+unpack_float8 = cast(UnpackFloat, struct.Struct("!d").unpack)
+
+_struct_len = struct.Struct("!i")
+pack_len = cast(Callable[[int], bytes], _struct_len.pack)
+unpack_len = cast(UnpackLen, _struct_len.unpack_from)
index d8d2be9163935e2a1507029c884596e5923ed394..9d133d899bc65d2ee99d246336697ea79b25bd7b 100644 (file)
@@ -6,7 +6,8 @@ Adapters for arrays
 
 import re
 import struct
-from typing import Any, Iterator, List, Optional, Set, Tuple, Type
+from typing import Any, Callable, Iterator, List, Optional, Set, Tuple, Type
+from typing import cast
 
 from .. import pq
 from .. import errors as e
@@ -14,8 +15,20 @@ from ..oids import postgres_types, TEXT_OID, TEXT_ARRAY_OID, INVALID_OID
 from ..adapt import Buffer, Dumper, Loader, Transformer
 from ..adapt import Format as Pg3Format
 from ..proto import AdaptContext
+from .._struct import pack_len, unpack_len
 from .._typeinfo import TypeInfo
 
+_struct_head = struct.Struct("!III")  # ndims, hasnull, elem oid
+_pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack)
+_unpack_head = cast(
+    Callable[[bytes], Tuple[int, int, int]], _struct_head.unpack_from
+)
+_struct_dim = struct.Struct("!II")  # dim, lower bound
+_pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack)
+_unpack_dim = cast(
+    Callable[[bytes, int], Tuple[int, int]], _struct_dim.unpack_from
+)
+
 
 class BaseListDumper(Dumper):
     def __init__(self, cls: type, context: Optional[AdaptContext] = None):
@@ -162,7 +175,7 @@ class ListBinaryDumper(BaseListDumper):
         sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID
 
         if not obj:
-            return _struct_head.pack(0, 0, sub_oid)
+            return _pack_head(0, 0, sub_oid)
 
         data: List[bytes] = [b"", b""]  # placeholders to avoid a resize
         dims: List[int] = []
@@ -187,7 +200,7 @@ class ListBinaryDumper(BaseListDumper):
                     if item is not None:
                         # If we get here, the sub_dumper must have been set
                         ad = self.sub_dumper.dump(item)  # type: ignore[union-attr]
-                        data.append(_struct_len.pack(len(ad)))
+                        data.append(pack_len(len(ad)))
                         data.append(ad)
                     else:
                         hasnull = 1
@@ -202,8 +215,8 @@ class ListBinaryDumper(BaseListDumper):
 
         dump_list(obj, 0)
 
-        data[0] = _struct_head.pack(len(dims), hasnull, sub_oid)
-        data[1] = b"".join(_struct_dim.pack(dim, 1) for dim in dims)
+        data[0] = _pack_head(len(dims), hasnull, sub_oid)
+        data[1] = b"".join(_pack_dim(dim, 1) for dim in dims)
         return b"".join(data)
 
 
@@ -273,30 +286,23 @@ class ArrayLoader(BaseArrayLoader):
     _re_unescape = re.compile(br"\\(.)")
 
 
-_struct_head = struct.Struct("!III")  # ndims, hasnull, elem oid
-_struct_dim = struct.Struct("!II")  # dim, lower bound
-_struct_len = struct.Struct("!i")
-
-
 class ArrayBinaryLoader(BaseArrayLoader):
 
     format = pq.Format.BINARY
 
     def load(self, data: Buffer) -> List[Any]:
-        ndims, hasnull, oid = _struct_head.unpack_from(data[:12])
+        ndims, hasnull, oid = _unpack_head(data)
         if not ndims:
             return []
 
         fcast = self._tx.get_loader(oid, self.format).load
 
         p = 12 + 8 * ndims
-        dims = [
-            _struct_dim.unpack_from(data, i)[0] for i in list(range(12, p, 8))
-        ]
+        dims = [_unpack_dim(data, i)[0] for i in list(range(12, p, 8))]
 
         def consume(p: int) -> Iterator[Any]:
             while 1:
-                size = _struct_len.unpack_from(data, p)[0]
+                size = unpack_len(data, p)[0]
                 p += 4
                 if size != -1:
                     yield fcast(data[p : p + size])
index e2275ff9bffe55fce53849d95ced05f5dfd99771..56a3d41e2e9f837d08159a303817a6480ba2300c 100644 (file)
@@ -7,15 +7,21 @@ Support for composite types adaptation.
 import re
 import struct
 from collections import namedtuple
-from typing import Any, Callable, Iterator, List, Optional
+from typing import Any, Callable, cast, Iterator, List, Optional
 from typing import Sequence, Tuple, Type
 
 from .. import pq
 from ..oids import TEXT_OID
 from ..adapt import Buffer, Format, Dumper, Loader, Transformer
 from ..proto import AdaptContext
+from .._struct import unpack_len
 from .._typeinfo import CompositeInfo
 
+_struct_oidlen = struct.Struct("!Ii")
+_unpack_oidlen = cast(
+    Callable[[bytes, int], Tuple[int, int]], _struct_oidlen.unpack_from
+)
+
 
 class SequenceDumper(Dumper):
 
@@ -116,10 +122,6 @@ class RecordLoader(BaseCompositeLoader):
         )
 
 
-_struct_len = struct.Struct("!i")
-_struct_oidlen = struct.Struct("!Ii")
-
-
 class RecordBinaryLoader(Loader):
 
     format = pq.Format.BINARY
@@ -145,10 +147,10 @@ class RecordBinaryLoader(Loader):
         """
         Yield a sequence of (oid, offset, length) for the content of the record
         """
-        nfields = _struct_len.unpack_from(data, 0)[0]
+        nfields = unpack_len(data, 0)[0]
         i = 4
         for _ in range(nfields):
-            oid, length = _struct_oidlen.unpack_from(data, i)
+            oid, length = _unpack_oidlen(data, i)
             yield oid, i + 8, length
             i += (8 + length) if length > 0 else 8
 
index bd877166cd1dce9268cbba8fbfcacf73c01496ae..262ec23d75a20677a803f5feb2c7230295be1368 100644 (file)
@@ -11,32 +11,26 @@ from datetime import date, datetime, time, timedelta, timezone
 from typing import Any, Callable, cast, Optional, Tuple, Union, TYPE_CHECKING
 
 from ..pq import Format
+from .._tz import get_tzinfo
 from ..oids import postgres_types as builtins
 from ..adapt import Buffer, Dumper, Loader, Format as Pg3Format
 from ..proto import AdaptContext
 from ..errors import InterfaceError, DataError
-from .._tz import get_tzinfo
+from .._struct import pack_int4, pack_int8, unpack_int4, unpack_int8
 
 if TYPE_CHECKING:
     from ..connection import BaseConnection
 
-_PackInt = Callable[[int], bytes]
-_UnpackInt = Callable[[bytes], Tuple[int]]
-
-_pack_int4 = cast(_PackInt, struct.Struct("!i").pack)
-_pack_int8 = cast(_PackInt, struct.Struct("!q").pack)
-_unpack_int4 = cast(_UnpackInt, struct.Struct("!i").unpack)
-_unpack_int8 = cast(_UnpackInt, struct.Struct("!q").unpack)
-
-_pack_timetz = cast(Callable[[int, int], bytes], struct.Struct("!qi").pack)
+_struct_timetz = struct.Struct("!qi")  # microseconds, sec tz offset
+_pack_timetz = cast(Callable[[int, int], bytes], _struct_timetz.pack)
 _unpack_timetz = cast(
-    Callable[[bytes], Tuple[int, int]], struct.Struct("!qi").unpack
-)
-_pack_interval = cast(
-    Callable[[int, int, int], bytes], struct.Struct("!qii").pack
+    Callable[[bytes], Tuple[int, int]], _struct_timetz.unpack
 )
+
+_struct_interval = struct.Struct("!qii")  # microseconds, days, months
+_pack_interval = cast(Callable[[int, int, int], bytes], _struct_interval.pack)
 _unpack_interval = cast(
-    Callable[[bytes], Tuple[int, int, int]], struct.Struct("!qii").unpack
+    Callable[[bytes], Tuple[int, int, int]], _struct_interval.unpack
 )
 
 utc = timezone.utc
@@ -64,7 +58,7 @@ class DateBinaryDumper(Dumper):
 
     def dump(self, obj: date) -> bytes:
         days = obj.toordinal() - _pg_date_epoch_days
-        return _pack_int4(days)
+        return pack_int4(days)
 
 
 class _BaseTimeDumper(Dumper):
@@ -115,7 +109,7 @@ class TimeBinaryDumper(_BaseTimeDumper):
         us = obj.microsecond + 1_000_000 * (
             obj.second + 60 * (obj.minute + 60 * obj.hour)
         )
-        return _pack_int8(us)
+        return pack_int8(us)
 
     def upgrade(self, obj: time, format: Pg3Format) -> Dumper:
         if not obj.tzinfo:
@@ -189,7 +183,7 @@ class DateTimeTzBinaryDumper(_BaseDateTimeDumper):
         micros = delta.microseconds + 1_000_000 * (
             86_400 * delta.days + delta.seconds
         )
-        return _pack_int8(micros)
+        return pack_int8(micros)
 
     def upgrade(self, obj: datetime, format: Pg3Format) -> Dumper:
         if obj.tzinfo:
@@ -208,7 +202,7 @@ class DateTimeBinaryDumper(_BaseDateTimeDumper):
         micros = delta.microseconds + 1_000_000 * (
             86_400 * delta.days + delta.seconds
         )
-        return _pack_int8(micros)
+        return pack_int8(micros)
 
 
 class TimeDeltaDumper(Dumper):
@@ -298,7 +292,7 @@ class DateBinaryLoader(Loader):
     format = Format.BINARY
 
     def load(self, data: Buffer) -> date:
-        days = _unpack_int4(data)[0] + _pg_date_epoch_days
+        days = unpack_int4(data)[0] + _pg_date_epoch_days
         try:
             return date.fromordinal(days)
         except ValueError:
@@ -342,7 +336,7 @@ class TimeBinaryLoader(Loader):
     format = Format.BINARY
 
     def load(self, data: Buffer) -> time:
-        val = _unpack_int8(data)[0]
+        val = unpack_int8(data)[0]
         val, us = divmod(val, 1_000_000)
         val, s = divmod(val, 60)
         h, m = divmod(val, 60)
@@ -535,7 +529,7 @@ class TimestampBinaryLoader(Loader):
     format = Format.BINARY
 
     def load(self, data: Buffer) -> datetime:
-        micros = _unpack_int8(data)[0]
+        micros = unpack_int8(data)[0]
         try:
             return _pg_datetime_epoch + timedelta(microseconds=micros)
         except OverflowError:
@@ -633,7 +627,7 @@ class TimestamptzBinaryLoader(Loader):
         )
 
     def load(self, data: Buffer) -> datetime:
-        micros = _unpack_int8(data)[0]
+        micros = unpack_int8(data)[0]
         try:
             ts = _pg_datetimetz_epoch + timedelta(microseconds=micros)
             return ts.astimezone(self._timezone)
index d1dca439f5670a5bb48254ea39aec511bf8af3d7..a940841a199725e1590932cf42978b2e1a2ceda4 100644 (file)
@@ -14,28 +14,12 @@ from ..pq import Format
 from ..oids import postgres_types as builtins
 from ..adapt import Buffer, Dumper, Loader
 from ..adapt import Format as Pg3Format
+from .._struct import pack_int2, pack_uint2, unpack_int2
+from .._struct import pack_int4, pack_uint4, unpack_int4, unpack_uint4
+from .._struct import pack_int8, unpack_int8
+from .._struct import pack_float8, unpack_float4, unpack_float8
 from ..wrappers.numeric import Int2, Int4, Int8, IntNumeric
 
-_PackInt = Callable[[int], bytes]
-_PackFloat = Callable[[float], bytes]
-_UnpackInt = Callable[[bytes], Tuple[int]]
-_UnpackFloat = Callable[[bytes], Tuple[float]]
-
-_pack_int2 = cast(_PackInt, struct.Struct("!h").pack)
-_pack_uint2 = cast(_PackInt, struct.Struct("!H").pack)
-_pack_int4 = cast(_PackInt, struct.Struct("!i").pack)
-_pack_uint4 = cast(_PackInt, struct.Struct("!I").pack)
-_pack_int8 = cast(_PackInt, struct.Struct("!q").pack)
-_pack_float8 = cast(_PackFloat, struct.Struct("!d").pack)
-_unpack_int2 = cast(_UnpackInt, struct.Struct("!h").unpack)
-_unpack_uint2 = cast(_UnpackInt, struct.Struct("!H").unpack)
-_unpack_int4 = cast(_UnpackInt, struct.Struct("!i").unpack)
-_unpack_uint4 = cast(_UnpackInt, struct.Struct("!I").unpack)
-_unpack_int8 = cast(_UnpackInt, struct.Struct("!q").unpack)
-_unpack_float4 = cast(_UnpackFloat, struct.Struct("!f").unpack)
-_unpack_float8 = cast(_UnpackFloat, struct.Struct("!d").unpack)
-
-
 # Wrappers to force numbers to be cast as specific PostgreSQL types
 
 
@@ -82,7 +66,7 @@ class FloatBinaryDumper(Dumper):
     _oid = builtins["float8"].oid
 
     def dump(self, obj: float) -> bytes:
-        return _pack_float8(obj)
+        return pack_float8(obj)
 
 
 class DecimalDumper(SpecialValuesDumper):
@@ -159,7 +143,7 @@ class Int2BinaryDumper(Int2Dumper):
     format = Format.BINARY
 
     def dump(self, obj: int) -> bytes:
-        return _pack_int2(obj)
+        return pack_int2(obj)
 
 
 class Int4BinaryDumper(Int4Dumper):
@@ -167,7 +151,7 @@ class Int4BinaryDumper(Int4Dumper):
     format = Format.BINARY
 
     def dump(self, obj: int) -> bytes:
-        return _pack_int4(obj)
+        return pack_int4(obj)
 
 
 class Int8BinaryDumper(Int8Dumper):
@@ -175,7 +159,7 @@ class Int8BinaryDumper(Int8Dumper):
     format = Format.BINARY
 
     def dump(self, obj: int) -> bytes:
-        return _pack_int8(obj)
+        return pack_int8(obj)
 
 
 # Ratio between number of bits required to store a number and number of pg
@@ -201,7 +185,7 @@ class IntNumericBinaryDumper(IntNumericDumper):
         while obj:
             rem = obj % 10_000
             obj //= 10_000
-            out[i : i + 2] = _pack_uint2(rem)
+            out[i : i + 2] = pack_uint2(rem)
             i -= 2
 
         return out
@@ -212,7 +196,7 @@ class OidBinaryDumper(OidDumper):
     format = Format.BINARY
 
     def dump(self, obj: int) -> bytes:
-        return _pack_uint4(obj)
+        return pack_uint4(obj)
 
 
 class IntBinaryDumper(IntDumper):
@@ -239,7 +223,7 @@ class Int2BinaryLoader(Loader):
     format = Format.BINARY
 
     def load(self, data: Buffer) -> int:
-        return _unpack_int2(data)[0]
+        return unpack_int2(data)[0]
 
 
 class Int4BinaryLoader(Loader):
@@ -247,7 +231,7 @@ class Int4BinaryLoader(Loader):
     format = Format.BINARY
 
     def load(self, data: Buffer) -> int:
-        return _unpack_int4(data)[0]
+        return unpack_int4(data)[0]
 
 
 class Int8BinaryLoader(Loader):
@@ -255,7 +239,7 @@ class Int8BinaryLoader(Loader):
     format = Format.BINARY
 
     def load(self, data: Buffer) -> int:
-        return _unpack_int8(data)[0]
+        return unpack_int8(data)[0]
 
 
 class OidBinaryLoader(Loader):
@@ -263,7 +247,7 @@ class OidBinaryLoader(Loader):
     format = Format.BINARY
 
     def load(self, data: Buffer) -> int:
-        return _unpack_uint4(data)[0]
+        return unpack_uint4(data)[0]
 
 
 class FloatLoader(Loader):
@@ -280,7 +264,7 @@ class Float4BinaryLoader(Loader):
     format = Format.BINARY
 
     def load(self, data: Buffer) -> float:
-        return _unpack_float4(data)[0]
+        return unpack_float4(data)[0]
 
 
 class Float8BinaryLoader(Loader):
@@ -288,7 +272,7 @@ class Float8BinaryLoader(Loader):
     format = Format.BINARY
 
     def load(self, data: Buffer) -> float:
-        return _unpack_float8(data)[0]
+        return unpack_float8(data)[0]
 
 
 class NumericLoader(Loader):
@@ -435,10 +419,10 @@ class DecimalBinaryDumper(Dumper):
             pgdigit += weights[wi] * digits[i]
             wi += 1
             if wi >= DEC_DIGITS:
-                out += _pack_uint2(pgdigit)
+                out += pack_uint2(pgdigit)
                 pgdigit = wi = 0
 
         if pgdigit:
-            out += _pack_uint2(pgdigit)
+            out += pack_uint2(pgdigit)
 
         return out