]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add dns.tsig.Key class.
authorBrian Wellington <bwelling@xbill.org>
Wed, 1 Jul 2020 20:06:14 +0000 (13:06 -0700)
committerBrian Wellington <bwelling@xbill.org>
Wed, 1 Jul 2020 20:06:14 +0000 (13:06 -0700)
This creates a new class to represent a TSIG key, containing name,
secret, and algorithm.

The keyring format is changed to be {name : key}, and the methods in
dns.tsigkeyring are updated to deal with old and new formats.

The Message class is updated to use dns.tsig.Key, although (to avoid
breaking existing code), it stores them in the keyring field.

Message.use_tsig() can accept either explicit keys, or keyrings; it will
extract and/or create a key.

dns.message.from_wire() can accept either a key or a keyring in the
keyring parameter.  If passed a key, it will now raise if the TSIG
record in the message was signed with a different key.  If passed a
keyring containing keys (as opposed to bare secrets), it will check that
the TSIG record's algorithm matches that of the key.

dns/message.py
dns/renderer.py
dns/resolver.py
dns/tsig.py
dns/tsigkeyring.py
dns/update.py
doc/message-class.rst
tests/test_resolution.py
tests/test_tsigkeyring.py

index 00359ef387751328fefd56268dad3cf5baa43dab..fdaec026321e8b260777be741ddffd0243d3b3c3 100644 (file)
@@ -437,9 +437,8 @@ class Message:
         r.write_header()
         if self.tsig is not None:
             (new_tsig, ctx) = dns.tsig.sign(r.get_wire(),
-                                            self.tsig.name,
+                                            self.keyring,
                                             self.tsig[0],
-                                            self.keyring[self.tsig.name],
                                             int(time.time()),
                                             self.request_mac,
                                             tsig_ctx,
@@ -463,21 +462,27 @@ class Message:
     def use_tsig(self, keyring, keyname=None, fudge=300,
                  original_id=None, tsig_error=0, other_data=b'',
                  algorithm=dns.tsig.default_algorithm):
-        """When sending, a TSIG signature using the specified keyring
-        and keyname should be added.
+        """When sending, a TSIG signature using the specified key
+        should be added.
 
-        See the documentation of the Message class for a complete
-        description of the keyring dictionary.
+        *key*, a ``dns.tsig.Key`` is the key to use.  If a key is specified,
+        the *keyring* and *algorithm* fields are not used.
 
-        *keyring*, a ``dict``, the TSIG keyring to use.  If a
-        *keyring* is specified but a *keyname* is not, then the key
-        used will be the first key in the *keyring*.  Note that the
-        order of keys in a dictionary is not defined, so applications
-        should supply a keyname when a keyring is used, unless they
-        know the keyring contains only one key.
+        *keyring*, a ``dict`` or ``dns.tsig.Key``, is either the TSIG
+        keyring or key to use.
 
-        *keyname*, a ``dns.name.Name`` or ``None``, the name of the TSIG key
-        to use; defaults to ``None``. The key must be defined in the keyring.
+        The format of a keyring dict is a mapping from TSIG key name, as
+        ``dns.name.Name`` to ``dns.tsig.Key`` or a TSIG secret, a ``bytes``.
+        If a ``dict`` *keyring* is specified but a *keyname* is not, the key
+        used will be the first key in the *keyring*.  Note that the order of
+        keys in a dictionary is not defined, so applications should supply a
+        keyname when a ``dict`` keyring is used, unless they know the keyring
+        contains only one key.
+
+        *keyname*, a ``dns.name.Name``, ``str`` or ``None``, the name of
+        thes TSIG key to use; defaults to ``None``.  If *keyring* is a
+        ``dict``, the key must be defined in it.  If *keyring* is a
+        ``dns.tsig.Key``, this is ignored.
 
         *fudge*, an ``int``, the TSIG time fudge.
 
@@ -488,18 +493,25 @@ class Message:
 
         *other_data*, a ``bytes``, the TSIG other data.
 
-        *algorithm*, a ``dns.name.Name``, the TSIG algorithm to use.
+        *algorithm*, a ``dns.name.Name``, the TSIG algorithm to use.  This is
+        only used if *keyring* is a ``dict``, and the key entry is a ``bytes``.
         """
 
-        self.keyring = keyring
-        if keyname is None:
-            keyname = list(self.keyring.keys())[0]
-        elif isinstance(keyname, str):
-            keyname = dns.name.from_text(keyname)
+        if isinstance(keyring, dns.tsig.Key):
+            self.keyring = keyring
+        else:
+            if isinstance(keyname, str):
+                keyname = dns.name.from_text(keyname)
+            if keyname is None:
+                keyname = next(iter(keyring))
+            key = keyring[keyname]
+            if isinstance(key, bytes):
+                key = dns.tsig.Key(keyname, key, algorithm)
+            self.keyring = key
         if original_id is None:
             original_id = self.id
-        self.tsig = self._make_tsig(keyname, algorithm, 0, fudge, b'',
-                                    original_id, tsig_error, other_data)
+        self.tsig = self._make_tsig(keyname, self.keyring.algorithm, 0, fudge,
+                                    b'', original_id, tsig_error, other_data)
 
     @property
     def keyname(self):
@@ -723,13 +735,15 @@ class _WireReader:
     initialize_message: Callback to set message parsing options
     question_only: Are we only reading the question?
     one_rr_per_rrset: Put each RR into its own RRset?
+    keyring: TSIG keyring
     ignore_trailing: Ignore trailing junk at end of request?
     multi: Is this message part of a multi-message sequence?
     DNS dynamic updates.
     """
 
     def __init__(self, wire, initialize_message, question_only=False,
-                 one_rr_per_rrset=False, ignore_trailing=False, multi=False):
+                 one_rr_per_rrset=False, ignore_trailing=False,
+                 keyring=None, multi=False):
         self.wire = dns.wiredata.maybe_wrap(wire)
         self.message = None
         self.current = 0
@@ -737,6 +751,7 @@ class _WireReader:
         self.question_only = question_only
         self.one_rr_per_rrset = one_rr_per_rrset
         self.ignore_trailing = ignore_trailing
+        self.keyring = keyring
         self.multi = multi
 
     def _get_question(self, section_number, qcount):
@@ -805,16 +820,22 @@ class _WireReader:
             if rdtype == dns.rdatatype.OPT:
                 self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
             elif rdtype == dns.rdatatype.TSIG:
-                if self.message.keyring is None:
+                if self.keyring is None:
                     raise UnknownTSIGKey('got signed message without keyring')
-                secret = self.message.keyring.get(absolute_name)
-                if secret is None:
+                if isinstance(self.keyring, dict):
+                    key = self.keyring.get(absolute_name)
+                    if isinstance(key, bytes):
+                        key = dns.tsig.Key(absolute_name, key, rd.algorithm)
+                else:
+                    key = self.keyring
+                if key is None:
                     raise UnknownTSIGKey("key '%s' unknown" % name)
+                self.message.keyring = key
                 self.message.tsig_ctx = \
                     dns.tsig.validate(self.wire,
+                                      key,
                                       absolute_name,
                                       rd,
-                                      secret,
                                       int(time.time()),
                                       self.message.request_mac,
                                       rr_start,
@@ -868,7 +889,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
     """Convert a DNS wire format message into a message
     object.
 
-    *keyring*, a ``dict``, the keyring to use if the message is signed.
+    *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use
+    if the message is signed.
 
     *request_mac*, a ``bytes``.  If the message is a response to a
     TSIG-signed request, *request_mac* should be set to the MAC of
@@ -918,14 +940,13 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
     """
 
     def initialize_message(message):
-        message.keyring = keyring
         message.request_mac = request_mac
         message.xfr = xfr
         message.origin = origin
         message.tsig_ctx = tsig_ctx
 
     reader = _WireReader(wire, initialize_message, question_only,
-                         one_rr_per_rrset, ignore_trailing, multi)
+                         one_rr_per_rrset, ignore_trailing, keyring, multi)
     try:
         m = reader.read()
     except dns.exception.FormError:
index be57a62f9b893be29a0280089ec91a0eb5d93c41..72f0f7a8a36d00b28efc8d0d78667241985f25df 100644 (file)
@@ -179,10 +179,14 @@ class Renderer:
 
         s = self.output.getvalue()
 
+        if isinstance(secret, dns.tsig.Key):
+            key = secret
+        else:
+            key = dns.tsig.Key(keyname, secret, algorithm)
         tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge,
                                               b'', id, tsig_error, other_data)
-        (tsig, _) = dns.tsig.sign(s, keyname, tsig[0], secret,
-                                  int(time.time()), request_mac)
+        (tsig, _) = dns.tsig.sign(s, key, tsig[0], int(time.time()),
+                                  request_mac)
         self._write_tsig(tsig, keyname)
 
     def add_multi_tsig(self, ctx, keyname, secret, fudge, id, tsig_error,
@@ -198,11 +202,14 @@ class Renderer:
 
         s = self.output.getvalue()
 
+        if isinstance(secret, dns.tsig.Key):
+            key = secret
+        else:
+            key = dns.tsig.Key(keyname, secret, algorithm)
         tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge,
                                               b'', id, tsig_error, other_data)
-        (tsig, ctx) = dns.tsig.sign(s, keyname, tsig[0], secret,
-                                    int(time.time()), request_mac,
-                                    ctx, True)
+        (tsig, ctx) = dns.tsig.sign(s, key, tsig[0], int(time.time()),
+                                    request_mac, ctx, True)
         self._write_tsig(tsig, keyname)
         return ctx
 
index f4a07b48e2d13c38b5378abeed41d075e803dcb9..62d019851ec9a50941df85a9e6c915e8b47b0ef7 100644 (file)
@@ -1111,29 +1111,14 @@ class Resolver:
 
     def use_tsig(self, keyring, keyname=None,
                  algorithm=dns.tsig.default_algorithm):
-        """Add a TSIG signature to the query.
+        """Add a TSIG signature to each query.
 
-        See the documentation of the Message class for a complete
-        description of the keyring dictionary.
-
-        *keyring*, a ``dict``, the TSIG keyring to use.  If a
-        *keyring* is specified but a *keyname* is not, then the key
-        used will be the first key in the *keyring*.  Note that the
-        order of keys in a dictionary is not defined, so applications
-        should supply a keyname when a keyring is used, unless they
-        know the keyring contains only one key.
-
-        *keyname*, a ``dns.name.Name`` or ``None``, the name of the TSIG key
-        to use; defaults to ``None``. The key must be defined in the keyring.
-
-        *algorithm*, a ``dns.name.Name``, the TSIG algorithm to use.
+        The parameters are passed to ``dns.message.Message.use_tsig()``;
+        see its documentation for details.
         """
 
         self.keyring = keyring
-        if keyname is None:
-            self.keyname = list(self.keyring.keys())[0]
-        else:
-            self.keyname = keyname
+        self.keyname = keyname
         self.keyalgorithm = algorithm
 
     def use_edns(self, edns, ednsflags, payload):
index 12cbae68f260068faa89f7722dba489172356d51..c3c849c7dba484ebc0753458f273eae6894e927d 100644 (file)
@@ -17,6 +17,7 @@
 
 """DNS TSIG support."""
 
+import base64
 import hashlib
 import hmac
 import struct
@@ -35,6 +36,16 @@ class BadSignature(dns.exception.DNSException):
     """The TSIG signature fails to verify."""
 
 
+class BadKey(dns.exception.DNSException):
+
+    """The TSIG record owner name does not match the key."""
+
+
+class BadAlgorithm(dns.exception.DNSException):
+
+    """The TSIG algorithm does not match the key."""
+
+
 class PeerError(dns.exception.DNSException):
 
     """Base class for all TSIG errors generated by the remote peer"""
@@ -85,8 +96,7 @@ BADTIME = 18
 BADTRUNC = 22
 
 
-def sign(wire, keyname, rdata, secret, time=None, request_mac=None,
-         ctx=None, multi=False):
+def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False):
     """Return a (tsig_rdata, mac, ctx) tuple containing the HMAC TSIG rdata
     for the input parameters, the HMAC MAC calculated by applying the
     TSIG signature algorithm, and the TSIG digest context.
@@ -98,14 +108,14 @@ def sign(wire, keyname, rdata, secret, time=None, request_mac=None,
     first = not (ctx and multi)
     (algorithm_name, digestmod) = get_algorithm(rdata.algorithm)
     if first:
-        ctx = hmac.new(secret, digestmod=digestmod)
+        ctx = hmac.new(key.secret, digestmod=digestmod)
         if request_mac:
             ctx.update(struct.pack('!H', len(request_mac)))
             ctx.update(request_mac)
     ctx.update(struct.pack('!H', rdata.original_id))
     ctx.update(wire[2:])
     if first:
-        ctx.update(keyname.to_digestable())
+        ctx.update(key.name.to_digestable())
         ctx.update(struct.pack('!H', dns.rdataclass.ANY))
         ctx.update(struct.pack('!I', 0))
     if time is None:
@@ -123,7 +133,7 @@ def sign(wire, keyname, rdata, secret, time=None, request_mac=None,
         ctx.update(time_encoded)
     mac = ctx.digest()
     if multi:
-        ctx = hmac.new(secret, digestmod=digestmod)
+        ctx = hmac.new(key.secret, digestmod=digestmod)
         ctx.update(struct.pack('!H', len(mac)))
         ctx.update(mac)
     else:
@@ -136,8 +146,8 @@ def sign(wire, keyname, rdata, secret, time=None, request_mac=None,
     return (tsig, ctx)
 
 
-def validate(wire, keyname, rdata, secret, now, request_mac, tsig_start,
-             ctx=None, multi=False):
+def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None,
+             multi=False):
     """Validate the specified TSIG rdata against the other input parameters.
 
     @raises FormError: The TSIG is badly formed.
@@ -164,8 +174,11 @@ def validate(wire, keyname, rdata, secret, now, request_mac, tsig_start,
             raise PeerError('unknown TSIG error code %d' % rdata.error)
     if abs(rdata.time_signed - now) > rdata.fudge:
         raise BadTime
-    (our_rdata, ctx) = sign(new_wire, keyname, rdata, secret, None, request_mac,
-                            ctx, multi)
+    if key.name != owner:
+        raise BadKey
+    if key.algorithm != rdata.algorithm:
+        raise BadAlgorithm
+    (our_rdata, ctx) = sign(new_wire, key, rdata, None, request_mac, ctx, multi)
     if our_rdata.mac != rdata.mac:
         raise BadSignature
     return ctx
@@ -187,3 +200,19 @@ def get_algorithm(algorithm):
     except KeyError:
         raise NotImplementedError("TSIG algorithm " + str(algorithm) +
                                   " is not supported")
+
+class Key:
+    def __init__(self, name, secret, algorithm=default_algorithm):
+        if isinstance(name, str):
+            name = dns.name.from_text(name)
+        self.name = name
+        if isinstance(secret, str):
+            secret = base64.decodebytes(secret.encode())
+        self.secret = secret
+        self.algorithm = algorithm
+
+    def __eq__(self, other):
+        return (isinstance(other, Key) and
+                self.name == other.name and
+                self.secret == other.secret and
+                self.algorithm == other.algorithm)
index 32baf803e318facca11b42490d3786d1c17d47ca..b93bdb76db0a9c5c442bad077a9d724d2574bca4 100644 (file)
@@ -23,27 +23,41 @@ import dns.name
 
 
 def from_text(textring):
-    """Convert a dictionary containing (textual DNS name, base64 secret) pairs
-    into a binary keyring which has (dns.name.Name, binary secret) pairs.
+    """Convert a dictionary containing (textual DNS name, base64 secret)
+    or (textual DNS name, (algorithm, base64 secret)) where algorithm
+    can be a dns.name.Name or string into a binary keyring which has
+    (dns.name.Name, dns.tsig.Key) pairs.
     @rtype: dict"""
 
     keyring = {}
-    for keytext in textring:
-        keyname = dns.name.from_text(keytext)
-        secret = base64.decodebytes(textring[keytext].encode())
-        keyring[keyname] = secret
+    for (name, value) in textring.items():
+        name = dns.name.from_text(name)
+        if isinstance(value, str):
+            algorithm = dns.tsig.default_algorithm
+            secret = value
+        else:
+            (algorithm, secret) = value
+            if isinstance(algorithm, str):
+                algorithm = dns.name.from_text(algorithm)
+        keyring[name] = dns.tsig.Key(name, secret, algorithm)
     return keyring
 
 
 def to_text(keyring):
-    """Convert a dictionary containing (dns.name.Name, binary secret) pairs
-    into a text keyring which has (textual DNS name, base64 secret) pairs.
+    """Convert a dictionary containing (dns.name.Name, dns.tsig.Key) pairs
+    into a text keyring which has (textual DNS name, (textual algorithm,
+    base64 secret)) pairs.
     @rtype: dict"""
 
     textring = {}
-    for keyname in keyring:
-        keytext = keyname.to_text()
-        # rstrip to get rid of the \n encoding adds
-        secret = base64.encodebytes(keyring[keyname]).decode().rstrip()
-        textring[keytext] = secret
+    for (name, key) in keyring.items():
+        name = name.to_text()
+        if isinstance(key, bytes):
+            algorithm = dns.tsig.default_algorithm
+            secret = key
+        else:
+            algorithm = key.algorithm
+            secret = key.secret
+        textring[name] = (algorithm.to_text(),
+                          base64.encodebytes(secret).decode().rstrip())
     return textring
index 130577d12c785adbd4c4102c40d12d6d7cc4c716..8e796504f36a2f32279089c43dd7244351c1c1b5 100644 (file)
@@ -60,18 +60,8 @@ class UpdateMessage(dns.message.Message):
 
         *rdclass*, an ``int`` or ``str``, the class of the zone.
 
-        *keyring*, a ``dict``, the TSIG keyring to use.  If a
-        *keyring* is specified but a *keyname* is not, then the key
-        used will be the first key in the *keyring*.  Note that the
-        order of keys in a dictionary is not defined, so applications
-        should supply a keyname when a keyring is used, unless they
-        know the keyring contains only one key.
-
-        *keyname*, a ``dns.name.Name`` or ``None``, the name of the TSIG key
-        to use; defaults to ``None``. The key must be defined in the keyring.
-
-        *keyalgorithm*, a ``dns.name.Name``, the TSIG algorithm to use.
-
+        The *keyring*, *keyname*, and *keyalgorithm* parameters are passed to
+        ``use_tsig()``; see its documentation for details.
         """
         super().__init__(id=id)
         self.flags |= dns.opcode.to_flags(dns.opcode.UPDATE)
index b235d900d1f4078edb9edc8c2910ebc0a4384664..08d99586d462a5c0a0683a2ff47c01039a057c51 100644 (file)
@@ -47,9 +47,7 @@ DNS opcodes that do not have a more specific class.
 
    .. attribute:: keyring
 
-      The TSIG keyring to use.  The default is `None`.  A TSIG keyring
-      is a dictionary mapping from TSIG key name, a ``dns.name.Name``, to
-      a TSIG secret, a ``bytes``.
+      A ``dns.tsig.Key``, the TSIG key.  The default is None.
 
    .. attribute:: keyname
 
index 1ba4463760ca8673998107c967248f24d53a2bb5..aa1cd0cbb10a696014b6f7549a8a9c3db4582462 100644 (file)
@@ -197,11 +197,11 @@ class ResolutionTestCase(unittest.TestCase):
         self.resolver.keyring = dns.tsigkeyring.from_text({
             'keyname.' : 'NjHwPsMKjdN++dOfE5iAiQ=='
         })
+        key = next(iter(self.resolver.keyring.values()))
         self.resolver.keyname = dns.name.from_text('keyname.')
         (request, answer) = self.resn.next_request()
         self.assertFalse(request is None)
-        self.assertEqual(request.keyring, self.resolver.keyring)
-        self.assertEqual(request.keyname, self.resolver.keyname)
+        self.assertEqual(request.keyring, key)
 
     def test_next_request_flags(self):
         self.resolver.flags = dns.flags.RD | dns.flags.CD
index ce8888d94cbc5c8d6f262a0f79280c7045bae889..25c41ccab42621a174afda0355af4529c92f1c5b 100644 (file)
@@ -3,17 +3,27 @@
 import base64
 import unittest
 
+import dns.tsig
 import dns.tsigkeyring
 
 text_keyring = {
+    'keyname.' : ('hmac-sha256.', 'NjHwPsMKjdN++dOfE5iAiQ==')
+}
+
+old_text_keyring = {
     'keyname.' : 'NjHwPsMKjdN++dOfE5iAiQ=='
 }
 
-rich_keyring = {
-    dns.name.from_text('keyname.') : \
-    base64.decodebytes('NjHwPsMKjdN++dOfE5iAiQ=='.encode())
+alt_text_keyring = {
+    'keyname.' : (dns.tsig.HMAC_SHA256, 'NjHwPsMKjdN++dOfE5iAiQ==')
 }
 
+key = dns.tsig.Key('keyname.', 'NjHwPsMKjdN++dOfE5iAiQ==')
+
+rich_keyring = { key.name : key }
+
+old_rich_keyring = { key.name : key.secret }
+
 class TSIGKeyRingTestCase(unittest.TestCase):
 
     def test_from_text(self):
@@ -21,11 +31,26 @@ class TSIGKeyRingTestCase(unittest.TestCase):
         rkeyring = dns.tsigkeyring.from_text(text_keyring)
         self.assertEqual(rkeyring, rich_keyring)
 
+    def test_from_old_text(self):
+        """old format text keyring -> rich keyring"""
+        rkeyring = dns.tsigkeyring.from_text(old_text_keyring)
+        self.assertEqual(rkeyring, rich_keyring)
+
+    def test_from_alt_text(self):
+        """alternate format text keyring -> rich keyring"""
+        rkeyring = dns.tsigkeyring.from_text(alt_text_keyring)
+        self.assertEqual(rkeyring, rich_keyring)
+
     def test_to_text(self):
         """text keyring -> rich keyring -> text keyring"""
         tkeyring = dns.tsigkeyring.to_text(rich_keyring)
         self.assertEqual(tkeyring, text_keyring)
 
+    def test_old_to_text(self):
+        """text keyring -> rich keyring -> text keyring"""
+        tkeyring = dns.tsigkeyring.to_text(old_rich_keyring)
+        self.assertEqual(tkeyring, text_keyring)
+
     def test_from_and_to_text(self):
         """text keyring -> rich keyring -> text keyring"""
         rkeyring = dns.tsigkeyring.from_text(text_keyring)