]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Consolidate NSEC/NSEC3/CSYNC bitmap handling.
authorBrian Wellington <bwelling@xbill.org>
Thu, 2 Jul 2020 22:29:25 +0000 (15:29 -0700)
committerBrian Wellington <bwelling@xbill.org>
Thu, 2 Jul 2020 22:29:25 +0000 (15:29 -0700)
This also fixes several bugs; the NSEC3 code would properly avoid empty
windows, but the NSEC and CSYNC code did not.  Also, none of the wire
parsing routines properly checked to see that the window number was
monotonically increasing.

dns/rdtypes/ANY/CSYNC.py
dns/rdtypes/ANY/NSEC.py
dns/rdtypes/ANY/NSEC3.py
dns/rdtypes/util.py
tests/test_nsec3.py

index c62dad8a90429779e9b172e142ab6e277976b226..9cba5fad09610d99fc7591f76528273af9b17f08 100644 (file)
@@ -21,6 +21,12 @@ import dns.exception
 import dns.rdata
 import dns.rdatatype
 import dns.name
+import dns.rdtypes.util
+
+
+class Bitmap(dns.rdtypes.util.Bitmap):
+    type_name = 'CSYNC'
+
 
 class CSYNC(dns.rdata.Rdata):
 
@@ -35,15 +41,7 @@ class CSYNC(dns.rdata.Rdata):
         object.__setattr__(self, 'windows', dns.rdata._constify(windows))
 
     def to_text(self, origin=None, relativize=True, **kw):
-        text = ''
-        for (window, bitmap) in self.windows:
-            bits = []
-            for (i, byte) in enumerate(bitmap):
-                for j in range(0, 8):
-                    if byte & (0x80 >> j):
-                        bits.append(dns.rdatatype.to_text(window * 256 +
-                                                          i * 8 + j))
-            text += (' ' + ' '.join(bits))
+        text = Bitmap(self.windows).to_text()
         return '%d %d%s' % (self.serial, self.flags, text)
 
     @classmethod
@@ -51,56 +49,15 @@ class CSYNC(dns.rdata.Rdata):
                   relativize_to=None):
         serial = tok.get_uint32()
         flags = tok.get_uint16()
-        rdtypes = []
-        while 1:
-            token = tok.get().unescape()
-            if token.is_eol_or_eof():
-                break
-            nrdtype = dns.rdatatype.from_text(token.value)
-            if nrdtype == 0:
-                raise dns.exception.SyntaxError("CSYNC with bit 0")
-            if nrdtype > 65535:
-                raise dns.exception.SyntaxError("CSYNC with bit > 65535")
-            rdtypes.append(nrdtype)
-        rdtypes.sort()
-        window = 0
-        octets = 0
-        prior_rdtype = 0
-        bitmap = bytearray(b'\0' * 32)
-        windows = []
-        for nrdtype in rdtypes:
-            if nrdtype == prior_rdtype:
-                continue
-            prior_rdtype = nrdtype
-            new_window = nrdtype // 256
-            if new_window != window:
-                windows.append((window, bitmap[0:octets]))
-                bitmap = bytearray(b'\0' * 32)
-                window = new_window
-            offset = nrdtype % 256
-            byte = offset // 8
-            bit = offset % 8
-            octets = byte + 1
-            bitmap[byte] = bitmap[byte] | (0x80 >> bit)
-
-        windows.append((window, bitmap[0:octets]))
+        windows = Bitmap().from_text(tok)
         return cls(rdclass, rdtype, serial, flags, windows)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
         file.write(struct.pack('!IH', self.serial, self.flags))
-        for (window, bitmap) in self.windows:
-            file.write(struct.pack('!BB', window, len(bitmap)))
-            file.write(bitmap)
+        Bitmap(self.windows).to_wire(file)
 
     @classmethod
     def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
         (serial, flags) = parser.get_struct("!IH")
-        windows = []
-        while parser.remaining() > 0:
-            window = parser.get_uint8()
-            octets = parser.get_uint8()
-            if octets == 0 or octets > 32:
-                raise dns.exception.FormError("bad CSYNC octets")
-            bitmap = parser.get_bytes(octets)
-            windows.append((window, bitmap))
+        windows = Bitmap().from_wire_parser(parser)
         return cls(rdclass, rdtype, serial, flags, windows)
index 8c1da5aec6ee7a559184e1c7c26321cbeaf11abd..85bc662fe33dcccc2026a2cce3af2041702cf841 100644 (file)
@@ -21,6 +21,11 @@ import dns.exception
 import dns.rdata
 import dns.rdatatype
 import dns.name
+import dns.rdtypes.util
+
+
+class Bitmap(dns.rdtypes.util.Bitmap):
+    type_name = 'NSEC'
 
 
 class NSEC(dns.rdata.Rdata):
@@ -36,70 +41,22 @@ class NSEC(dns.rdata.Rdata):
 
     def to_text(self, origin=None, relativize=True, **kw):
         next = self.next.choose_relativity(origin, relativize)
-        text = ''
-        for (window, bitmap) in self.windows:
-            bits = []
-            for (i, byte) in enumerate(bitmap):
-                for j in range(0, 8):
-                    if byte & (0x80 >> j):
-                        bits.append(dns.rdatatype.to_text(window * 256 +
-                                                          i * 8 + j))
-            text += (' ' + ' '.join(bits))
+        text = Bitmap(self.windows).to_text()
         return '{}{}'.format(next, text)
 
     @classmethod
     def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
                   relativize_to=None):
         next = tok.get_name(origin, relativize, relativize_to)
-        rdtypes = []
-        while 1:
-            token = tok.get().unescape()
-            if token.is_eol_or_eof():
-                break
-            nrdtype = dns.rdatatype.from_text(token.value)
-            if nrdtype == 0:
-                raise dns.exception.SyntaxError("NSEC with bit 0")
-            if nrdtype > 65535:
-                raise dns.exception.SyntaxError("NSEC with bit > 65535")
-            rdtypes.append(nrdtype)
-        rdtypes.sort()
-        window = 0
-        octets = 0
-        prior_rdtype = 0
-        bitmap = bytearray(b'\0' * 32)
-        windows = []
-        for nrdtype in rdtypes:
-            if nrdtype == prior_rdtype:
-                continue
-            prior_rdtype = nrdtype
-            new_window = nrdtype // 256
-            if new_window != window:
-                windows.append((window, bitmap[0:octets]))
-                bitmap = bytearray(b'\0' * 32)
-                window = new_window
-            offset = nrdtype % 256
-            byte = offset // 8
-            bit = offset % 8
-            octets = byte + 1
-            bitmap[byte] = bitmap[byte] | (0x80 >> bit)
-
-        windows.append((window, bitmap[0:octets]))
+        windows = Bitmap().from_text(tok)
         return cls(rdclass, rdtype, next, windows)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
         self.next.to_wire(file, None, origin, False)
-        for (window, bitmap) in self.windows:
-            file.write(struct.pack('!BB', window, len(bitmap)))
-            file.write(bitmap)
+        Bitmap(self.windows).to_wire(file)
 
     @classmethod
     def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
         next = parser.get_name(origin)
-        windows = []
-        while parser.remaining() > 0:
-            window = parser.get_uint8()
-            bitmap = parser.get_counted_bytes()
-            if len(bitmap) == 0 or len(bitmap) > 32:
-                raise dns.exception.FormError("bad NSEC octets")
-            windows.append((window, bitmap))
+        windows = Bitmap().from_wire_parser(parser)
         return cls(rdclass, rdtype, next, windows)
index 32dfe3e08442155c162b8f6000a209d3e83fd117..91471f0f571e0a837196aae180e5a0e2820d3651 100644 (file)
@@ -22,6 +22,7 @@ import struct
 import dns.exception
 import dns.rdata
 import dns.rdatatype
+import dns.rdtypes.util
 
 
 b32_hex_to_normal = bytes.maketrans(b'0123456789ABCDEFGHIJKLMNOPQRSTUV',
@@ -36,6 +37,10 @@ SHA1 = 1
 OPTOUT = 1
 
 
+class Bitmap(dns.rdtypes.util.Bitmap):
+    type_name = 'NSEC3'
+
+
 class NSEC3(dns.rdata.Rdata):
 
     """NSEC3 record"""
@@ -62,15 +67,7 @@ class NSEC3(dns.rdata.Rdata):
             salt = '-'
         else:
             salt = binascii.hexlify(self.salt).decode()
-        text = ''
-        for (window, bitmap) in self.windows:
-            bits = []
-            for (i, byte) in enumerate(bitmap):
-                for j in range(0, 8):
-                    if byte & (0x80 >> j):
-                        bits.append(dns.rdatatype.to_text(window * 256 +
-                                                          i * 8 + j))
-            text += (' ' + ' '.join(bits))
+        text = Bitmap(self.windows).to_text()
         return '%u %u %u %s %s%s' % (self.algorithm, self.flags,
                                      self.iterations, salt, next, text)
 
@@ -88,40 +85,7 @@ class NSEC3(dns.rdata.Rdata):
         next = tok.get_string().encode(
             'ascii').upper().translate(b32_hex_to_normal)
         next = base64.b32decode(next)
-        rdtypes = []
-        while 1:
-            token = tok.get().unescape()
-            if token.is_eol_or_eof():
-                break
-            nrdtype = dns.rdatatype.from_text(token.value)
-            if nrdtype == 0:
-                raise dns.exception.SyntaxError("NSEC3 with bit 0")
-            if nrdtype > 65535:
-                raise dns.exception.SyntaxError("NSEC3 with bit > 65535")
-            rdtypes.append(nrdtype)
-        rdtypes.sort()
-        window = 0
-        octets = 0
-        prior_rdtype = 0
-        bitmap = bytearray(b'\0' * 32)
-        windows = []
-        for nrdtype in rdtypes:
-            if nrdtype == prior_rdtype:
-                continue
-            prior_rdtype = nrdtype
-            new_window = nrdtype // 256
-            if new_window != window:
-                if octets != 0:
-                    windows.append((window, bitmap[0:octets]))
-                bitmap = bytearray(b'\0' * 32)
-                window = new_window
-            offset = nrdtype % 256
-            byte = offset // 8
-            bit = offset % 8
-            octets = byte + 1
-            bitmap[byte] = bitmap[byte] | (0x80 >> bit)
-        if octets != 0:
-            windows.append((window, bitmap[0:octets]))
+        windows = Bitmap().from_text(tok)
         return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next,
                    windows)
 
@@ -133,21 +97,13 @@ class NSEC3(dns.rdata.Rdata):
         l = len(self.next)
         file.write(struct.pack("!B", l))
         file.write(self.next)
-        for (window, bitmap) in self.windows:
-            file.write(struct.pack("!BB", window, len(bitmap)))
-            file.write(bitmap)
+        Bitmap(self.windows).to_wire(file)
 
     @classmethod
     def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
         (algorithm, flags, iterations) = parser.get_struct('!BBH')
         salt = parser.get_counted_bytes()
         next = parser.get_counted_bytes()
-        windows = []
-        while parser.remaining() > 0:
-            window = parser.get_uint8()
-            bitmap = parser.get_counted_bytes()
-            if len(bitmap) == 0 or len(bitmap) > 32:
-                raise dns.exception.FormError("bad NSEC3 octets")
-            windows.append((window, bitmap))
+        windows = Bitmap().from_wire_parser(parser)
         return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next,
                    windows)
index 3dc636d06599d52be77260657890096596752d2d..a63d1a0abc1c6b7af6ef1a0faa135f401e35c6b7 100644 (file)
@@ -15,6 +15,8 @@
 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
+import struct
+
 import dns.exception
 import dns.name
 import dns.ipv4
@@ -89,3 +91,76 @@ class Gateway:
             return parser.get_name(origin)
         else:
             raise dns.exception.FormError(self._invalid_type())
+
+class Bitmap:
+    """A helper class for the NSEC/NSEC3/CSYNC type bitmaps"""
+    type_name = ""
+
+    def __init__(self, windows=None):
+        self.windows = windows
+
+    def to_text(self):
+        text = ""
+        for (window, bitmap) in self.windows:
+            bits = []
+            for (i, byte) in enumerate(bitmap):
+                for j in range(0, 8):
+                    if byte & (0x80 >> j):
+                        rdtype = window * 256 + i * 8 + j
+                        bits.append(dns.rdatatype.to_text(rdtype))
+            text += (' ' + ' '.join(bits))
+        return text
+
+    def from_text(self, tok):
+        rdtypes = []
+        while True:
+            token = tok.get().unescape()
+            if token.is_eol_or_eof():
+                break
+            rdtype = dns.rdatatype.from_text(token.value)
+            if rdtype == 0:
+                raise dns.exception.SyntaxError(f"{self.type_name} with bit 0")
+            rdtypes.append(rdtype)
+        rdtypes.sort()
+        window = 0
+        octets = 0
+        prior_rdtype = 0
+        bitmap = bytearray(b'\0' * 32)
+        windows = []
+        for rdtype in rdtypes:
+            if rdtype == prior_rdtype:
+                continue
+            prior_rdtype = rdtype
+            new_window = rdtype // 256
+            if new_window != window:
+                if octets != 0:
+                    windows.append((window, bitmap[0:octets]))
+                bitmap = bytearray(b'\0' * 32)
+                window = new_window
+            offset = rdtype % 256
+            byte = offset // 8
+            bit = offset % 8
+            octets = byte + 1
+            bitmap[byte] = bitmap[byte] | (0x80 >> bit)
+        if octets != 0:
+            windows.append((window, bitmap[0:octets]))
+        return windows
+
+    def to_wire(self, file):
+        for (window, bitmap) in self.windows:
+            file.write(struct.pack('!BB', window, len(bitmap)))
+            file.write(bitmap)
+
+    def from_wire_parser(self, parser):
+        windows = []
+        last_window = -1
+        while parser.remaining() > 0:
+            window = parser.get_uint8()
+            if window <= last_window:
+                raise dns.exception.FormError(f"bad {self.type_name} bitmap")
+            bitmap = parser.get_counted_bytes()
+            if len(bitmap) == 0 or len(bitmap) > 32:
+                raise dns.exception.FormError(f"bad {self.type_name} octets")
+            windows.append((window, bitmap))
+            last_window = window
+        return windows
index 0b75b294a35fcda75dce8f12af23acdc644a1720..bf7d1151b83737be214b23c987d9a50338fe7d2d 100644 (file)
@@ -17,6 +17,7 @@
 
 import unittest
 
+import dns.exception
 import dns.rdata
 import dns.rdataclass
 import dns.rdatatype
@@ -34,5 +35,14 @@ class NSEC3TestCase(unittest.TestCase):
                                          (255, bitmap)
                                          ))
 
+    def test_NSEC3_bad_bitmaps(self):
+        rdata = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NSEC3,
+                u"1 0 100 ABCD SCBCQHKU35969L2A68P3AD59LHF30715 A CAA")
+
+        with self.assertRaises(dns.exception.FormError):
+            copy = bytearray(rdata.to_wire())
+            copy[-3] = 0
+            dns.rdata.from_wire('IN', 'NSEC3', copy, 0, len(copy))
+
 if __name__ == '__main__':
     unittest.main()