]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Unicode label escapify was not escapifying special characters.
authorBob Halley <halley@dnspython.org>
Sun, 6 Jan 2019 01:36:49 +0000 (17:36 -0800)
committerBob Halley <halley@dnspython.org>
Sun, 6 Jan 2019 01:36:49 +0000 (17:36 -0800)
[Issue #339]

This commit also simplifies code and changes u'string' to 'string'.

dns/name.py
tests/test_name.py

index 84968c16321b7a88b818441b37fff76781de2afa..cd465cd5ce8ce21e7762916fd6dc337a0ee07c4f 100644 (file)
@@ -108,20 +108,28 @@ class IDNACodec(object):
     def __init__(self):
         pass
 
+    def is_idna(self, label):
+        return label.lower().startswith(b'xn--')
+
+    def is_all_ascii(self, label):
+        for c in label:
+            if ord(c) > 0x7f:
+                return False
+        return True
+
     def encode(self, label):
         raise NotImplementedError
 
     def decode(self, label):
-        # We do not apply any IDNA policy on decode; we just
-        downcased = label.lower()
-        if downcased.startswith(b'xn--'):
+        # We do not apply any IDNA policy on decode.
+        if self.is_idna(label):
             try:
-                label = downcased[4:].decode('punycode')
+                label = label[4:].decode('punycode')
             except Exception as e:
                 raise IDNAException(idna_exception=e)
         else:
             label = label.decode()
-        return _escapify(label, True)
+        return _escapify(label)
 
 
 class IDNA2003Codec(IDNACodec):
@@ -153,9 +161,9 @@ class IDNA2003Codec(IDNACodec):
         if not self.strict_decode:
             return super(IDNA2003Codec, self).decode(label)
         if label == b'':
-            return u''
+            return ''
         try:
-            return _escapify(encodings.idna.ToUnicode(label), True)
+            return _escapify(encodings.idna.ToUnicode(label))
         except Exception as e:
             raise IDNAException(idna_exception=e)
 
@@ -193,12 +201,6 @@ class IDNA2008Codec(IDNACodec):
         self.allow_pure_ascii = allow_pure_ascii
         self.strict_decode = strict_decode
 
-    def is_all_ascii(self, label):
-        for c in label:
-            if ord(c) > 0x7f:
-                return False
-        return True
-
     def encode(self, label):
         if label == '':
             return b''
@@ -217,17 +219,18 @@ class IDNA2008Codec(IDNACodec):
         if not self.strict_decode:
             return super(IDNA2008Codec, self).decode(label)
         if label == b'':
-            return u''
+            return ''
         if not have_idna_2008:
             raise NoIDNA2008
         try:
             if self.uts_46:
                 label = idna.uts46_remap(label, False, False)
-            return _escapify(idna.ulabel(label), True)
+            return _escapify(idna.ulabel(label))
         except idna.IDNAError as e:
             raise IDNAException(idna_exception=e)
 
 _escaped = b'"().;\\@$'
+_escaped_text = '"().;\\@$'
 
 IDNA_2003_Practical = IDNA2003Codec(False)
 IDNA_2003_Strict = IDNA2003Codec(True)
@@ -238,13 +241,13 @@ IDNA_2008_Strict = IDNA2008Codec(False, False, False, True)
 IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False)
 IDNA_2008 = IDNA_2008_Practical
 
-def _escapify(label, unicode_mode=False):
+def _escapify(label):
     """Escape the characters in label which need it.
-    @param unicode_mode: escapify only special and whitespace (<= 0x20)
-    characters
     @returns: the escaped string
     @rtype: string"""
-    if not unicode_mode:
+    if isinstance(label, bytes):
+        # Ordinary DNS label mode.  Escape special characters and values
+        # < 0x20 or > 0x7f.
         text = ''
         if isinstance(label, str):
             label = label.encode()
@@ -255,19 +258,17 @@ def _escapify(label, unicode_mode=False):
                 text += chr(c)
             else:
                 text += '\\%03d' % c
-        return text.encode()
+        return text
 
-    text = u''
-    if isinstance(label, bytes):
-        label = label.decode()
+    # Unicode label mode.  Escape only special characters and values < 0x20
+    text = ''
     for c in label:
-        if c > u'\x20' and c < u'\x7f':
-            text += c
+        if c in _escaped_text:
+            text += '\\' + c
+        elif c <= '\x20':
+            text += '\\%03d' % ord(c)
         else:
-            if c >= u'\x7f':
-                text += c
-            else:
-                text += u'\\%03d' % ord(c)
+            text += c
     return text
 
 def _validate_labels(labels):
@@ -549,8 +550,8 @@ class Name(object):
             l = self.labels[:-1]
         else:
             l = self.labels
-        s = b'.'.join(map(_escapify, l))
-        return s.decode()
+        s = '.'.join(map(_escapify, l))
+        return s
 
     def to_unicode(self, omit_final_dot=False, idna_codec=None):
         """Convert name to Unicode text format.
@@ -571,16 +572,16 @@ class Name(object):
         """
 
         if len(self.labels) == 0:
-            return u'@'
+            return '@'
         if len(self.labels) == 1 and self.labels[0] == b'':
-            return u'.'
+            return '.'
         if omit_final_dot and self.is_absolute():
             l = self.labels[:-1]
         else:
             l = self.labels
         if idna_codec is None:
             idna_codec = IDNA_2003_Practical
-        return u'.'.join([idna_codec.decode(x) for x in l])
+        return '.'.join([idna_codec.decode(x) for x in l])
 
     def to_digestable(self, origin=None):
         """Convert name to a format suitable for digesting in hashes.
@@ -816,16 +817,16 @@ def from_unicode(text, origin=root, idna_codec=None):
     if not (origin is None or isinstance(origin, Name)):
         raise ValueError("origin must be a Name or None")
     labels = []
-    label = u''
+    label = ''
     escaping = False
     edigits = 0
     total = 0
     if idna_codec is None:
         idna_codec = IDNA_2003
-    if text == u'@':
-        text = u''
+    if text == '@':
+        text = ''
     if text:
-        if text == u'.':
+        if text == '.':
             return Name([b''])        # no Unicode "u" on this constant!
         for c in text:
             if escaping:
@@ -845,12 +846,12 @@ def from_unicode(text, origin=root, idna_codec=None):
                     if edigits == 3:
                         escaping = False
                         label += chr(total)
-            elif c in [u'.', u'\u3002', u'\uff0e', u'\uff61']:
+            elif c in ['.', '\u3002', '\uff0e', '\uff61']:
                 if len(label) == 0:
                     raise EmptyLabel
                 labels.append(idna_codec.encode(label))
-                label = u''
-            elif c == u'\\':
+                label = ''
+            elif c == '\\':
                 escaping = True
                 edigits = 0
                 total = 0
index 62a94973931b754bde710742637665486e25c96d..02a67c28a1c15ae9766f0ee193859e35e1680970 100644 (file)
@@ -255,6 +255,11 @@ class NameTestCase(unittest.TestCase):
         t = dns.name.root.to_unicode()
         self.assertEqual(t, '.')
 
+    def testToText12(self):
+        n = dns.name.from_text(r'a\.b.c')
+        t = n.to_unicode()
+        self.assertEqual(t, r'a\.b.c.')
+
     def testSlice1(self):
         n = dns.name.from_text(r'a.b.c.', origin=None)
         s = n[:]