From: Daniele Varrazzo Date: Wed, 14 Jul 2021 15:48:21 +0000 (+0200) Subject: Fix dumping arrays of different versions of network objects X-Git-Tag: 3.0.dev1~7 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=7ac207944d612197ddb41af32f18167dfa6b1f49;p=thirdparty%2Fpsycopg.git Fix dumping arrays of different versions of network objects Use a single dumper for both IPv4 and IPv6 so that the array sub_dumper is not ambiguous. --- diff --git a/psycopg/psycopg/types/net.py b/psycopg/psycopg/types/net.py index 96cf1c76a..05f0d261c 100644 --- a/psycopg/psycopg/types/net.py +++ b/psycopg/psycopg/types/net.py @@ -54,80 +54,42 @@ class NetworkDumper(Dumper): return str(obj).encode("utf8") -class _IPv4Mixin: - _family = PGSQL_AF_INET - _prefixlen = IPV4_PREFIXLEN - - -class _IPv6Mixin: - _family = PGSQL_AF_INET6 - _prefixlen = IPV6_PREFIXLEN - - -class _AddressBinaryDumper(Dumper): +class AddressBinaryDumper(Dumper): format = Format.BINARY _oid = postgres.types["inet"].oid - _family: int - _prefixlen: int - def dump(self, obj: Address) -> bytes: packed = obj.packed - head = bytes((self._family, self._prefixlen, 0, len(packed))) + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + head = bytes((family, obj.max_prefixlen, 0, len(packed))) return head + packed -class IPv4AddressBinaryDumper(_IPv4Mixin, _AddressBinaryDumper): - pass - - -class IPv6AddressBinaryDumper(_IPv6Mixin, _AddressBinaryDumper): - pass - - -class _InterfaceBinaryDumper(Dumper): +class InterfaceBinaryDumper(Dumper): format = Format.BINARY _oid = postgres.types["inet"].oid - _family: int - def dump(self, obj: Interface) -> bytes: packed = obj.packed - head = bytes((self._family, obj.network.prefixlen, 0, len(packed))) + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + head = bytes((family, obj.network.prefixlen, 0, len(packed))) return head + packed -class IPv4InterfaceBinaryDumper(_IPv4Mixin, _InterfaceBinaryDumper): - pass - - -class IPv6InterfaceBinaryDumper(_IPv6Mixin, _InterfaceBinaryDumper): - pass - - -class _NetworkBinaryDumper(Dumper): +class NetworkBinaryDumper(Dumper): format = Format.BINARY _oid = postgres.types["cidr"].oid - _family: int - def dump(self, obj: Network) -> bytes: packed = obj.network_address.packed - head = bytes((self._family, obj.prefixlen, 1, len(packed))) + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + head = bytes((family, obj.prefixlen, 1, len(packed))) return head + packed -class IPv4NetworkBinaryDumper(_IPv4Mixin, _NetworkBinaryDumper): - pass - - -class IPv6NetworkBinaryDumper(_IPv6Mixin, _NetworkBinaryDumper): - pass - - class _LazyIpaddress(Loader): def __init__(self, oid: int, context: Optional[AdaptContext] = None): super().__init__(oid, context) @@ -217,16 +179,12 @@ def register_default_adapters(context: AdaptContext) -> None: adapters.register_dumper("ipaddress.IPv6Interface", InterfaceDumper) adapters.register_dumper("ipaddress.IPv4Network", NetworkDumper) adapters.register_dumper("ipaddress.IPv6Network", NetworkDumper) - adapters.register_dumper("ipaddress.IPv4Address", IPv4AddressBinaryDumper) - adapters.register_dumper("ipaddress.IPv6Address", IPv6AddressBinaryDumper) - adapters.register_dumper( - "ipaddress.IPv4Interface", IPv4InterfaceBinaryDumper - ) - adapters.register_dumper( - "ipaddress.IPv6Interface", IPv6InterfaceBinaryDumper - ) - adapters.register_dumper("ipaddress.IPv4Network", IPv4NetworkBinaryDumper) - adapters.register_dumper("ipaddress.IPv6Network", IPv6NetworkBinaryDumper) + adapters.register_dumper("ipaddress.IPv4Address", AddressBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Address", AddressBinaryDumper) + adapters.register_dumper("ipaddress.IPv4Interface", InterfaceBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Interface", InterfaceBinaryDumper) + adapters.register_dumper("ipaddress.IPv4Network", NetworkBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Network", NetworkBinaryDumper) adapters.register_loader("inet", InetLoader) adapters.register_loader("inet", InetBinaryLoader) adapters.register_loader("cidr", CidrLoader) diff --git a/tests/types/test_net.py b/tests/types/test_net.py index 154c74da0..fdc626339 100644 --- a/tests/types/test_net.py +++ b/tests/types/test_net.py @@ -55,6 +55,18 @@ def test_network_dump(conn, fmt_in, val): assert cur.fetchone()[0] is True +@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) +def test_network_mixed_size_array(conn, fmt_in): + val = [ + ipaddress.IPv4Network("192.168.0.1/32"), + ipaddress.IPv6Network("::1/128"), + ] + cur = conn.cursor() + cur.execute(f"select %{fmt_in}", (val,)) + got = cur.fetchone()[0] + assert val == got + + @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) @pytest.mark.parametrize("val", ["127.0.0.1/32", "::ffff:102:300/128"]) def test_inet_load_address(conn, fmt_out, val):