]> git.ipfire.org Git - thirdparty/bird.git/commitdiff
maria's test aggregator works on IPv4 as well
authorMaria Matejka <mq@ucw.cz>
Mon, 25 Dec 2023 22:23:19 +0000 (23:23 +0100)
committerMaria Matejka <mq@ucw.cz>
Mon, 25 Dec 2023 22:24:49 +0000 (23:24 +0100)
mq-sketch/myagr.py

index 6d2a66682484e6500b031a0f310aa9f0c94ed288..dde2dcc2cac370312186d50e8a7cada0dce554c3 100755 (executable)
@@ -2,7 +2,11 @@
 
 import ipaddress
 
+
 class IPTrie:
+    rootnet = None
+    agrclass = None
+
     def __init__(self, up=None):
         self.children = [ None, None ]
         self.local = None
@@ -31,10 +35,13 @@ class IPTrie:
                 (self.children[0].dump([ *path, 0 ]) if self.children[0] is not None else "") + \
                 (self.children[1].dump([ *path, 1 ]) if self.children[1] is not None else "")
 
-    def aggregate(self, up=None, net=ipaddress.IPv6Network("::/0"), covered=None):
+    def aggregate(self, up=None, net=None, covered=None):
         if self.children[0] is None and self.children[1] is None:
             return self
 
+        if net is None:
+            net = self.rootnet
+
         if self.local:
             covered = self.local
         else:
@@ -42,7 +49,7 @@ class IPTrie:
 
         def coveredNode(bit):
             t = IPTrie(self)
-            t.local = AgrPointv6(list(net.subnets())[bit], covered.bucket)
+            t.local = self.agrclass(list(net.subnets())[bit], covered.bucket)
             t.buckets.add(covered.bucket)
             return t
 
@@ -59,7 +66,7 @@ class IPTrie:
         intersection = ac[0].buckets & ac[1].buckets
 
         if len(intersection) > 0:
-            nap.local = AgrPointv6(net, sorted(intersection)[0])
+            nap.local = self.agrclass(net, sorted(intersection)[0])
             nap.buckets = intersection
         else:
             nap.buckets = ac[0].buckets | ac[1].buckets
@@ -107,6 +114,21 @@ class AgrPointv6(ipaddress.IPv6Network):
     def __init__(self, net, bucket):
         super().__init__(net)
         self.bucket = bucket
+        if IPTrie.rootnet is None:
+            IPTrie.rootnet = ipaddress.IPv6Network("::/0")
+            IPTrie.agrclass = AgrPointv6
+
+    def __str__(self):
+#        print(type(self), super().__str__(), type(self.bucket), self.bucket)
+        return super().__str__() + " -> " + self.bucket
+
+class AgrPointv4(ipaddress.IPv4Network):
+    def __init__(self, net, bucket):
+        super().__init__(net)
+        self.bucket = bucket
+        if IPTrie.rootnet is None:
+            IPTrie.rootnet = ipaddress.IPv4Network("0.0.0.0/0")
+            IPTrie.agrclass = AgrPointv4
 
     def __str__(self):
 #        print(type(self), super().__str__(), type(self.bucket), self.bucket)
@@ -115,12 +137,25 @@ class AgrPointv6(ipaddress.IPv6Network):
 # Load
 t = IPTrie()
 
+p = input()
+data = p.split(" ")
+
 try:
-    while p := input():
-        data = p.split(" ")
-        t.add(AgrPointv6(data[0], data[1]))
-except EOFError:
-    pass
+    t.add(AgrPointv6(data[0], data[1]))
+    try:
+        while p := input():
+            data = p.split(" ")
+            t.add(AgrPointv6(data[0], data[1]))
+    except EOFError:
+        pass
+except ipaddress.AddressValueError:
+    t.add(AgrPointv4(data[0], data[1]))
+    try:
+        while p := input():
+            data = p.split(" ")
+            t.add(AgrPointv4(data[0], data[1]))
+    except EOFError:
+        pass
 
 # Dump
 print("Dump After Load")