]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
add serial number arithmetic helper
authorBob Halley <halley@play-bow.org>
Sat, 27 Jun 2020 22:33:58 +0000 (15:33 -0700)
committerBob Halley <halley@play-bow.org>
Sat, 27 Jun 2020 22:33:58 +0000 (15:33 -0700)
dns/__init__.py
dns/serial.py [new file with mode: 0644]
tests/test_serial.py [new file with mode: 0644]

index 61ec41271ddbe25e9a93dc15be6ef92a38d300b3..bb87ff477702607fd13505e8330b25bfe332656f 100644 (file)
@@ -45,6 +45,7 @@ __all__ = [
     'resolver',
     'reversename',
     'rrset',
+    'serial',
     'set',
     'tokenizer',
     'tsig',
diff --git a/dns/serial.py b/dns/serial.py
new file mode 100644 (file)
index 0000000..b047415
--- /dev/null
@@ -0,0 +1,117 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""Serial Number Arthimetic from RFC 1982"""
+
+class Serial:
+    def __init__(self, value, bits=32):
+        self.value = value % 2 ** bits
+        self.bits = bits
+
+    def __repr__(self):
+        return f'dns.serial.Serial({self.value}, {self.bits})'
+
+    def __eq__(self, other):
+        if isinstance(other, int):
+            other = Serial(other, self.bits)
+        elif not isinstance(other, Serial) or other.bits != self.bits:
+            return NotImplemented
+        return self.value == other.value
+
+    def __ne__(self, other):
+        if isinstance(other, int):
+            other = Serial(other, self.bits)
+        elif not isinstance(other, Serial) or other.bits != self.bits:
+            return NotImplemented
+        return self.value != other.value
+
+    def __lt__(self, other):
+        if isinstance(other, int):
+            other = Serial(other, self.bits)
+        elif not isinstance(other, Serial) or other.bits != self.bits:
+            return NotImplemented
+        if self.value < other.value and \
+           other.value - self.value < 2 ** (self.bits - 1):
+            return True
+        elif self.value > other.value and \
+             self.value - other.value > 2 ** (self.bits - 1):
+            return True
+        else:
+            return False
+
+    def __le__(self, other):
+        return self == other or self < other
+
+    def __gt__(self, other):
+        if isinstance(other, int):
+            other = Serial(other, self.bits)
+        elif not isinstance(other, Serial) or other.bits != self.bits:
+            return NotImplemented
+        if self.value < other.value and \
+           other.value - self.value > 2 ** (self.bits - 1):
+            return True
+        elif self.value > other.value and \
+             self.value - other.value < 2 ** (self.bits - 1):
+            return True
+        else:
+            return False
+
+    def __ge__(self, other):
+        return self == other or self > other
+
+    def __add__(self, other):
+        v = self.value
+        if isinstance(other, Serial):
+            delta = other.value
+        elif isinstance(other, int):
+            delta = other
+        else:
+            raise ValueError
+        if abs(delta) > (2 ** (self.bits - 1) - 1):
+            raise ValueError
+        v += delta
+        v = v % 2 ** self.bits
+        return Serial(v, self.bits)
+
+    def __iadd__(self, other):
+        v = self.value
+        if isinstance(other, Serial):
+            delta = other.value
+        elif isinstance(other, int):
+            delta = other
+        else:
+            raise ValueError
+        if abs(delta) > (2 ** (self.bits - 1) - 1):
+            raise ValueError
+        v += delta
+        v = v % 2 ** self.bits
+        self.value = v
+        return self
+
+    def __sub__(self, other):
+        v = self.value
+        if isinstance(other, Serial):
+            delta = other.value
+        elif isinstance(other, int):
+            delta = other
+        else:
+            raise ValueError
+        if abs(delta) > (2 ** (self.bits - 1) - 1):
+            raise ValueError
+        v -= delta
+        v = v % 2 ** self.bits
+        return Serial(v, self.bits)
+
+    def __isub__(self, other):
+        v = self.value
+        if isinstance(other, Serial):
+            delta = other.value
+        elif isinstance(other, int):
+            delta = other
+        else:
+            raise ValueError
+        if abs(delta) > (2 ** (self.bits - 1) - 1):
+            raise ValueError
+        v -= delta
+        v = v % 2 ** self.bits
+        self.value = v
+        return self
diff --git a/tests/test_serial.py b/tests/test_serial.py
new file mode 100644 (file)
index 0000000..d632a46
--- /dev/null
@@ -0,0 +1,115 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import unittest
+
+import dns.serial
+
+def S2(v):
+    return dns.serial.Serial(v, bits=2)
+
+def S8(v):
+    return dns.serial.Serial(v, bits=8)
+
+class SerialTestCase(unittest.TestCase):
+    def test_rfc_1982_2_bit_cases(self):
+        self.assertEqual(S2(0) + S2(1), S2(1))
+        self.assertEqual(S2(1) + S2(1), S2(2))
+        self.assertEqual(S2(2) + S2(1), S2(3))
+        self.assertEqual(S2(3) + S2(1), S2(0))
+        self.assertTrue(S2(1) > S2(0))
+        self.assertTrue(S2(2) > S2(1))
+        self.assertTrue(S2(3) > S2(2))
+        self.assertTrue(S2(0) > S2(3))
+        self.assertFalse(S2(2) > S2(0))
+        self.assertFalse(S2(0) > S2(2))
+        self.assertFalse(S2(2) < S2(0))
+        self.assertFalse(S2(0) < S2(2))
+
+    def test_rfc_1982_8_bit_cases(self):
+        self.assertEqual(S8(255) + S8(1), S8(0))
+        self.assertEqual(S8(100) + S8(100), S8(200))
+        self.assertEqual(S8(200) + S8(100), S8(44))
+        self.assertTrue(S8(1) > S8(0))
+        self.assertTrue(S8(44) > S8(0))
+        self.assertTrue(S8(100) > S8(0))
+        self.assertTrue(S8(100) > S8(44))
+        self.assertTrue(S8(200) > S8(100))
+        self.assertTrue(S8(255) > S8(200))
+        self.assertTrue(S8(0) > S8(255))
+        self.assertTrue(S8(255) < S8(0))
+        self.assertTrue(S8(100) > S8(255))
+        self.assertTrue(S8(0) > S8(200))
+        self.assertTrue(S8(44) > S8(200))
+        self.assertFalse(S8(0) > S8(128))
+        self.assertFalse(S8(128) > S8(0))
+        self.assertFalse(S8(0) < S8(128))
+        self.assertFalse(S8(128) < S8(0))
+        self.assertFalse(S8(1) > S8(129))
+        self.assertFalse(S8(129) > S8(1))
+
+    def test_incremental_ops(self):
+        v = S8(255)
+        v += 1
+        self.assertEqual(v, 0)
+        v = S8(255)
+        v += S8(1)
+        self.assertEqual(v, 0)
+        v = S8(0)
+        v -= 1
+        self.assertEqual(v, 255)
+        v = S8(0)
+        v -= S8(1)
+        self.assertEqual(v, 255)
+
+    def test_sub(self):
+        self.assertEqual(S8(0) - S8(1), S8(255))
+
+    def test_sub(self):
+        self.assertEqual(S8(0) - S8(1), S8(255))
+
+    def test_addition_bounds(self):
+        self.assertRaises(ValueError, lambda: S8(0) + 128)
+        self.assertRaises(ValueError, lambda: S8(0) - 128)
+        def bad1():
+            v = S8(0)
+            v += 128
+        self.assertRaises(ValueError, bad1)
+        def bad2():
+            v = S8(0)
+            v -= 128
+        self.assertRaises(ValueError, bad2)
+
+    def test_casting(self):
+        self.assertTrue(S8(0) == 0)
+        self.assertTrue(S8(0) != 1)
+        self.assertTrue(S8(0) < 1)
+        self.assertTrue(S8(0) <= 1)
+        self.assertTrue(S8(0) > 255)
+        self.assertTrue(S8(0) >= 255)
+
+    def test_uncastable(self):
+        self.assertRaises(ValueError, lambda: S8(0) + 'a')
+        self.assertRaises(ValueError, lambda: S8(0) - 'a')
+        def bad1():
+            v = S8(0)
+            v += 'a'
+        self.assertRaises(ValueError, bad1)
+        def bad2():
+            v = S8(0)
+            v -= 'a'
+        self.assertRaises(ValueError, bad2)
+
+    def test_uncomparable(self):
+        self.assertFalse(S8(0) == 'a')
+        self.assertTrue(S8(0) != 'a')
+        self.assertRaises(TypeError, lambda: S8(0) < 'a')
+        self.assertRaises(TypeError, lambda: S8(0) <= 'a')
+        self.assertRaises(TypeError, lambda: S8(0) > 'a')
+        self.assertRaises(TypeError, lambda: S8(0) >= 'a')
+
+    def test_modulo(self):
+        self.assertEqual(S8(-1), 255)
+        self.assertEqual(S8(257), 1)
+
+    def test_repr(self):
+        self.assertEqual(repr(S8(1)), 'dns.serial.Serial(1, 8)')