]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Better deal with backwards compatibility. 527/head
authorBrian Wellington <bwelling@xbill.org>
Wed, 1 Jul 2020 21:58:14 +0000 (14:58 -0700)
committerBrian Wellington <bwelling@xbill.org>
Wed, 1 Jul 2020 22:01:04 +0000 (15:01 -0700)
If dns.tsigkeyring.from_text() creates dns.tsig.Key objects with the
default algorithm, that causes problems for code that specifies a
different algorithm.  There's no good way to handle this, so change
dns.tsigkeyring.from_text() to not create dns.tsig.Key objects unless it
knows the algorithm.

dns/tsig.py
dns/tsigkeyring.py
tests/test_resolution.py
tests/test_tsigkeyring.py

index 89183cf03ea5432d8897e871ba780bda8bb95784..08ab41e45689b1e9618b72887cc78ab4794a9c90 100644 (file)
@@ -209,6 +209,8 @@ class Key:
         if isinstance(secret, str):
             secret = base64.decodebytes(secret.encode())
         self.secret = secret
+        if isinstance(algorithm, str):
+            algorithm = dns.name.from_text(algorithm)
         self.algorithm = algorithm
 
     def __eq__(self, other):
index b93bdb76db0a9c5c442bad077a9d724d2574bca4..aa3cae92f8649218a67ddcf2de0bf89c5a84c150 100644 (file)
@@ -24,40 +24,37 @@ import dns.name
 
 def from_text(textring):
     """Convert a dictionary containing (textual DNS name, base64 secret)
-    or (textual DNS name, (algorithm, base64 secret)) where algorithm
-    can be a dns.name.Name or string into a binary keyring which has
-    (dns.name.Name, dns.tsig.Key) pairs.
+    pairs into a binary keyring which has (dns.name.Name, bytes) pairs, or
+    a dictionary containing (textual DNS name, (algorithm, base64 secret))
+    pairs into a binary keyring which has (dns.name.Name, dns.tsig.Key) pairs.
     @rtype: dict"""
 
     keyring = {}
     for (name, value) in textring.items():
         name = dns.name.from_text(name)
         if isinstance(value, str):
-            algorithm = dns.tsig.default_algorithm
-            secret = value
+            keyring[name] = dns.tsig.Key(name, value).secret
         else:
             (algorithm, secret) = value
-            if isinstance(algorithm, str):
-                algorithm = dns.name.from_text(algorithm)
-        keyring[name] = dns.tsig.Key(name, secret, algorithm)
+            keyring[name] = dns.tsig.Key(name, secret, algorithm)
     return keyring
 
 
 def to_text(keyring):
     """Convert a dictionary containing (dns.name.Name, dns.tsig.Key) pairs
     into a text keyring which has (textual DNS name, (textual algorithm,
-    base64 secret)) pairs.
+    base64 secret)) pairs, or a dictionary containing (dns.name.Name, bytes)
+    pairs into a text keyring which has (textual DNS name, base64 secret) pairs.
     @rtype: dict"""
 
     textring = {}
+    b64encode = lambda secret: base64.encodebytes(secret).decode().rstrip()
     for (name, key) in keyring.items():
         name = name.to_text()
         if isinstance(key, bytes):
-            algorithm = dns.tsig.default_algorithm
-            secret = key
+            textring[name] = b64encode(key)
         else:
             algorithm = key.algorithm
             secret = key.secret
-        textring[name] = (algorithm.to_text(),
-                          base64.encodebytes(secret).decode().rstrip())
+            textring[name] = (key.algorithm.to_text(), b64encode(key.secret))
     return textring
index aa1cd0cbb10a696014b6f7549a8a9c3db4582462..9145f167a355dcbec298097557c4c7ef575f2312 100644 (file)
@@ -197,11 +197,12 @@ class ResolutionTestCase(unittest.TestCase):
         self.resolver.keyring = dns.tsigkeyring.from_text({
             'keyname.' : 'NjHwPsMKjdN++dOfE5iAiQ=='
         })
-        key = next(iter(self.resolver.keyring.values()))
+        (keyname, secret) = next(iter(self.resolver.keyring.items()))
         self.resolver.keyname = dns.name.from_text('keyname.')
         (request, answer) = self.resn.next_request()
         self.assertFalse(request is None)
-        self.assertEqual(request.keyring, key)
+        self.assertEqual(request.keyring.name, keyname)
+        self.assertEqual(request.keyring.secret, secret)
 
     def test_next_request_flags(self):
         self.resolver.flags = dns.flags.RD | dns.flags.CD
index 25c41ccab42621a174afda0355af4529c92f1c5b..47f88067b6eb64166cbad31db843d0b2ba797c53 100644 (file)
@@ -10,14 +10,14 @@ text_keyring = {
     'keyname.' : ('hmac-sha256.', 'NjHwPsMKjdN++dOfE5iAiQ==')
 }
 
-old_text_keyring = {
-    'keyname.' : 'NjHwPsMKjdN++dOfE5iAiQ=='
-}
-
 alt_text_keyring = {
     'keyname.' : (dns.tsig.HMAC_SHA256, 'NjHwPsMKjdN++dOfE5iAiQ==')
 }
 
+old_text_keyring = {
+    'keyname.' : 'NjHwPsMKjdN++dOfE5iAiQ=='
+}
+
 key = dns.tsig.Key('keyname.', 'NjHwPsMKjdN++dOfE5iAiQ==')
 
 rich_keyring = { key.name : key }
@@ -31,16 +31,16 @@ class TSIGKeyRingTestCase(unittest.TestCase):
         rkeyring = dns.tsigkeyring.from_text(text_keyring)
         self.assertEqual(rkeyring, rich_keyring)
 
-    def test_from_old_text(self):
-        """old format text keyring -> rich keyring"""
-        rkeyring = dns.tsigkeyring.from_text(old_text_keyring)
-        self.assertEqual(rkeyring, rich_keyring)
-
     def test_from_alt_text(self):
         """alternate format text keyring -> rich keyring"""
         rkeyring = dns.tsigkeyring.from_text(alt_text_keyring)
         self.assertEqual(rkeyring, rich_keyring)
 
+    def test_from_old_text(self):
+        """old format text keyring -> rich keyring"""
+        rkeyring = dns.tsigkeyring.from_text(old_text_keyring)
+        self.assertEqual(rkeyring, old_rich_keyring)
+
     def test_to_text(self):
         """text keyring -> rich keyring -> text keyring"""
         tkeyring = dns.tsigkeyring.to_text(rich_keyring)
@@ -49,10 +49,16 @@ class TSIGKeyRingTestCase(unittest.TestCase):
     def test_old_to_text(self):
         """text keyring -> rich keyring -> text keyring"""
         tkeyring = dns.tsigkeyring.to_text(old_rich_keyring)
-        self.assertEqual(tkeyring, text_keyring)
+        self.assertEqual(tkeyring, old_text_keyring)
 
     def test_from_and_to_text(self):
         """text keyring -> rich keyring -> text keyring"""
         rkeyring = dns.tsigkeyring.from_text(text_keyring)
         tkeyring = dns.tsigkeyring.to_text(rkeyring)
         self.assertEqual(tkeyring, text_keyring)
+
+    def test_old_from_and_to_text(self):
+        """text keyring -> rich keyring -> text keyring"""
+        rkeyring = dns.tsigkeyring.from_text(old_text_keyring)
+        tkeyring = dns.tsigkeyring.to_text(rkeyring)
+        self.assertEqual(tkeyring, old_text_keyring)