]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add ipaddress binary dumpers
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 8 Jun 2021 03:15:34 +0000 (04:15 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 8 Jun 2021 03:15:34 +0000 (04:15 +0100)
psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/network.py
tests/types/test_network.py

index 511dcbf02a177333143e4b19e267ea4de7932c85..d0d1eef923678ea98b25e84c68c49233a3ead62b 100644 (file)
@@ -124,6 +124,12 @@ from .uuid import (
 from .network import (
     InterfaceDumper as InterfaceDumper,
     NetworkDumper as NetworkDumper,
+    IPv4AddressBinaryDumper as IPv4AddressBinaryDumper,
+    IPv6AddressBinaryDumper as IPv6AddressBinaryDumper,
+    IPv4InterfaceBinaryDumper as IPv4InterfaceBinaryDumper,
+    IPv6InterfaceBinaryDumper as IPv6InterfaceBinaryDumper,
+    IPv4NetworkBinaryDumper as IPv4NetworkBinaryDumper,
+    IPv6NetworkBinaryDumper as IPv6NetworkBinaryDumper,
     InetLoader as InetLoader,
     CidrLoader as CidrLoader,
 )
@@ -258,6 +264,12 @@ def register_default_globals(ctx: AdaptContext) -> None:
     InterfaceDumper.register("ipaddress.IPv6Interface", ctx)
     NetworkDumper.register("ipaddress.IPv4Network", ctx)
     NetworkDumper.register("ipaddress.IPv6Network", ctx)
+    IPv4AddressBinaryDumper.register("ipaddress.IPv4Address", ctx)
+    IPv6AddressBinaryDumper.register("ipaddress.IPv6Address", ctx)
+    IPv4InterfaceBinaryDumper.register("ipaddress.IPv4Interface", ctx)
+    IPv6InterfaceBinaryDumper.register("ipaddress.IPv6Interface", ctx)
+    IPv4NetworkBinaryDumper.register("ipaddress.IPv4Network", ctx)
+    IPv6NetworkBinaryDumper.register("ipaddress.IPv6Network", ctx)
     InetLoader.register("inet", ctx)
     CidrLoader.register("cidr", ctx)
 
index c08e6971d3bbf23f8cd169c5054bb6e385862c46..6f4d36f217837a99743ab18f0d5a52e8fb2eae67 100644 (file)
@@ -24,6 +24,9 @@ ip_address: Callable[[str], Address]
 ip_interface: Callable[[str], Interface]
 ip_network: Callable[[str], Network]
 
+PGSQL_AF_INET = 2
+PGSQL_AF_INET6 = 3
+
 
 class InterfaceDumper(Dumper):
 
@@ -43,6 +46,80 @@ class NetworkDumper(Dumper):
         return str(obj).encode("utf8")
 
 
+class _IPv4Mixin:
+    _family = PGSQL_AF_INET
+    _prefixlen = 32
+
+
+class _IPv6Mixin:
+    _family = PGSQL_AF_INET6
+    _prefixlen = 128
+
+
+class _AddressBinaryDumper(Dumper):
+
+    format = Format.BINARY
+    _oid = builtins["inet"].oid
+
+    _family: int
+    _prefixlen: int
+
+    def dump(self, obj: Address) -> bytes:
+        packed = obj.packed
+        head = bytes((self._family, self._prefixlen, 0, len(packed)))
+        return head + packed
+
+
+class IPv4AddressBinaryDumper(_IPv4Mixin, _AddressBinaryDumper):
+    pass
+
+
+class IPv6AddressBinaryDumper(_IPv6Mixin, _AddressBinaryDumper):
+    pass
+
+
+class _InterfaceBinaryDumper(Dumper):
+
+    format = Format.BINARY
+    _oid = builtins["inet"].oid
+
+    _family: int
+
+    def dump(self, obj: Interface) -> bytes:
+        packed = obj.packed
+        head = bytes((self._family, obj.network.prefixlen, 0, len(packed)))
+        return head + packed
+
+
+class IPv4InterfaceBinaryDumper(_IPv4Mixin, _InterfaceBinaryDumper):
+    pass
+
+
+class IPv6InterfaceBinaryDumper(_IPv6Mixin, _InterfaceBinaryDumper):
+    pass
+
+
+class _NetworkBinaryDumper(Dumper):
+
+    format = Format.BINARY
+    _oid = builtins["cidr"].oid
+
+    _family: int
+
+    def dump(self, obj: Network) -> bytes:
+        packed = obj.network_address.packed
+        head = bytes((self._family, obj.prefixlen, 1, len(packed)))
+        return head + packed
+
+
+class IPv4NetworkBinaryDumper(_IPv4Mixin, _NetworkBinaryDumper):
+    pass
+
+
+class IPv6NetworkBinaryDumper(_IPv6Mixin, _NetworkBinaryDumper):
+    pass
+
+
 class _LazyIpaddress(Loader):
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
index 3b88d38a184b7ca263e309af5ecb61637365f5db..4934622432ca9eaabd0910b9257aa5c58eb770e4 100644 (file)
@@ -12,7 +12,6 @@ from psycopg3.adapt import Format
 @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("val", ["192.168.0.1", "2001:db8::"])
 def test_address_dump(conn, fmt_in, val):
-    binary_check(fmt_in)
     cur = conn.cursor()
     cur.execute(
         f"select %{fmt_in} = %s::inet", (ipaddress.ip_address(val), val)
@@ -28,12 +27,12 @@ def test_address_dump(conn, fmt_in, val):
 @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("val", ["127.0.0.1/24", "::ffff:102:300/128"])
 def test_interface_dump(conn, fmt_in, val):
-    binary_check(fmt_in)
     cur = conn.cursor()
-    cur.execute(
-        f"select %{fmt_in} = %s::inet", (ipaddress.ip_interface(val), val)
-    )
-    assert cur.fetchone()[0] is True
+    rec = cur.execute(
+        f"select %(val){fmt_in} = %(repr)s::inet, %(val){fmt_in}, %(repr)s::inet",
+        {"val": ipaddress.ip_interface(val), "repr": val},
+    ).fetchone()
+    assert rec[0] is True, f"{rec[1]} != {rec[2]}"
     cur.execute(
         f"select %{fmt_in} = array[null, %s]::inet[]",
         ([None, ipaddress.ip_interface(val)], val),
@@ -44,7 +43,6 @@ def test_interface_dump(conn, fmt_in, val):
 @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"])
 def test_network_dump(conn, fmt_in, val):
-    binary_check(fmt_in)
     cur = conn.cursor()
     cur.execute(
         f"select %{fmt_in} = %s::cidr", (ipaddress.ip_network(val), val)