From: Bob Halley Date: Sat, 27 Jun 2020 22:33:58 +0000 (-0700) Subject: add serial number arithmetic helper X-Git-Tag: v2.0.0rc2~53 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=846055124fd2ff0dd97af203c19776bddd97a59a;p=thirdparty%2Fdnspython.git add serial number arithmetic helper --- diff --git a/dns/__init__.py b/dns/__init__.py index 61ec4127..bb87ff47 100644 --- a/dns/__init__.py +++ b/dns/__init__.py @@ -45,6 +45,7 @@ __all__ = [ 'resolver', 'reversename', 'rrset', + 'serial', 'set', 'tokenizer', 'tsig', diff --git a/dns/serial.py b/dns/serial.py new file mode 100644 index 00000000..b0474151 --- /dev/null +++ b/dns/serial.py @@ -0,0 +1,117 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""Serial Number Arthimetic from RFC 1982""" + +class Serial: + def __init__(self, value, bits=32): + self.value = value % 2 ** bits + self.bits = bits + + def __repr__(self): + return f'dns.serial.Serial({self.value}, {self.bits})' + + def __eq__(self, other): + if isinstance(other, int): + other = Serial(other, self.bits) + elif not isinstance(other, Serial) or other.bits != self.bits: + return NotImplemented + return self.value == other.value + + def __ne__(self, other): + if isinstance(other, int): + other = Serial(other, self.bits) + elif not isinstance(other, Serial) or other.bits != self.bits: + return NotImplemented + return self.value != other.value + + def __lt__(self, other): + if isinstance(other, int): + other = Serial(other, self.bits) + elif not isinstance(other, Serial) or other.bits != self.bits: + return NotImplemented + if self.value < other.value and \ + other.value - self.value < 2 ** (self.bits - 1): + return True + elif self.value > other.value and \ + self.value - other.value > 2 ** (self.bits - 1): + return True + else: + return False + + def __le__(self, other): + return self == other or self < other + + def __gt__(self, other): + if isinstance(other, int): + other = Serial(other, self.bits) + elif not isinstance(other, Serial) or other.bits != self.bits: + return NotImplemented + if self.value < other.value and \ + other.value - self.value > 2 ** (self.bits - 1): + return True + elif self.value > other.value and \ + self.value - other.value < 2 ** (self.bits - 1): + return True + else: + return False + + def __ge__(self, other): + return self == other or self > other + + def __add__(self, other): + v = self.value + if isinstance(other, Serial): + delta = other.value + elif isinstance(other, int): + delta = other + else: + raise ValueError + if abs(delta) > (2 ** (self.bits - 1) - 1): + raise ValueError + v += delta + v = v % 2 ** self.bits + return Serial(v, self.bits) + + def __iadd__(self, other): + v = self.value + if isinstance(other, Serial): + delta = other.value + elif isinstance(other, int): + delta = other + else: + raise ValueError + if abs(delta) > (2 ** (self.bits - 1) - 1): + raise ValueError + v += delta + v = v % 2 ** self.bits + self.value = v + return self + + def __sub__(self, other): + v = self.value + if isinstance(other, Serial): + delta = other.value + elif isinstance(other, int): + delta = other + else: + raise ValueError + if abs(delta) > (2 ** (self.bits - 1) - 1): + raise ValueError + v -= delta + v = v % 2 ** self.bits + return Serial(v, self.bits) + + def __isub__(self, other): + v = self.value + if isinstance(other, Serial): + delta = other.value + elif isinstance(other, int): + delta = other + else: + raise ValueError + if abs(delta) > (2 ** (self.bits - 1) - 1): + raise ValueError + v -= delta + v = v % 2 ** self.bits + self.value = v + return self diff --git a/tests/test_serial.py b/tests/test_serial.py new file mode 100644 index 00000000..d632a460 --- /dev/null +++ b/tests/test_serial.py @@ -0,0 +1,115 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import unittest + +import dns.serial + +def S2(v): + return dns.serial.Serial(v, bits=2) + +def S8(v): + return dns.serial.Serial(v, bits=8) + +class SerialTestCase(unittest.TestCase): + def test_rfc_1982_2_bit_cases(self): + self.assertEqual(S2(0) + S2(1), S2(1)) + self.assertEqual(S2(1) + S2(1), S2(2)) + self.assertEqual(S2(2) + S2(1), S2(3)) + self.assertEqual(S2(3) + S2(1), S2(0)) + self.assertTrue(S2(1) > S2(0)) + self.assertTrue(S2(2) > S2(1)) + self.assertTrue(S2(3) > S2(2)) + self.assertTrue(S2(0) > S2(3)) + self.assertFalse(S2(2) > S2(0)) + self.assertFalse(S2(0) > S2(2)) + self.assertFalse(S2(2) < S2(0)) + self.assertFalse(S2(0) < S2(2)) + + def test_rfc_1982_8_bit_cases(self): + self.assertEqual(S8(255) + S8(1), S8(0)) + self.assertEqual(S8(100) + S8(100), S8(200)) + self.assertEqual(S8(200) + S8(100), S8(44)) + self.assertTrue(S8(1) > S8(0)) + self.assertTrue(S8(44) > S8(0)) + self.assertTrue(S8(100) > S8(0)) + self.assertTrue(S8(100) > S8(44)) + self.assertTrue(S8(200) > S8(100)) + self.assertTrue(S8(255) > S8(200)) + self.assertTrue(S8(0) > S8(255)) + self.assertTrue(S8(255) < S8(0)) + self.assertTrue(S8(100) > S8(255)) + self.assertTrue(S8(0) > S8(200)) + self.assertTrue(S8(44) > S8(200)) + self.assertFalse(S8(0) > S8(128)) + self.assertFalse(S8(128) > S8(0)) + self.assertFalse(S8(0) < S8(128)) + self.assertFalse(S8(128) < S8(0)) + self.assertFalse(S8(1) > S8(129)) + self.assertFalse(S8(129) > S8(1)) + + def test_incremental_ops(self): + v = S8(255) + v += 1 + self.assertEqual(v, 0) + v = S8(255) + v += S8(1) + self.assertEqual(v, 0) + v = S8(0) + v -= 1 + self.assertEqual(v, 255) + v = S8(0) + v -= S8(1) + self.assertEqual(v, 255) + + def test_sub(self): + self.assertEqual(S8(0) - S8(1), S8(255)) + + def test_sub(self): + self.assertEqual(S8(0) - S8(1), S8(255)) + + def test_addition_bounds(self): + self.assertRaises(ValueError, lambda: S8(0) + 128) + self.assertRaises(ValueError, lambda: S8(0) - 128) + def bad1(): + v = S8(0) + v += 128 + self.assertRaises(ValueError, bad1) + def bad2(): + v = S8(0) + v -= 128 + self.assertRaises(ValueError, bad2) + + def test_casting(self): + self.assertTrue(S8(0) == 0) + self.assertTrue(S8(0) != 1) + self.assertTrue(S8(0) < 1) + self.assertTrue(S8(0) <= 1) + self.assertTrue(S8(0) > 255) + self.assertTrue(S8(0) >= 255) + + def test_uncastable(self): + self.assertRaises(ValueError, lambda: S8(0) + 'a') + self.assertRaises(ValueError, lambda: S8(0) - 'a') + def bad1(): + v = S8(0) + v += 'a' + self.assertRaises(ValueError, bad1) + def bad2(): + v = S8(0) + v -= 'a' + self.assertRaises(ValueError, bad2) + + def test_uncomparable(self): + self.assertFalse(S8(0) == 'a') + self.assertTrue(S8(0) != 'a') + self.assertRaises(TypeError, lambda: S8(0) < 'a') + self.assertRaises(TypeError, lambda: S8(0) <= 'a') + self.assertRaises(TypeError, lambda: S8(0) > 'a') + self.assertRaises(TypeError, lambda: S8(0) >= 'a') + + def test_modulo(self): + self.assertEqual(S8(-1), 255) + self.assertEqual(S8(257), 1) + + def test_repr(self): + self.assertEqual(repr(S8(1)), 'dns.serial.Serial(1, 8)')