]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix dumping arrays of different versions of network objects
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 14 Jul 2021 15:48:21 +0000 (17:48 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 14 Jul 2021 15:48:21 +0000 (17:48 +0200)
Use a single dumper for both IPv4 and IPv6 so that the array sub_dumper
is not ambiguous.

psycopg/psycopg/types/net.py
tests/types/test_net.py

index 96cf1c76a2d899ef369a1bc35bc3d7968819520d..05f0d261cd64b1a30283a3e6430cda762a4127ed 100644 (file)
@@ -54,80 +54,42 @@ class NetworkDumper(Dumper):
         return str(obj).encode("utf8")
 
 
-class _IPv4Mixin:
-    _family = PGSQL_AF_INET
-    _prefixlen = IPV4_PREFIXLEN
-
-
-class _IPv6Mixin:
-    _family = PGSQL_AF_INET6
-    _prefixlen = IPV6_PREFIXLEN
-
-
-class _AddressBinaryDumper(Dumper):
+class AddressBinaryDumper(Dumper):
 
     format = Format.BINARY
     _oid = postgres.types["inet"].oid
 
-    _family: int
-    _prefixlen: int
-
     def dump(self, obj: Address) -> bytes:
         packed = obj.packed
-        head = bytes((self._family, self._prefixlen, 0, len(packed)))
+        family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+        head = bytes((family, obj.max_prefixlen, 0, len(packed)))
         return head + packed
 
 
-class IPv4AddressBinaryDumper(_IPv4Mixin, _AddressBinaryDumper):
-    pass
-
-
-class IPv6AddressBinaryDumper(_IPv6Mixin, _AddressBinaryDumper):
-    pass
-
-
-class _InterfaceBinaryDumper(Dumper):
+class InterfaceBinaryDumper(Dumper):
 
     format = Format.BINARY
     _oid = postgres.types["inet"].oid
 
-    _family: int
-
     def dump(self, obj: Interface) -> bytes:
         packed = obj.packed
-        head = bytes((self._family, obj.network.prefixlen, 0, len(packed)))
+        family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+        head = bytes((family, obj.network.prefixlen, 0, len(packed)))
         return head + packed
 
 
-class IPv4InterfaceBinaryDumper(_IPv4Mixin, _InterfaceBinaryDumper):
-    pass
-
-
-class IPv6InterfaceBinaryDumper(_IPv6Mixin, _InterfaceBinaryDumper):
-    pass
-
-
-class _NetworkBinaryDumper(Dumper):
+class NetworkBinaryDumper(Dumper):
 
     format = Format.BINARY
     _oid = postgres.types["cidr"].oid
 
-    _family: int
-
     def dump(self, obj: Network) -> bytes:
         packed = obj.network_address.packed
-        head = bytes((self._family, obj.prefixlen, 1, len(packed)))
+        family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+        head = bytes((family, obj.prefixlen, 1, len(packed)))
         return head + packed
 
 
-class IPv4NetworkBinaryDumper(_IPv4Mixin, _NetworkBinaryDumper):
-    pass
-
-
-class IPv6NetworkBinaryDumper(_IPv6Mixin, _NetworkBinaryDumper):
-    pass
-
-
 class _LazyIpaddress(Loader):
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
@@ -217,16 +179,12 @@ def register_default_adapters(context: AdaptContext) -> None:
     adapters.register_dumper("ipaddress.IPv6Interface", InterfaceDumper)
     adapters.register_dumper("ipaddress.IPv4Network", NetworkDumper)
     adapters.register_dumper("ipaddress.IPv6Network", NetworkDumper)
-    adapters.register_dumper("ipaddress.IPv4Address", IPv4AddressBinaryDumper)
-    adapters.register_dumper("ipaddress.IPv6Address", IPv6AddressBinaryDumper)
-    adapters.register_dumper(
-        "ipaddress.IPv4Interface", IPv4InterfaceBinaryDumper
-    )
-    adapters.register_dumper(
-        "ipaddress.IPv6Interface", IPv6InterfaceBinaryDumper
-    )
-    adapters.register_dumper("ipaddress.IPv4Network", IPv4NetworkBinaryDumper)
-    adapters.register_dumper("ipaddress.IPv6Network", IPv6NetworkBinaryDumper)
+    adapters.register_dumper("ipaddress.IPv4Address", AddressBinaryDumper)
+    adapters.register_dumper("ipaddress.IPv6Address", AddressBinaryDumper)
+    adapters.register_dumper("ipaddress.IPv4Interface", InterfaceBinaryDumper)
+    adapters.register_dumper("ipaddress.IPv6Interface", InterfaceBinaryDumper)
+    adapters.register_dumper("ipaddress.IPv4Network", NetworkBinaryDumper)
+    adapters.register_dumper("ipaddress.IPv6Network", NetworkBinaryDumper)
     adapters.register_loader("inet", InetLoader)
     adapters.register_loader("inet", InetBinaryLoader)
     adapters.register_loader("cidr", CidrLoader)
index 154c74da02fba5c4893561eee8b6b53968ec0734..fdc6263399b16c248cccaf1861dd0b4bf1939480 100644 (file)
@@ -55,6 +55,18 @@ def test_network_dump(conn, fmt_in, val):
     assert cur.fetchone()[0] is True
 
 
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_network_mixed_size_array(conn, fmt_in):
+    val = [
+        ipaddress.IPv4Network("192.168.0.1/32"),
+        ipaddress.IPv6Network("::1/128"),
+    ]
+    cur = conn.cursor()
+    cur.execute(f"select %{fmt_in}", (val,))
+    got = cur.fetchone()[0]
+    assert val == got
+
+
 @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
 @pytest.mark.parametrize("val", ["127.0.0.1/32", "::ffff:102:300/128"])
 def test_inet_load_address(conn, fmt_out, val):