]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
set_rcode() was broken when used with extended rcodes; keep ednsflags coherent with...
authorBob Halley <halley@dnspython.org>
Tue, 3 Jun 2008 10:49:12 +0000 (10:49 +0000)
committerBob Halley <halley@dnspython.org>
Tue, 3 Jun 2008 10:49:12 +0000 (10:49 +0000)
ChangeLog
dns/message.py
dns/renderer.py
tests/message.py

index 80911656f6a3f82de04f83d78607aed0399cb935..d5468cc407b122074a3ebf9d2f3e54a483fe37ad 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,11 @@
+2008-06-03  Bob Halley  <halley@dnspython.org>
+
+       * dns/message.py (Message.set_rcode): The mask used preserved the
+         extended rcode, instead of everything else in ednsflags.
+
+       * dns/message.py (Message.use_edns): ednsflags was not kept
+         coherent with the specified edns version.
+
 2008-02-06  Bob Halley  <halley@dnspython.org>
 
        * dns/ipv6.py (inet_aton):  We could raise an exception other than
index 852793591654cb4cdb4262d1c0af2b0592cb1a13..66e6a4783c995c6ae0b13c8e6f5fcf6182d312c0 100644 (file)
@@ -161,7 +161,7 @@ class Message(object):
 
     def __repr__(self):
         return '<DNS message, ID ' + `self.id` + '>'
-    
+
     def __str__(self):
         return self.to_text()
 
@@ -173,7 +173,7 @@ class Message(object):
 
         @rtype: string
         """
-        
+
         s = cStringIO.StringIO()
         print >> s, 'id %d' % self.id
         print >> s, 'opcode %s' % \
@@ -288,7 +288,7 @@ class Message(object):
                    covers=dns.rdatatype.NONE, deleting=None, create=False,
                    force_unique=False):
         """Find the RRset with the given attributes in the specified section.
-        
+
         @param section: the section of the message to look in, e.g.
         self.answer.
         @type section: list of dns.rrset.RRset objects
@@ -336,7 +336,7 @@ class Message(object):
         """Get the RRset with the given attributes in the specified section.
 
         If the RRset is not found, None is returned.
-        
+
         @param section: the section of the message to look in, e.g.
         self.answer.
         @type section: list of dns.rrset.RRset objects
@@ -371,7 +371,7 @@ class Message(object):
 
         Additional keyword arguments are passed to the rrset to_wire()
         method.
-        
+
         @param origin: The origin to be appended to any relative names.
         @type origin: dns.name.Name object
         @param max_size: The maximum size of the wire format output; default
@@ -414,7 +414,7 @@ class Message(object):
                  tsig_error=0, other_data=''):
         """When sending, a TSIG signature using the specified keyring
         and keyname should be added.
-        
+
         @param keyring: The TSIG keyring to use; defaults to None.
         @type keyring: dict
         @param keyname: The name of the TSIG key to use; defaults to None.
@@ -476,6 +476,10 @@ class Message(object):
             ednsflags = 0
             payload = 0
             request_payload = 0
+        else:
+            # make sure the EDNS version in ednsflags agrees with edns
+            ednsflags &= 0xFF00FFFFL
+            ednsflags |= (edns << 16)
         self.edns = edns
         self.ednsflags = ednsflags
         self.payload = payload
@@ -509,7 +513,7 @@ class Message(object):
         (value, evalue) = dns.rcode.to_flags(rcode)
         self.flags &= 0xFFF0
         self.flags |= value
-        self.ednsflags &= 0xFF000000L
+        self.ednsflags &= 0x00FFFFFFL
         self.ednsflags |= evalue
         if self.ednsflags != 0 and self.edns < 0:
             self.edns = 0
@@ -545,7 +549,7 @@ class _WireReader(object):
     DNS dynamic updates.
     @type zone_rdclass: int
     """
-    
+
     def __init__(self, wire, message, question_only=False):
         self.wire = wire
         self.message = message
@@ -553,7 +557,7 @@ class _WireReader(object):
         self.updating = False
         self.zone_rdclass = dns.rdataclass.IN
         self.question_only = question_only
-        
+
     def _get_question(self, qcount):
         """Read the next I{qcount} records from the wire data and add them to
         the question section.
@@ -562,7 +566,7 @@ class _WireReader(object):
 
         if self.updating and qcount > 1:
             raise dns.exception.FormError
-        
+
         for i in xrange(0, qcount):
             (qname, used) = dns.name.from_wire(self.wire, self.current)
             if not self.message.origin is None:
@@ -577,7 +581,7 @@ class _WireReader(object):
                                     force_unique=True)
             if self.updating:
                 self.zone_rdclass = rdclass
-        
+
     def _get_section(self, section, count):
         """Read the next I{count} records from the wire data and add them to
         the specified section.
@@ -585,7 +589,7 @@ class _WireReader(object):
         @type section: list of dns.rrset.RRset objects
         @param count: the number of records to read
         @type count: int"""
-        
+
         if self.updating:
             force_unique = True
         else:
@@ -618,7 +622,7 @@ class _WireReader(object):
                 if secret is None:
                     raise UnknownTSIGKey, "key '%s' unknown" % name
                 self.message.tsig_ctx = \
-                       dns.tsig.validate(self.wire,
+                                      dns.tsig.validate(self.wire,
                                           name,
                                           secret,
                                           int(time.time()),
@@ -660,7 +664,7 @@ class _WireReader(object):
     def read(self):
         """Read a wire format DNS message and build a dns.message.Message
         object."""
-        
+
         l = len(self.wire)
         if l < 12:
             raise ShortHeader
@@ -730,11 +734,11 @@ def from_wire(wire, keyring=None, request_mac='', xfr=False, origin=None,
     reader.read()
 
     return m
-    
+
 
 class _TextReader(object):
     """Text format reader.
-    
+
     @ivar tok: the tokenizer
     @type tok: dns.tokenizer.Tokenizer object
     @ivar message: The message object being built
@@ -758,7 +762,7 @@ class _TextReader(object):
 
     def _header_line(self, section):
         """Process one line from the text format header section."""
-        
+
         (ttype, what) = self.tok.get()
         if what == 'id':
             self.message.id = self.tok.get_int()
@@ -803,7 +807,7 @@ class _TextReader(object):
 
     def _question_line(self, section):
         """Process one line from the text format question section."""
-        
+
         token = self.tok.get(want_leading = True)
         if token[0] != dns.tokenizer.WHITESPACE:
             self.last_name = dns.name.from_text(token[1], None)
@@ -834,7 +838,7 @@ class _TextReader(object):
         """Process one line from the text format answer, authority, or
         additional data sections.
         """
-        
+
         deleting = None
         # Name
         token = self.tok.get(want_leading = True)
@@ -913,7 +917,7 @@ class _TextReader(object):
                 continue
             self.tok.unget(token)
             line_method(section)
-        
+
 
 def from_text(text):
     """Convert the text format message into a message object.
@@ -927,14 +931,14 @@ def from_text(text):
     # 'text' can also be a file, but we don't publish that fact
     # since it's an implementation detail.  The official file
     # interface is from_file().
-    
+
     m = Message()
 
     reader = _TextReader(text, m)
     reader.read()
-    
+
     return m
-    
+
 def from_file(f):
     """Read the next text format message from the specified file.
 
@@ -973,7 +977,7 @@ def make_query(qname, rdtype, rdclass = dns.rdataclass.IN, use_edns=None,
 
     The query will have a randomly choosen query id, and its DNS flags
     will be set to dns.flags.RD.
-    
+
     @param qname: The query name.
     @type qname: dns.name.Name object or string
     @param rdtype: The desired rdata type.
@@ -987,7 +991,7 @@ def make_query(qname, rdtype, rdclass = dns.rdataclass.IN, use_edns=None,
     @param want_dnssec: Should the query indicate that DNSSEC is desired?
     @type want_dnssec: bool
     @rtype: dns.message.Message object"""
-    
+
     if isinstance(qname, (str, unicode)):
         qname = dns.name.from_text(qname)
     if isinstance(rdtype, str):
@@ -1011,7 +1015,7 @@ def make_response(query, recursion_available=False, our_payload=8192):
     The response's question section is a shallow copy of the query's
     question section, so the query's question RRsets should not be
     changed.
-    
+
     @param query: the query to respond to
     @type query: dns.message.Message object
     @param recursion_available: should RA be set in the response?
index af7e2cb4469dd15949189952ad3efac9fbc898f0..07bce98f6e7243b2b08a1e16e2889b91111a6a87 100644 (file)
@@ -35,7 +35,7 @@ class Renderer(object):
     class and its to_wire() method to generate wire-format messages.
     This class is for those applications which need finer control
     over the generation of messages.
-    
+
     Typical use::
 
         r = dns.renderer.Renderer(id=1, flags=0x80, max_size=512)
@@ -85,7 +85,7 @@ class Renderer(object):
         @param origin: the origin to use when rendering relative names
         @type origin: dns.name.Namem or None.
         """
-        
+
         self.output = cStringIO.StringIO()
         if id is None:
             self.id = random.randint(0, 65535)
@@ -108,7 +108,7 @@ class Renderer(object):
         @param where: the offset
         @type where: int
         """
-        
+
         self.output.seek(where)
         self.output.truncate()
         keys_to_delete = []
@@ -129,7 +129,7 @@ class Renderer(object):
         @raises dns.exception.FormError: an attempt was made to set
         a section value less than the current section.
         """
-        
+
         if self.section != section:
             if self.section > section:
                 raise dns.exception.FormError
@@ -145,7 +145,7 @@ class Renderer(object):
         @param rdclass: the question rdata class
         @type rdclass: int
         """
-        
+
         self._set_section(QUESTION)
         before = self.output.tell()
         qname.to_wire(self.output, self.compress, self.origin)
@@ -155,13 +155,13 @@ class Renderer(object):
             self._rollback(before)
             raise dns.exception.TooBig
         self.counts[QUESTION] += 1
-        
+
     def add_rrset(self, section, rrset, **kw):
         """Add the rrset to the specified section.
 
         Any keyword arguments are passed on to the rdataset's to_wire()
         routine.
-        
+
         @param section: the section
         @type section: int
         @param rrset: the rrset
@@ -191,7 +191,7 @@ class Renderer(object):
         @param rdataset: the rdataset
         @type rdataset: dns.rdataset.Rdataset object
         """
-        
+
         self._set_section(section)
         before = self.output.tell()
         n = rdataset.to_wire(name, self.output, self.compress, self.origin,
@@ -215,6 +215,9 @@ class Renderer(object):
         @see: RFC 2671
         """
 
+        # make sure the EDNS version in ednsflags agrees with edns
+        ednsflags &= 0xFF00FFFFL
+        ednsflags |= (edns << 16)
         self._set_section(ADDITIONAL)
         before = self.output.tell()
         self.output.write(struct.pack('!BHHIH', 0, dns.rdatatype.OPT, payload,
@@ -228,7 +231,7 @@ class Renderer(object):
     def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data,
                  request_mac):
         """Add a TSIG signature to the message.
-        
+
         @param keyname: the TSIG key name
         @type keyname: dns.name.Name object
         @param secret: the secret to use
@@ -245,7 +248,7 @@ class Renderer(object):
         had the specified MAC.
         @type request_mac: string
         """
+
         self._set_section(ADDITIONAL)
         before = self.output.tell()
         s = self.output.getvalue()
@@ -282,7 +285,7 @@ class Renderer(object):
         have been rendered, but before the optional TSIG signature
         is added.
         """
-        
+
         self.output.seek(0)
         self.output.write(struct.pack('!HHHHHH', self.id, self.flags,
                                       self.counts[0], self.counts[1],
@@ -294,5 +297,5 @@ class Renderer(object):
 
         @rtype: string
         """
-        
+
         return self.output.getvalue()
index e63910b51c0b99d6b9b31b3f00a9acdb14f5067a..3338efe6df7c4f09f501709d5afe10653915d885 100644 (file)
@@ -163,5 +163,17 @@ class MessageTestCase(unittest.TestCase):
             r2 = dns.message.make_response(r1)
         self.failUnlessRaises(dns.exception.FormError, bad)
 
+    def test_ExtendedRcodeSetting(self):
+        m = dns.message.make_query('foo', 'A')
+        m.set_rcode(4095)
+        self.failUnless(m.rcode() == 4095)
+        m.set_rcode(2)
+        self.failUnless(m.rcode() == 2)
+
+    def test_EDNSVersionCoherence(self):
+        m = dns.message.make_query('foo', 'A')
+        m.use_edns(1)
+        self.failUnless((m.ednsflags >> 16) & 0xFF == 1)
+
 if __name__ == '__main__':
     unittest.main()