]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix selection of dumper by oid
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 23 Sep 2021 14:22:26 +0000 (16:22 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 26 Sep 2021 19:51:39 +0000 (21:51 +0200)
Add a few dumpers which can deal with the different types and make sure
to register them so that the lookup by oid would use them. Because they
are slightly less efficient (with isinstance tests) we prefer to have
them only used by oid, instead of extending the domain of the dumpers
used in normal adaptation by type.

This commit only fixes the Python side. The C numeric dumpers require
more work.

psycopg/psycopg/__init__.py
psycopg/psycopg/dbapi20.py
psycopg/psycopg/postgres.py
psycopg/psycopg/types/array.py
psycopg/psycopg/types/net.py
psycopg/psycopg/types/numeric.py

index acc4fbb7059b36afde3b31d042c39c0420c2ef4b..c1c1dc6ebce06594c770c2f88b0c2416f38ea7cd 100644 (file)
@@ -23,6 +23,7 @@ from .cursor_async import AsyncCursor
 from .server_cursor import AsyncServerCursor, ServerCursor
 from .connection_async import AsyncConnection
 
+from . import dbapi20
 from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING
 from .dbapi20 import Binary, BinaryTextDumper, BinaryBinaryDumper
 from .dbapi20 import Date, DateFromTicks, Time, TimeFromTicks
@@ -43,12 +44,15 @@ connect = Connection.connect
 apilevel = "2.0"
 threadsafety = 2
 paramstyle = "pyformat"
-adapters.register_dumper(Binary, BinaryTextDumper)  # dbapi20
-adapters.register_dumper(Binary, BinaryBinaryDumper)  # dbapi20
 
-# After registering the dbapi20 dumpers to clobber the oid they set
 postgres.register_default_adapters(adapters)
 
+# After the default one because they can deal with the bytea oid better
+dbapi20.register_dbapi20_adapters(adapters)
+
+# Must come after all the types are registered
+types.array.register_all_arrays(adapters)
+
 # Note: defining the exported methods helps both Sphynx in documenting that
 # this is the canonical place to obtain them and should be used by MyPy too,
 # so that function signatures are consistent with the documentation.
index a0a87590e8370b050624a2b3782b2748095352af..3e2c31700818500cdc7a9325e7d7f19f0cf8a694 100644 (file)
@@ -7,10 +7,10 @@ Compatibility objects with DBAPI 2.0
 import time
 import datetime as dt
 from math import floor
-from typing import Any, Sequence
+from typing import Any, Sequence, Union
 
 from . import postgres
-from .abc import Buffer
+from .abc import AdaptContext, Buffer
 from .types.string import BytesDumper, BytesBinaryDumper
 
 
@@ -58,13 +58,19 @@ class Binary:
 
 
 class BinaryBinaryDumper(BytesBinaryDumper):
-    def dump(self, obj: Binary) -> Buffer:  # type: ignore
-        return super().dump(obj.obj)
+    def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
+        if isinstance(obj, Binary):
+            return super().dump(obj.obj)
+        else:
+            return super().dump(obj)
 
 
 class BinaryTextDumper(BytesDumper):
-    def dump(self, obj: Binary) -> Buffer:  # type: ignore
-        return super().dump(obj.obj)
+    def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
+        if isinstance(obj, Binary):
+            return super().dump(obj.obj)
+        else:
+            return super().dump(obj)
 
 
 def Date(year: int, month: int, day: int) -> dt.date:
@@ -96,3 +102,13 @@ def TimestampFromTicks(ticks: float) -> dt.datetime:
     tzinfo = dt.timezone(dt.timedelta(seconds=t.tm_gmtoff))
     rv = dt.datetime(*t[:6], round(frac * 1_000_000), tzinfo=tzinfo)
     return rv
+
+
+def register_dbapi20_adapters(context: AdaptContext) -> None:
+    adapters = context.adapters
+    adapters.register_dumper(Binary, BinaryTextDumper)
+    adapters.register_dumper(Binary, BinaryBinaryDumper)
+
+    # Make them also the default dumpers when dumping by bytea oid
+    adapters.register_dumper(None, BinaryTextDumper)
+    adapters.register_dumper(None, BinaryBinaryDumper)
index e3a83afda40523070504c249e00de548884f6f80..af052bcab5db987264c1e8556697a33071468022 100644 (file)
@@ -115,6 +115,3 @@ def register_default_adapters(context: AdaptContext) -> None:
     range.register_default_adapters(context)
     string.register_default_adapters(context)
     uuid.register_default_adapters(context)
-
-    # Must come after all the types are registered
-    array.register_all_arrays(context)
index b3b6e5e387d726f36f0eedcf3077b8b620c26fd3..4c3c66881e2a4f6e6e692bc44712d95769f3ca90 100644 (file)
@@ -32,19 +32,23 @@ _unpack_dim = cast(
 )
 
 TEXT_ARRAY_OID = postgres.types["text"].array_oid
+NoneType: type = type(None)
 
 
 class BaseListDumper(RecursiveDumper):
     element_oid = 0
 
     def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+        if cls is NoneType:
+            cls = list
+
         super().__init__(cls, context)
         self.sub_dumper: Optional[Dumper] = None
         if self.element_oid and context:
             sdclass = context.adapters.get_dumper_by_oid(
                 self.element_oid, self.format
             )
-            self.sub_dumper = sdclass(type(None), context)
+            self.sub_dumper = sdclass(NoneType, context)
 
     def _find_list_element(self, L: List[Any]) -> Any:
         """
index 0cb21c1ea6fb2a6aed9a4957793de770e1039203..614fe566d00592223ebedeb662b247ed1b3bbdf6 100644 (file)
@@ -36,6 +36,21 @@ IPV4_PREFIXLEN = 32
 IPV6_PREFIXLEN = 128
 
 
+class _LazyIpaddress:
+    def _ensure_module(self) -> None:
+        global imported, ip_address, ip_interface, ip_network
+        global IPv4Address, IPv6Address, IPv4Interface, IPv6Interface
+        global IPv4Network, IPv6Network
+
+        if not imported:
+            from ipaddress import ip_address, ip_interface, ip_network
+            from ipaddress import IPv4Address, IPv6Address
+            from ipaddress import IPv4Interface, IPv6Interface
+            from ipaddress import IPv4Network, IPv6Network
+
+            imported = True
+
+
 class InterfaceDumper(Dumper):
 
     oid = postgres.types["inet"].oid
@@ -52,11 +67,12 @@ class NetworkDumper(Dumper):
         return str(obj).encode()
 
 
-class AddressBinaryDumper(Dumper):
-
+class _AIBinaryDumper(Dumper):
     format = Format.BINARY
     oid = postgres.types["inet"].oid
 
+
+class AddressBinaryDumper(_AIBinaryDumper):
     def dump(self, obj: Address) -> bytes:
         packed = obj.packed
         family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
@@ -64,11 +80,7 @@ class AddressBinaryDumper(Dumper):
         return head + packed
 
 
-class InterfaceBinaryDumper(Dumper):
-
-    format = Format.BINARY
-    oid = postgres.types["inet"].oid
-
+class InterfaceBinaryDumper(_AIBinaryDumper):
     def dump(self, obj: Interface) -> bytes:
         packed = obj.packed
         family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
@@ -76,6 +88,28 @@ class InterfaceBinaryDumper(Dumper):
         return head + packed
 
 
+class InetBinaryDumper(_AIBinaryDumper, _LazyIpaddress):
+    """Either an address or an interface to inet
+
+    Used when looking up by oid.
+    """
+
+    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+        super().__init__(cls, context)
+        self._ensure_module()
+
+    def dump(self, obj: Union[Address, Interface]) -> bytes:
+        packed = obj.packed
+        family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+        if isinstance(obj, (IPv4Interface, IPv6Interface)):
+            prefixlen = obj.network.prefixlen
+        else:
+            prefixlen = obj.max_prefixlen
+
+        head = bytes((family, prefixlen, 0, len(packed)))
+        return head + packed
+
+
 class NetworkBinaryDumper(Dumper):
 
     format = Format.BINARY
@@ -88,23 +122,13 @@ class NetworkBinaryDumper(Dumper):
         return head + packed
 
 
-class _LazyIpaddress(Loader):
+class _LazyIpaddressLoader(Loader, _LazyIpaddress):
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
-        global imported, ip_address, ip_interface, ip_network
-        global IPv4Address, IPv6Address, IPv4Interface, IPv6Interface
-        global IPv4Network, IPv6Network
-
-        if not imported:
-            from ipaddress import ip_address, ip_interface, ip_network
-            from ipaddress import IPv4Address, IPv6Address
-            from ipaddress import IPv4Interface, IPv6Interface
-            from ipaddress import IPv4Network, IPv6Network
-
-            imported = True
+        self._ensure_module()
 
 
-class InetLoader(_LazyIpaddress):
+class InetLoader(_LazyIpaddressLoader):
     def load(self, data: Buffer) -> Union[Address, Interface]:
         if isinstance(data, memoryview):
             data = bytes(data)
@@ -115,7 +139,7 @@ class InetLoader(_LazyIpaddress):
             return ip_address(data.decode())
 
 
-class InetBinaryLoader(_LazyIpaddress):
+class InetBinaryLoader(_LazyIpaddressLoader):
 
     format = Format.BINARY
 
@@ -137,7 +161,7 @@ class InetBinaryLoader(_LazyIpaddress):
                 return IPv6Interface((packed, prefix))
 
 
-class CidrLoader(_LazyIpaddress):
+class CidrLoader(_LazyIpaddressLoader):
     def load(self, data: Buffer) -> Network:
         if isinstance(data, memoryview):
             data = bytes(data)
@@ -145,7 +169,7 @@ class CidrLoader(_LazyIpaddress):
         return ip_network(data.decode())
 
 
-class CidrBinaryLoader(_LazyIpaddress):
+class CidrBinaryLoader(_LazyIpaddressLoader):
 
     format = Format.BINARY
 
@@ -177,6 +201,7 @@ def register_default_adapters(context: AdaptContext) -> None:
     adapters.register_dumper("ipaddress.IPv6Interface", InterfaceBinaryDumper)
     adapters.register_dumper("ipaddress.IPv4Network", NetworkBinaryDumper)
     adapters.register_dumper("ipaddress.IPv6Network", NetworkBinaryDumper)
+    adapters.register_dumper(None, InetBinaryDumper)
     adapters.register_loader("inet", InetLoader)
     adapters.register_loader("inet", InetBinaryLoader)
     adapters.register_loader("cidr", CidrLoader)
index b0baf9db6a269f07c8d286f275676f4c2791cd99..f1af106e27f8dea0f8ad123a6b1df25738078545 100644 (file)
@@ -184,24 +184,8 @@ class IntNumericBinaryDumper(IntNumericDumper):
 
     format = Format.BINARY
 
-    def dump(self, obj: int) -> bytearray:
-        ndigits = int(obj.bit_length() * BIT_PER_PGDIGIT) + 1
-        out = bytearray(b"\x00\x00" * (ndigits + 4))
-        if obj < 0:
-            sign = NUMERIC_NEG
-            obj = -obj
-        else:
-            sign = NUMERIC_POS
-
-        out[:8] = _pack_numeric_head(ndigits, ndigits - 1, sign, 0)
-        i = 8 + (ndigits - 1) * 2
-        while obj:
-            rem = obj % 10_000
-            obj //= 10_000
-            out[i : i + 2] = pack_uint2(rem)
-            i -= 2
-
-        return out
+    def dump(self, obj: int) -> Buffer:
+        return dump_int_to_numeric_binary(obj)
 
 
 class OidBinaryDumper(OidDumper):
@@ -372,64 +356,108 @@ class DecimalBinaryDumper(Dumper):
     format = Format.BINARY
     oid = postgres.types["numeric"].oid
 
-    def dump(self, obj: Decimal) -> Union[bytearray, bytes]:
-        sign, digits, exp = obj.as_tuple()
-        if exp == "n" or exp == "N":  # type: ignore[comparison-overlap]
-            return NUMERIC_NAN_BIN
-        elif exp == "F":  # type: ignore[comparison-overlap]
-            return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN
+    def dump(self, obj: Decimal) -> Buffer:
+        return dump_decimal_to_numeric_binary(obj)
 
-        # Weights of py digits into a pg digit according to their positions.
-        # Starting with an index wi != 0 is equivalent to prepending 0's to
-        # the digits tuple, but without really changing it.
-        weights = (1000, 100, 10, 1)
-        wi = 0
 
-        ndigits = nzdigits = len(digits)
+class NumericDumper(DecimalDumper):
+    def dump(self, obj: Union[Decimal, int]) -> bytes:
+        if isinstance(obj, int):
+            return str(obj).encode()
+        else:
+            return super().dump(obj)
 
-        # Find the last nonzero digit
-        while nzdigits > 0 and digits[nzdigits - 1] == 0:
-            nzdigits -= 1
 
-        if exp <= 0:
-            dscale = -exp
+class NumericBinaryDumper(Dumper):
+
+    format = Format.BINARY
+    oid = postgres.types["numeric"].oid
+
+    def dump(self, obj: Union[Decimal, int]) -> Buffer:
+        if isinstance(obj, int):
+            return dump_int_to_numeric_binary(obj)
         else:
-            dscale = 0
-            # align the py digits to the pg digits if there's some py exponent
-            ndigits += exp % DEC_DIGITS
-
-        if not nzdigits:
-            return _pack_numeric_head(0, 0, NUMERIC_POS, dscale)
-
-        # Equivalent of 0-padding left to align the py digits to the pg digits
-        # but without changing the digits tuple.
-        mod = (ndigits - dscale) % DEC_DIGITS
-        if mod:
-            wi = DEC_DIGITS - mod
-            ndigits += wi
-
-        tmp = nzdigits + wi
-        out = bytearray(
-            _pack_numeric_head(
-                tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1),  # ndigits
-                (ndigits + exp) // DEC_DIGITS - 1,  # weight
-                NUMERIC_NEG if sign else NUMERIC_POS,  # sign
-                dscale,
-            )
+            return dump_decimal_to_numeric_binary(obj)
+
+
+def dump_decimal_to_numeric_binary(obj: Decimal) -> Union[bytearray, bytes]:
+    sign, digits, exp = obj.as_tuple()
+    if exp == "n" or exp == "N":  # type: ignore[comparison-overlap]
+        return NUMERIC_NAN_BIN
+    elif exp == "F":  # type: ignore[comparison-overlap]
+        return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN
+
+    # Weights of py digits into a pg digit according to their positions.
+    # Starting with an index wi != 0 is equivalent to prepending 0's to
+    # the digits tuple, but without really changing it.
+    weights = (1000, 100, 10, 1)
+    wi = 0
+
+    ndigits = nzdigits = len(digits)
+
+    # Find the last nonzero digit
+    while nzdigits > 0 and digits[nzdigits - 1] == 0:
+        nzdigits -= 1
+
+    if exp <= 0:
+        dscale = -exp
+    else:
+        dscale = 0
+        # align the py digits to the pg digits if there's some py exponent
+        ndigits += exp % DEC_DIGITS
+
+    if not nzdigits:
+        return _pack_numeric_head(0, 0, NUMERIC_POS, dscale)
+
+    # Equivalent of 0-padding left to align the py digits to the pg digits
+    # but without changing the digits tuple.
+    mod = (ndigits - dscale) % DEC_DIGITS
+    if mod:
+        wi = DEC_DIGITS - mod
+        ndigits += wi
+
+    tmp = nzdigits + wi
+    out = bytearray(
+        _pack_numeric_head(
+            tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1),  # ndigits
+            (ndigits + exp) // DEC_DIGITS - 1,  # weight
+            NUMERIC_NEG if sign else NUMERIC_POS,  # sign
+            dscale,
         )
+    )
 
-        pgdigit = 0
-        for i in range(nzdigits):
-            pgdigit += weights[wi] * digits[i]
-            wi += 1
-            if wi >= DEC_DIGITS:
-                out += pack_uint2(pgdigit)
-                pgdigit = wi = 0
-
-        if pgdigit:
+    pgdigit = 0
+    for i in range(nzdigits):
+        pgdigit += weights[wi] * digits[i]
+        wi += 1
+        if wi >= DEC_DIGITS:
             out += pack_uint2(pgdigit)
+            pgdigit = wi = 0
+
+    if pgdigit:
+        out += pack_uint2(pgdigit)
+
+    return out
 
-        return out
+
+def dump_int_to_numeric_binary(obj: int) -> bytearray:
+    ndigits = int(obj.bit_length() * BIT_PER_PGDIGIT) + 1
+    out = bytearray(b"\x00\x00" * (ndigits + 4))
+    if obj < 0:
+        sign = NUMERIC_NEG
+        obj = -obj
+    else:
+        sign = NUMERIC_POS
+
+    out[:8] = _pack_numeric_head(ndigits, ndigits - 1, sign, 0)
+    i = 8 + (ndigits - 1) * 2
+    while obj:
+        rem = obj % 10_000
+        obj //= 10_000
+        out[i : i + 2] = pack_uint2(rem)
+        i -= 2
+
+    return out
 
 
 def register_default_adapters(context: AdaptContext) -> None:
@@ -450,6 +478,10 @@ def register_default_adapters(context: AdaptContext) -> None:
     adapters.register_dumper("decimal.Decimal", DecimalBinaryDumper)
     adapters.register_dumper("decimal.Decimal", DecimalDumper)
 
+    # Used only by oid, can take both int and Decimal as input
+    adapters.register_dumper(None, NumericBinaryDumper)
+    adapters.register_dumper(None, NumericDumper)
+
     adapters.register_dumper(Float4, Float4Dumper)
     adapters.register_dumper(Float8, FloatDumper)
     adapters.register_dumper(Int2, Int2BinaryDumper)