]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add network types binary loaders
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 8 Jun 2021 03:48:12 +0000 (04:48 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 8 Jun 2021 03:48:39 +0000 (04:48 +0100)
psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/network.py
tests/types/test_network.py

index d0d1eef923678ea98b25e84c68c49233a3ead62b..e3d0ea4436e9fc2de4fb9122470209c0db382ffd 100644 (file)
@@ -131,7 +131,9 @@ from .network import (
     IPv4NetworkBinaryDumper as IPv4NetworkBinaryDumper,
     IPv6NetworkBinaryDumper as IPv6NetworkBinaryDumper,
     InetLoader as InetLoader,
+    InetBinaryLoader as InetBinaryLoader,
     CidrLoader as CidrLoader,
+    CidrBinaryLoader as CidrBinaryLoader,
 )
 from .range import (
     RangeDumper as RangeDumper,
@@ -271,7 +273,9 @@ def register_default_globals(ctx: AdaptContext) -> None:
     IPv4NetworkBinaryDumper.register("ipaddress.IPv4Network", ctx)
     IPv6NetworkBinaryDumper.register("ipaddress.IPv6Network", ctx)
     InetLoader.register("inet", ctx)
+    InetBinaryLoader.register("inet", ctx)
     CidrLoader.register("cidr", ctx)
+    CidrBinaryLoader.register("cidr", ctx)
 
     RangeDumper.register(Range, ctx)
     Int4RangeLoader.register("int4range", ctx)
index 6f4d36f217837a99743ab18f0d5a52e8fb2eae67..7248323bd0677e502f6fadb9a55039c76817ace2 100644 (file)
@@ -4,7 +4,7 @@ Adapters for network types.
 
 # Copyright (C) 2020-2021 The Psycopg Team
 
-from typing import Callable, Optional, Union, TYPE_CHECKING
+from typing import Callable, Optional, Type, Union, TYPE_CHECKING
 
 from ..pq import Format
 from ..oids import postgres_types as builtins
@@ -18,14 +18,22 @@ Address = Union["ipaddress.IPv4Address", "ipaddress.IPv6Address"]
 Interface = Union["ipaddress.IPv4Interface", "ipaddress.IPv6Interface"]
 Network = Union["ipaddress.IPv4Network", "ipaddress.IPv6Network"]
 
-# These functions will be imported lazily
+# These objects will be imported lazily
 imported = False
 ip_address: Callable[[str], Address]
 ip_interface: Callable[[str], Interface]
 ip_network: Callable[[str], Network]
+IPv4Address: "Type[ipaddress.IPv4Address]"
+IPv6Address: "Type[ipaddress.IPv6Address]"
+IPv4Interface: "Type[ipaddress.IPv4Interface]"
+IPv6Interface: "Type[ipaddress.IPv6Interface]"
+IPv4Network: "Type[ipaddress.IPv4Network]"
+IPv6Network: "Type[ipaddress.IPv6Network]"
 
 PGSQL_AF_INET = 2
 PGSQL_AF_INET6 = 3
+IPV4_PREFIXLEN = 32
+IPV6_PREFIXLEN = 128
 
 
 class InterfaceDumper(Dumper):
@@ -48,12 +56,12 @@ class NetworkDumper(Dumper):
 
 class _IPv4Mixin:
     _family = PGSQL_AF_INET
-    _prefixlen = 32
+    _prefixlen = IPV4_PREFIXLEN
 
 
 class _IPv6Mixin:
     _family = PGSQL_AF_INET6
-    _prefixlen = 128
+    _prefixlen = IPV6_PREFIXLEN
 
 
 class _AddressBinaryDumper(Dumper):
@@ -124,8 +132,14 @@ class _LazyIpaddress(Loader):
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         global imported, ip_address, ip_interface, ip_network
+        global IPv4Address, IPv6Address, IPv4Interface, IPv6Interface
+        global IPv4Network, IPv6Network
+
         if not imported:
             from ipaddress import ip_address, ip_interface, ip_network
+            from ipaddress import IPv4Address, IPv6Address
+            from ipaddress import IPv4Interface, IPv6Interface
+            from ipaddress import IPv4Network, IPv6Network
 
             imported = True
 
@@ -144,6 +158,28 @@ class InetLoader(_LazyIpaddress):
             return ip_address(data.decode("utf8"))
 
 
+class InetBinaryLoader(_LazyIpaddress):
+
+    format = Format.BINARY
+
+    def load(self, data: Buffer) -> Union[Address, Interface]:
+        if isinstance(data, memoryview):
+            data = bytes(data)
+
+        prefix = data[1]
+        packed = data[4:]
+        if data[0] == PGSQL_AF_INET:
+            if prefix == IPV4_PREFIXLEN:
+                return IPv4Address(packed)
+            else:
+                return IPv4Interface((packed, prefix))
+        else:
+            if prefix == IPV6_PREFIXLEN:
+                return IPv6Address(packed)
+            else:
+                return IPv6Interface((packed, prefix))
+
+
 class CidrLoader(_LazyIpaddress):
 
     format = Format.TEXT
@@ -153,3 +189,21 @@ class CidrLoader(_LazyIpaddress):
             data = bytes(data)
 
         return ip_network(data.decode("utf8"))
+
+
+class CidrBinaryLoader(_LazyIpaddress):
+
+    format = Format.BINARY
+
+    def load(self, data: Buffer) -> Network:
+        if isinstance(data, memoryview):
+            data = bytes(data)
+
+        prefix = data[1]
+        packed = data[4:]
+        if data[0] == PGSQL_AF_INET:
+            return IPv4Network((packed, prefix))
+        else:
+            return IPv6Network((packed, prefix))
+
+        return ip_network(data.decode("utf8"))
index 4934622432ca9eaabd0910b9257aa5c58eb770e4..bafc7afe38fe5bc190a0af5734050967490211bf 100644 (file)
@@ -58,7 +58,6 @@ def test_network_dump(conn, fmt_in, val):
 @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
 @pytest.mark.parametrize("val", ["127.0.0.1/32", "::ffff:102:300/128"])
 def test_inet_load_address(conn, fmt_out, val):
-    binary_check(fmt_out)
     addr = ipaddress.ip_address(val.split("/", 1)[0])
     cur = conn.cursor(binary=fmt_out)
 
@@ -81,7 +80,6 @@ def test_inet_load_address(conn, fmt_out, val):
 @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
 @pytest.mark.parametrize("val", ["127.0.0.1/24", "::ffff:102:300/127"])
 def test_inet_load_network(conn, fmt_out, val):
-    binary_check(fmt_out)
     pyval = ipaddress.ip_interface(val)
     cur = conn.cursor(binary=fmt_out)
 
@@ -104,7 +102,6 @@ def test_inet_load_network(conn, fmt_out, val):
 @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
 @pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"])
 def test_cidr_load(conn, fmt_out, val):
-    binary_check(fmt_out)
     pyval = ipaddress.ip_network(val)
     cur = conn.cursor(binary=fmt_out)
 
@@ -124,11 +121,6 @@ def test_cidr_load(conn, fmt_out, val):
     assert got == pyval
 
 
-def binary_check(fmt):
-    if fmt == Format.BINARY or fmt == pq.Format.BINARY:
-        pytest.xfail("inet binary not implemented")
-
-
 @pytest.mark.subprocess
 def test_lazy_load(dsn):
     script = f"""\