]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
a way of doing comments 543/head
authorBob Halley <halley@dnspython.org>
Sat, 18 Jul 2020 20:07:04 +0000 (13:07 -0700)
committerBob Halley <halley@dnspython.org>
Mon, 20 Jul 2020 20:41:22 +0000 (13:41 -0700)
34 files changed:
dns/rdata.py
dns/rdataset.py
dns/rdtypes/ANY/GPOS.py
dns/rdtypes/ANY/HINFO.py
dns/rdtypes/ANY/HIP.py
dns/rdtypes/ANY/ISDN.py
dns/rdtypes/ANY/LOC.py
dns/rdtypes/ANY/NSEC3PARAM.py
dns/rdtypes/ANY/RP.py
dns/rdtypes/ANY/SOA.py
dns/rdtypes/ANY/URI.py
dns/rdtypes/ANY/X25.py
dns/rdtypes/CH/A.py
dns/rdtypes/IN/A.py
dns/rdtypes/IN/AAAA.py
dns/rdtypes/IN/APL.py
dns/rdtypes/IN/NAPTR.py
dns/rdtypes/IN/NSAP.py
dns/rdtypes/IN/PX.py
dns/rdtypes/IN/SRV.py
dns/rdtypes/IN/WKS.py
dns/rdtypes/euibase.py
dns/rdtypes/mxbase.py
dns/rdtypes/nsbase.py
dns/rdtypes/txtbase.py
dns/rdtypes/util.py
dns/tokenizer.py
dns/zone.py
tests/mx-2-0.pickle [new file with mode: 0644]
tests/test_message.py
tests/test_rdata.py
tests/test_zone.py
tests/util.py [new file with mode: 0644]
util/generate-mx-pickle.py [new file with mode: 0644]

index e114fe32bcbf162b8d89f73e18bce6a43f098b4c..98ded794db519ff2542d407d64146f8a2489f668 100644 (file)
@@ -111,7 +111,7 @@ def _constify(o):
 class Rdata:
     """Base class for all DNS rdata types."""
 
-    __slots__ = ['rdclass', 'rdtype']
+    __slots__ = ['rdclass', 'rdtype', 'rdcomment']
 
     def __init__(self, rdclass, rdtype):
         """Initialize an rdata.
@@ -123,6 +123,7 @@ class Rdata:
 
         object.__setattr__(self, 'rdclass', rdclass)
         object.__setattr__(self, 'rdtype', rdtype)
+        object.__setattr__(self, 'rdcomment', None)
 
     def __setattr__(self, name, value):
         # Rdatas are immutable
@@ -153,6 +154,10 @@ class Rdata:
     def __setstate__(self, state):
         for slot, val in state.items():
             object.__setattr__(self, slot, val)
+        if not hasattr(self, 'rdcomment'):
+            # Pickled rdata from 2.0.x might not have a rdcomment, so add
+            # it if needed.
+            object.__setattr__(self, 'rdcomment', None)
 
     def covers(self):
         """Return the type a Rdata covers.
@@ -319,6 +324,8 @@ class Rdata:
         # Ensure that all of the arguments correspond to valid fields.
         # Don't allow rdclass or rdtype to be changed, though.
         for key in kwargs:
+            if key == 'rdcomment':
+                continue
             if key not in parameters:
                 raise AttributeError("'{}' object has no attribute '{}'"
                                      .format(self.__class__.__name__, key))
@@ -336,6 +343,11 @@ class Rdata:
         # this validation can go away.
         rd = self.__class__(*args)
         dns.rdata.from_text(rd.rdclass, rd.rdtype, rd.to_text())
+        # The comment is not set in the constructor, so give it special
+        # handling.
+        rdcomment = kwargs.get('rdcomment', self.rdcomment)
+        if rdcomment is not None:
+            object.__setattr__(rd, 'rdcomment', rdcomment)
         return rd
 
 
@@ -364,13 +376,7 @@ class GenericRdata(Rdata):
             raise dns.exception.SyntaxError(
                 r'generic rdata does not start with \#')
         length = tok.get_int()
-        chunks = []
-        while 1:
-            token = tok.get()
-            if token.is_eol_or_eof():
-                break
-            chunks.append(token.value.encode())
-        hex = b''.join(chunks)
+        hex = tok.concatenate_remaining_identifiers().encode()
         data = binascii.unhexlify(hex)
         if len(data) != length:
             raise dns.exception.SyntaxError(
@@ -459,6 +465,7 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
     rdclass = dns.rdataclass.RdataClass.make(rdclass)
     rdtype = dns.rdatatype.RdataType.make(rdtype)
     cls = get_rdata_class(rdclass, rdtype)
+    rdata = None
     if cls != GenericRdata:
         # peek at first token
         token = tok.get()
@@ -470,12 +477,17 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
             # wire form from the generic syntax, and then run
             # from_wire on it.
             #
-            rdata = GenericRdata.from_text(rdclass, rdtype, tok, origin,
-                                           relativize, relativize_to)
-            return from_wire(rdclass, rdtype, rdata.data, 0, len(rdata.data),
-                             origin)
-    return cls.from_text(rdclass, rdtype, tok, origin, relativize,
-                         relativize_to)
+            grdata = GenericRdata.from_text(rdclass, rdtype, tok, origin,
+                                            relativize, relativize_to)
+            rdata = from_wire(rdclass, rdtype, grdata.data, 0, len(grdata.data),
+                              origin)
+    if rdata is None:
+        rdata = cls.from_text(rdclass, rdtype, tok, origin, relativize,
+                              relativize_to)
+    token = tok.get_eol_as_token()
+    if token.comment is not None:
+        object.__setattr__(rdata, 'rdcomment', token.comment)
+    return rdata
 
 
 def from_wire_parser(rdclass, rdtype, parser, origin=None):
index 660415e78314ed9f698034276b9a7d15b053a8cf..8e70a08cd130be95285a41b750e6f24a1d891a2f 100644 (file)
@@ -176,7 +176,7 @@ class Rdataset(dns.set.Set):
         return not self.__eq__(other)
 
     def to_text(self, name=None, origin=None, relativize=True,
-                override_rdclass=None, **kw):
+                override_rdclass=None, want_comments=False, **kw):
         """Convert the rdataset into DNS master file format.
 
         See ``dns.name.Name.choose_relativity`` for more information
@@ -194,6 +194,9 @@ class Rdataset(dns.set.Set):
 
         *relativize*, a ``bool``.  If ``True``, names will be relativized
         to *origin*.
+
+        *want_comments*, a ``bool``.  If ``True``, emit comments for rdata
+        which have them.  The default is ``False``.
         """
 
         if name is not None:
@@ -219,11 +222,16 @@ class Rdataset(dns.set.Set):
                                          dns.rdatatype.to_text(self.rdtype)))
         else:
             for rd in self:
-                s.write('%s%s%d %s %s %s\n' %
+                extra = ''
+                if want_comments:
+                    if rd.rdcomment:
+                        extra = f' ;{rd.rdcomment}'
+                s.write('%s%s%d %s %s %s%s\n' %
                         (ntext, pad, self.ttl, dns.rdataclass.to_text(rdclass),
                          dns.rdatatype.to_text(self.rdtype),
                          rd.to_text(origin=origin, relativize=relativize,
-                         **kw)))
+                                    **kw),
+                         extra))
         #
         # We strip off the final \n for the caller's convenience in printing
         #
index 03677fd22cdd10256706365850fec26969b14bc3..8285b3fcb47e1fee3f7a05c6a59117faf9c0deb0 100644 (file)
@@ -93,7 +93,6 @@ class GPOS(dns.rdata.Rdata):
         latitude = tok.get_string()
         longitude = tok.get_string()
         altitude = tok.get_string()
-        tok.get_eol()
         return cls(rdclass, rdtype, latitude, longitude, altitude)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
index 587e0ad169249c6b2da66c3e7f30ec9f9e1b9e74..6c1ccfae77a9b2145d0c513a34cfb792a5b45d3a 100644 (file)
@@ -50,7 +50,6 @@ class HINFO(dns.rdata.Rdata):
                   relativize_to=None):
         cpu = tok.get_string(max_length=255)
         os = tok.get_string(max_length=255)
-        tok.get_eol()
         return cls(rdclass, rdtype, cpu, os)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
index 1c774bbff4eab998e5b99baee67f99272e514e84..437ee73381df80b445702b95903e89412e0f45c7 100644 (file)
@@ -59,10 +59,7 @@ class HIP(dns.rdata.Rdata):
             raise dns.exception.SyntaxError("HIT too long")
         key = base64.b64decode(tok.get_string().encode())
         servers = []
-        while 1:
-            token = tok.get()
-            if token.is_eol_or_eof():
-                break
+        for token in tok.get_remaining():
             server = tok.as_name(token, origin, relativize, relativize_to)
             servers.append(server)
         return cls(rdclass, rdtype, hit, algorithm, key, servers)
index 6834b3c7c964ddcf474e9517dd0b536862c5383c..b07594f9add164cdd7cb26dcfee5cf94d5b972b5 100644 (file)
@@ -52,14 +52,11 @@ class ISDN(dns.rdata.Rdata):
     def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
                   relativize_to=None):
         address = tok.get_string()
-        t = tok.get()
-        if not t.is_eol_or_eof():
-            tok.unget(t)
-            subaddress = tok.get_string()
+        tokens = tok.get_remaining(max_tokens=1)
+        if len(tokens) >= 1:
+            subaddress = tokens[0].unescape().value
         else:
-            tok.unget(t)
             subaddress = ''
-        tok.get_eol()
         return cls(rdclass, rdtype, address, subaddress)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
index eb00a1cd03dcde76be58b7b653a20eb3c6c8ceaf..06027346cc1a55e07c8481f301f3fd016ff9bec5 100644 (file)
@@ -245,25 +245,22 @@ class LOC(dns.rdata.Rdata):
             t = t[0: -1]
         altitude = float(t) * 100.0        # m -> cm
 
-        token = tok.get().unescape()
-        if not token.is_eol_or_eof():
-            value = token.value
+        tokens = tok.get_remaining(max_tokens=3)
+        if len(tokens) >= 1:
+            value = tokens[0].unescape().value
             if value[-1] == 'm':
                 value = value[0: -1]
             size = float(value) * 100.0        # m -> cm
-            token = tok.get().unescape()
-            if not token.is_eol_or_eof():
-                value = token.value
+            if len(tokens) >= 2:
+                value = tokens[1].unescape().value
                 if value[-1] == 'm':
                     value = value[0: -1]
                 hprec = float(value) * 100.0        # m -> cm
-                token = tok.get().unescape()
-                if not token.is_eol_or_eof():
-                    value = token.value
+                if len(tokens) >= 3:
+                    value = tokens[2].unescape().value
                     if value[-1] == 'm':
                         value = value[0: -1]
                     vprec = float(value) * 100.0        # m -> cm
-                    tok.get_eol()
 
         # Try encoding these now so we raise if they are bad
         _encode_size(size, "size")
index 8ac76271df32ffef5ff420423298e4de5a38c384..31ab8b776f8bc10a594189b340645a6dad1636c3 100644 (file)
@@ -57,7 +57,6 @@ class NSEC3PARAM(dns.rdata.Rdata):
             salt = ''
         else:
             salt = binascii.unhexlify(salt.encode())
-        tok.get_eol()
         return cls(rdclass, rdtype, algorithm, flags, iterations, salt)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
index 7446de6deef175531f88e3bdef3bd93936a67f5e..a6054da9b5c3d19f7719f3019dc1ddb1e2db1b2d 100644 (file)
@@ -43,7 +43,6 @@ class RP(dns.rdata.Rdata):
                   relativize_to=None):
         mbox = tok.get_name(origin, relativize, relativize_to)
         txt = tok.get_name(origin, relativize, relativize_to)
-        tok.get_eol()
         return cls(rdclass, rdtype, mbox, txt)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
index e93274ed729167442201f8ccc3f54471c13be886..32b0a8663322f0c001924a814f05a201fb7a4840 100644 (file)
@@ -59,7 +59,6 @@ class SOA(dns.rdata.Rdata):
         retry = tok.get_ttl()
         expire = tok.get_ttl()
         minimum = tok.get_ttl()
-        tok.get_eol()
         return cls(rdclass, rdtype, mname, rname, serial, refresh, retry,
                    expire, minimum)
 
index 84296f52aee3ee096b388583e30fc1701b47883c..7d6d06854015c8a81f8dbac57d62717d940dd9e2 100644 (file)
@@ -54,7 +54,6 @@ class URI(dns.rdata.Rdata):
         target = tok.get().unescape()
         if not (target.is_quoted_string() or target.is_identifier()):
             raise dns.exception.SyntaxError("URI target must be a string")
-        tok.get_eol()
         return cls(rdclass, rdtype, priority, weight, target.value)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
index 214f1dca4136cd68dae8a405101984731cd345f5..29b9c4d539cf119874cd74c47ac47c82773644ba 100644 (file)
@@ -44,7 +44,6 @@ class X25(dns.rdata.Rdata):
     def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
                   relativize_to=None):
         address = tok.get_string()
-        tok.get_eol()
         return cls(rdclass, rdtype, address)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
index b738ac6cb020970612dabb08c79c9c436827c4a6..330fcae8e3dc7c32c62a0e76b88b85e27e8e9ee5 100644 (file)
@@ -41,7 +41,6 @@ class A(dns.rdata.Rdata):
                   relativize_to=None):
         domain = tok.get_name(origin, relativize, relativize_to)
         address = tok.get_uint16(base=8)
-        tok.get_eol()
         return cls(rdclass, rdtype, domain, address)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
index 8b71e329f62ed9be316578ee79393046d90ce077..35ec46f58c45dd13f039c331274dabdd69bf5386 100644 (file)
@@ -40,7 +40,6 @@ class A(dns.rdata.Rdata):
     def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
                   relativize_to=None):
         address = tok.get_identifier()
-        tok.get_eol()
         return cls(rdclass, rdtype, address)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
index 08f9d679489628e8a424ef99302fd79b38ae9408..c37b82a88ade7ac1f128f5106d7acc4e198833db 100644 (file)
@@ -40,7 +40,6 @@ class AAAA(dns.rdata.Rdata):
     def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
                   relativize_to=None):
         address = tok.get_identifier()
-        tok.get_eol()
         return cls(rdclass, rdtype, address)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
index ab7fe4bca914d4dbbe490f3bea19702e1e906629..3b1b8d12a12f12d3bb9eda33ec4c24d2fe8ec090 100644 (file)
@@ -87,11 +87,8 @@ class APL(dns.rdata.Rdata):
     def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
                   relativize_to=None):
         items = []
-        while True:
-            token = tok.get().unescape()
-            if token.is_eol_or_eof():
-                break
-            item = token.value
+        for token in tok.get_remaining():
+            item = token.unescape().value
             if item[0] == '!':
                 negation = True
                 item = item[1:]
index 48d43562511994ae5419ebcaaf2b72d91ae3a78b..a861b67ff6a78ae82c0c6a28ce37c2aecf30511e 100644 (file)
@@ -72,7 +72,6 @@ class NAPTR(dns.rdata.Rdata):
         service = tok.get_string()
         regexp = tok.get_string()
         replacement = tok.get_name(origin, relativize, relativize_to)
-        tok.get_eol()
         return cls(rdclass, rdtype, order, preference, flags, service,
                    regexp, replacement)
 
index 227465fadc2885ea235796e4266ee15b8f198975..78730a1a115481615eaa4748d0411cc9378047bd 100644 (file)
@@ -41,7 +41,6 @@ class NSAP(dns.rdata.Rdata):
     def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
                   relativize_to=None):
         address = tok.get_string()
-        tok.get_eol()
         if address[0:2] != '0x':
             raise dns.exception.SyntaxError('string does not start with 0x')
         address = address[2:].replace('.', '')
index 946d79f8c3441f7116f81df161198129abd6d110..288bb125f1072ffbcf89b062a74eb89dcc1c98de 100644 (file)
@@ -47,7 +47,6 @@ class PX(dns.rdata.Rdata):
         preference = tok.get_uint16()
         map822 = tok.get_name(origin, relativize, relativize_to)
         mapx400 = tok.get_name(origin, relativize, relativize_to)
-        tok.get_eol()
         return cls(rdclass, rdtype, preference, map822, mapx400)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
index 485153f4358345b7be3d86e385ffb7a4bcc3aeae..a3debab9c3b55d063b4b9006a31039efec464f98 100644 (file)
@@ -49,7 +49,6 @@ class SRV(dns.rdata.Rdata):
         weight = tok.get_uint16()
         port = tok.get_uint16()
         target = tok.get_name(origin, relativize, relativize_to)
-        tok.get_eol()
         return cls(rdclass, rdtype, priority, weight, port, target)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
index d66d85832bd3345d97f28e0df4fcb7e1b5168d51..9b5e87e65010e4e17dd59e0ee04a1c9b16362ba5 100644 (file)
@@ -59,12 +59,10 @@ class WKS(dns.rdata.Rdata):
         else:
             protocol = socket.getprotobyname(protocol)
         bitmap = bytearray()
-        while 1:
-            token = tok.get().unescape()
-            if token.is_eol_or_eof():
-                break
-            if token.value.isdigit():
-                serv = int(token.value)
+        for token in tok.get_remaining():
+            value = token.unescape().value
+            if value.isdigit():
+                serv = int(value)
             else:
                 if protocol != _proto_udp and protocol != _proto_tcp:
                     raise NotImplementedError("protocol must be TCP or UDP")
@@ -72,7 +70,7 @@ class WKS(dns.rdata.Rdata):
                     protocol_text = "udp"
                 else:
                     protocol_text = "tcp"
-                serv = socket.getservbyname(token.value, protocol_text)
+                serv = socket.getservbyname(value, protocol_text)
             i = serv // 8
             l = len(bitmap)
             if l < i + 1:
index c1677a81d8b9a524b5caabb0f329b9435ee01ba4..ba44571f32a094d70d37b19cc6f6cf19aa192f3b 100644 (file)
@@ -44,7 +44,6 @@ class EUIBase(dns.rdata.Rdata):
     def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
                   relativize_to=None):
         text = tok.get_string()
-        tok.get_eol()
         if len(text) != cls.text_len:
             raise dns.exception.SyntaxError(
                 'Input text must have %s characters' % cls.text_len)
index d6a6efed5936df19bdccf6e6856832f7267845c3..723b762715fc003cf182319b654e880d07a95dc3 100644 (file)
@@ -44,7 +44,6 @@ class MXBase(dns.rdata.Rdata):
                   relativize_to=None):
         preference = tok.get_uint16()
         exchange = tok.get_name(origin, relativize, relativize_to)
-        tok.get_eol()
         return cls(rdclass, rdtype, preference, exchange)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
index 93d3ee53659c08d0b745633f1c700538fdb51025..212f8c0a081f9c800d1297e79662f8bc6922e881 100644 (file)
@@ -40,7 +40,6 @@ class NSBase(dns.rdata.Rdata):
     def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
                   relativize_to=None):
         target = tok.get_name(origin, relativize, relativize_to)
-        tok.get_eol()
         return cls(rdclass, rdtype, target)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
index ad0093daa39a1e695a2ebb94894207fd5300cfb2..38c56011d77a3694e4ae17e09fe6550ff2cab014 100644 (file)
@@ -63,10 +63,8 @@ class TXTBase(dns.rdata.Rdata):
     def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
                   relativize_to=None):
         strings = []
-        while 1:
-            token = tok.get().unescape_to_bytes()
-            if token.is_eol_or_eof():
-                break
+        for token in tok.get_remaining():
+            token = token.unescape_to_bytes()
             if not (token.is_quoted_string() or token.is_identifier()):
                 raise dns.exception.SyntaxError("expected a string")
             if len(token.value) > 255:
index a63d1a0abc1c6b7af6ef1a0faa135f401e35c6b7..3a089ed2c2f30d9369141813de8b27521af815e1 100644 (file)
@@ -113,11 +113,8 @@ class Bitmap:
 
     def from_text(self, tok):
         rdtypes = []
-        while True:
-            token = tok.get().unescape()
-            if token.is_eol_or_eof():
-                break
-            rdtype = dns.rdatatype.from_text(token.value)
+        for token in tok.get_remaining():
+            rdtype = dns.rdatatype.from_text(token.unescape().value)
             if rdtype == 0:
                 raise dns.exception.SyntaxError(f"{self.type_name} with bit 0")
             rdtypes.append(rdtype)
index 3e5d2ba92e8762532c2b3c5d6cbd0170298b26c7..0c117abd5617453ad555176f8a534f7341d1cf96 100644 (file)
@@ -48,12 +48,13 @@ class Token:
     has_escape: Does the token value contain escapes?
     """
 
-    def __init__(self, ttype, value='', has_escape=False):
+    def __init__(self, ttype, value='', has_escape=False, comment=None):
         """Initialize a token instance."""
 
         self.ttype = ttype
         self.value = value
         self.has_escape = has_escape
+        self.comment = comment
 
     def is_eof(self):
         return self.ttype == EOF
@@ -396,13 +397,13 @@ class Tokenizer:
                             if self.multiline:
                                 raise dns.exception.SyntaxError(
                                     'unbalanced parentheses')
-                            return Token(EOF)
+                            return Token(EOF, comment=token)
                         elif self.multiline:
                             self.skip_whitespace()
                             token = ''
                             continue
                         else:
-                            return Token(EOL, '\n')
+                            return Token(EOL, '\n', comment=token)
                     else:
                         # This code exists in case we ever want a
                         # delimiter to be returned.  It never produces
@@ -559,6 +560,25 @@ class Tokenizer:
             raise dns.exception.SyntaxError('expecting an identifier')
         return token.value
 
+    def get_remaining(self, max_tokens=None):
+        """Return the remaining tokens on the line, until an EOL or EOF is seen.
+
+        max_tokens: If not None, stop after this number of tokens.
+
+        Returns a list of tokens.
+        """
+
+        tokens = []
+        while True:
+            token = self.get()
+            if token.is_eol_or_eof():
+                self.unget(token)
+                break
+            tokens.append(token)
+            if len(tokens) == max_tokens:
+                break
+        return tokens
+
     def concatenate_remaining_identifiers(self):
         """Read the remaining tokens on the line, which should be identifiers.
 
@@ -572,6 +592,7 @@ class Tokenizer:
         while True:
             token = self.get().unescape()
             if token.is_eol_or_eof():
+                self.unget(token)
                 break
             if not token.is_identifier():
                 raise dns.exception.SyntaxError
@@ -601,7 +622,7 @@ class Tokenizer:
         token = self.get()
         return self.as_name(token, origin, relativize, relativize_to)
 
-    def get_eol(self):
+    def get_eol_as_token(self):
         """Read the next token and raise an exception if it isn't EOL or
         EOF.
 
@@ -613,7 +634,10 @@ class Tokenizer:
             raise dns.exception.SyntaxError(
                 'expected EOL or EOF, got %d "%s"' % (token.ttype,
                                                       token.value))
-        return token.value
+        return token
+
+    def get_eol(self):
+        return self.get_eol_as_token().value
 
     def get_ttl(self):
         """Read the next token and interpret it as a DNS TTL.
index e8413c0866fae157dfc7eff1ebc372b10e1f75b5..eee5fdb691b10dd44d0477a3e9cb4455bc304561 100644 (file)
@@ -532,7 +532,8 @@ class Zone:
                     for rdata in rds:
                         yield (name, rds.ttl, rdata)
 
-    def to_file(self, f, sorted=True, relativize=True, nl=None):
+    def to_file(self, f, sorted=True, relativize=True, nl=None,
+                want_comments=False):
         """Write a zone to a file.
 
         *f*, a file or `str`.  If *f* is a string, it is treated
@@ -550,6 +551,10 @@ class Zone:
         *nl*, a ``str`` or None.  The end of line string.  If not
         ``None``, the output will use the platform's native
         end-of-line marker (i.e. LF on POSIX, CRLF on Windows).
+
+        *want_comments*, a ``bool``.  If ``True``, emit end-of-line comments
+        as part of writing the file.  If ``False``, the default, do not
+        emit them.
         """
 
         with contextlib.ExitStack() as stack:
@@ -579,7 +584,8 @@ class Zone:
                 names = self.keys()
             for n in names:
                 l = self[n].to_text(n, origin=self.origin,
-                                    relativize=relativize)
+                                    relativize=relativize,
+                                    want_comments=want_comments)
                 if isinstance(l, str):
                     l_b = l.encode(file_enc)
                 else:
@@ -593,7 +599,8 @@ class Zone:
                     f.write(l)
                     f.write(nl)
 
-    def to_text(self, sorted=True, relativize=True, nl=None):
+    def to_text(self, sorted=True, relativize=True, nl=None,
+                want_comments=False):
         """Return a zone's text as though it were written to a file.
 
         *sorted*, a ``bool``.  If True, the default, then the file
@@ -609,10 +616,14 @@ class Zone:
         ``None``, the output will use the platform's native
         end-of-line marker (i.e. LF on POSIX, CRLF on Windows).
 
+        *want_comments*, a ``bool``.  If ``True``, emit end-of-line comments
+        as part of writing the file.  If ``False``, the default, do not
+        emit them.
+
         Returns a ``str``.
         """
         temp_buffer = io.StringIO()
-        self.to_file(temp_buffer, sorted, relativize, nl)
+        self.to_file(temp_buffer, sorted, relativize, nl, want_comments)
         return_value = temp_buffer.getvalue()
         temp_buffer.close()
         return return_value
diff --git a/tests/mx-2-0.pickle b/tests/mx-2-0.pickle
new file mode 100644 (file)
index 0000000..53d094c
Binary files /dev/null and b/tests/mx-2-0.pickle differ
index 4eb48d3a7653c5a24db563c4bc597a22951c8a1c..e64578b049d14bb039d089cec6b0161266099654 100644 (file)
@@ -29,8 +29,7 @@ import dns.rdatatype
 import dns.rrset
 import dns.update
 
-def here(filename):
-    return os.path.join(os.path.dirname(__file__), filename)
+from tests.util import here
 
 query_text = """id 1234
 opcode QUERY
index 022642fc641a351399e1635a2e9dc17806379675..7960dd1974b9b302c78655cf2bd421ffea5b4496 100644 (file)
@@ -34,6 +34,7 @@ from dns.rdtypes.ANY.OPT import OPT
 
 import tests.stxt_module
 import tests.ttxt_module
+from tests.util import here
 
 class RdataTestCase(unittest.TestCase):
 
@@ -94,6 +95,15 @@ class RdataTestCase(unittest.TestCase):
             a1.replace(address="bogus")
         self.assertRaises(dns.exception.SyntaxError, bad)
 
+    def test_replace_comment(self):
+        a1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A,
+                                 "1.2.3.4 ;foo")
+        self.assertEqual(a1.rdcomment, "foo")
+        a2 = a1.replace(rdcomment="bar")
+        self.assertEqual(a1, a2)
+        self.assertEqual(a1.rdcomment, "foo")
+        self.assertEqual(a2.rdcomment, "bar")
+
     def test_to_generic(self):
         a = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "1.2.3.4")
         self.assertEqual(str(a.to_generic()), r'\# 4 01020304')
@@ -415,5 +425,13 @@ class RdataTestCase(unittest.TestCase):
         rdata = dns.rdata.from_wire('in', 'a', wire, 0, 4)
         self.assertEqual(rdata, dns.rdata.from_text('in', 'a', '1.2.3.4'))
 
+    def test_unpickle(self):
+        expected_mx = dns.rdata.from_text('in', 'mx', '10 mx.example.')
+        with open(here('mx-2-0.pickle'), 'rb') as f:
+            mx = pickle.load(f)
+        self.assertEqual(mx, expected_mx)
+        self.assertIsNone(mx.rdcomment)
+
+
 if __name__ == '__main__':
     unittest.main()
index 36428b1264873caebf03389b2e067097ae81d67f..9857f80514086c0c783e53532d5add129094290c 100644 (file)
@@ -35,8 +35,7 @@ import dns.rrset
 import dns.zone
 import dns.node
 
-def here(filename):
-    return os.path.join(os.path.dirname(__file__), filename)
+from tests.util import here
 
 example_text = """$TTL 3600
 $ORIGIN example.
@@ -175,6 +174,23 @@ $ORIGIN example.
 @ 300 ns ns2
 """
 
+example_comments_text = """$TTL 3600
+$ORIGIN example.
+@ soa foo bar (1 ; not kept
+2 3 4 5) ; kept
+@ ns ns1
+@ ns ns2
+ns1 a 10.0.0.1 ; comment1
+ns2 a 10.0.0.2 ; comment2
+"""
+
+example_comments_text_output = """@ 3600 IN SOA foo bar 1 2 3 4 5 ; kept
+@ 3600 IN NS ns1
+@ 3600 IN NS ns2
+ns1 3600 IN A 10.0.0.1 ; comment1
+ns2 3600 IN A 10.0.0.2 ; comment2
+"""
+
 _keep_output = True
 
 def _rdata_sort(a):
@@ -746,5 +762,14 @@ class ZoneTestCase(unittest.TestCase):
         self.assertEqual(z._validate_name('foo.bar.example.'),
                          dns.name.from_text('foo.bar', None))
 
+    def testComments(self):
+        z = dns.zone.from_text(example_comments_text, 'example.',
+                               relativize=True)
+        f = StringIO()
+        z.to_file(f, want_comments=True)
+        out = f.getvalue()
+        f.close()
+        self.assertEqual(out, example_comments_text_output)
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/tests/util.py b/tests/util.py
new file mode 100644 (file)
index 0000000..df736df
--- /dev/null
@@ -0,0 +1,21 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+import os.path
+
+def here(filename):
+    return os.path.join(os.path.dirname(__file__), filename)
diff --git a/util/generate-mx-pickle.py b/util/generate-mx-pickle.py
new file mode 100644 (file)
index 0000000..ad99942
--- /dev/null
@@ -0,0 +1,19 @@
+import pickle
+import sys
+
+import dns.rdata
+import dns.version
+
+# Generate a pickled mx RR for the current dnspython version
+
+mx = dns.rdata.from_text('in', 'mx', '10 mx.example.')
+filename = f'pickled-{dns.version.MAJOR}-{dns.version.MINOR}.pickle'
+with open(filename, 'wb') as f:
+    pickle.dump(mx, f)
+with open(filename, 'rb') as f:
+    mx2 = pickle.load(f)
+if mx == mx2:
+    print('ok')
+else:
+    print('DIFFERENT!')
+    sys.exit(1)