]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Implement EDNS Client Subnet option
authorpascal.bouchareine <pascal@gandi.net>
Thu, 3 Nov 2016 19:06:25 +0000 (12:06 -0700)
committerpascal.bouchareine <pascal@gandi.net>
Thu, 3 Nov 2016 19:32:07 +0000 (12:32 -0700)
dns/edns.py
dns/message.py
tests/test_option.py [new file with mode: 0644]

index 8ac676bc62108d991c6d4dc35a507084b2baa454..720b054999ee6a35e312e2051a24d67709754bc5 100644 (file)
 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
+import struct
+
+import dns.inet
+
+
 """EDNS Options"""
 
 NSID = 3
+ECS = 8
 
 
 class Option(object):
@@ -111,6 +117,9 @@ class GenericOption(Option):
     def to_wire(self, file):
         file.write(self.data)
 
+    def to_text(self):
+        return "Generic %d" % self.otype
+
     @classmethod
     def from_wire(cls, otype, wire, current, olen):
         return cls(otype, wire[current: current + olen])
@@ -122,10 +131,89 @@ class GenericOption(Option):
             return 1
         return -1
 
+
+class ECSOption(Option):
+    """EDNS Client Subnet (ECS, RFC7871)"""
+
+    def __init__(self, address, srclen=None, scopelen=0):
+        """Generate an ECS option
+
+        @ivar address: client address information
+        @type address: string
+        @ivar srclen: prefix length, leftmost number of bits of the address
+        to be used for the lookup. Sent by client, mirrored by server in
+        responses. If not provided at init, will use /24 for v4 and /56 for v6
+        @ivar srclen: int
+        @ivar scopelen: prefix length, leftmost number of bits of the address
+        that the response covers. 0 in queries, set by server.
+        """
+        super(ECSOption, self).__init__(ECS)
+        af = dns.inet.af_for_address(address)
+
+        if af == dns.inet.AF_INET6:
+            self.family = 2
+            if srclen is None:
+                srclen = 56
+        elif af == dns.inet.AF_INET:
+            self.family = 1
+            if srclen is None:
+                srclen = 24
+        else:
+            raise ValueError('Bad ip family')
+
+        self.srclen = srclen
+        self.scopelen = scopelen
+        self.address = address
+
+        self.addrdata = dns.inet.inet_pton(af, address)
+
+        # Truncate to srclen and pad to the end of the last octet needed
+        # See RFC section 6
+        self.addrdata = self.addrdata[:-(-srclen//8)]
+        last = ord(self.addrdata[-1:]) & (0xff << srclen % 8)
+        self.addrdata = self.addrdata[:-1] + chr(last).encode('latin1')
+
+    def to_text(self):
+        return "ECS %s/%s scope/%s" % (self.address, self.srclen,
+                                       self.scopelen)
+
+    def to_wire(self, file):
+        """Opt type and len are handled by renderer"""
+        file.write(struct.pack('!H', self.family))
+        file.write(struct.pack('!BB', self.srclen, self.scopelen))
+        file.write(self.addrdata)
+
+    @classmethod
+    def from_wire(cls, otype, wire, cur, olen):
+        """Opt type and len are handled by Message.from_wire"""
+        family, src, scope = struct.unpack('!HBB', wire[cur:cur+4])
+        cur += 4
+
+        addrlen = -(-src//8)
+
+        if family == 1:
+            af = dns.inet.AF_INET
+            pad = 4 - addrlen
+        elif family == 2:
+            af = dns.inet.AF_INET6
+            pad = 16 - addrlen
+        else:
+            raise ValueError('unsupported family')
+
+        addr = dns.inet.inet_ntop(af, wire[cur:cur+addrlen] + '\x00' * pad)
+        return cls(addr, src, scope)
+
+    def _cmp(self, other):
+        if self.addrdata == other.addrdata:
+            return 0
+        if self.addrdata > other.addrdata:
+            return 1
+        return -1
+
 _type_to_class = {
+        ECS: ECSOption
 }
 
-
 def get_option_class(otype):
     cls = _type_to_class.get(otype)
     if cls is None:
index a0df18e67f6c6557ed3796966812b783b040b324..faa29682f9709b69f3490d246f1045f39d270014 100644 (file)
@@ -209,6 +209,8 @@ class Message(object):
                 s.write(u'eflags %s\n' %
                         dns.flags.edns_to_text(self.ednsflags))
             s.write(u'payload %d\n' % self.payload)
+        for opt in self.options:
+            s.write(u'option %s\n' % opt.to_text())
         is_update = dns.opcode.is_update(self.flags)
         if is_update:
             s.write(u';ZONE\n')
diff --git a/tests/test_option.py b/tests/test_option.py
new file mode 100644 (file)
index 0000000..f91485a
--- /dev/null
@@ -0,0 +1,48 @@
+# -*- coding: utf-8
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+from __future__ import print_function
+
+try:
+    import unittest2 as unittest
+except ImportError:
+    import unittest
+
+from io import BytesIO
+
+import dns.edns
+
+class OptionTestCase(unittest.TestCase):
+    def testGenericOption(self):
+        opt = dns.edns.GenericOption(3, b'data')
+        io = BytesIO()
+        opt.to_wire(io)
+        data = io.getvalue()
+        self.assertEqual(data, b'data')
+
+    def testECSOption(self):
+        opt = dns.edns.ECSOption('1.2.3.4', 24)
+        io = BytesIO()
+        opt.to_wire(io)
+        data = io.getvalue()
+        self.assertEqual(data, b'\x00\x01\x18\x00\x01\x02\x03')
+
+    def testECSOption_v6(self):
+        opt = dns.edns.ECSOption('2001:4b98::1')
+        io = BytesIO()
+        opt.to_wire(io)
+        data = io.getvalue()
+        self.assertEqual(data, b'\x00\x02\x38\x00\x20\x01\x4b\x98\x00\x00\x00')