From: Bob Halley Date: Tue, 3 Jun 2008 10:49:12 +0000 (+0000) Subject: set_rcode() was broken when used with extended rcodes; keep ednsflags coherent with... X-Git-Tag: v1.7.0~39 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=edebc929df5c5e880c34817434765d66e3c8e21e;p=thirdparty%2Fdnspython.git set_rcode() was broken when used with extended rcodes; keep ednsflags coherent with edns version --- diff --git a/ChangeLog b/ChangeLog index 80911656..d5468cc4 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,11 @@ +2008-06-03 Bob Halley + + * dns/message.py (Message.set_rcode): The mask used preserved the + extended rcode, instead of everything else in ednsflags. + + * dns/message.py (Message.use_edns): ednsflags was not kept + coherent with the specified edns version. + 2008-02-06 Bob Halley * dns/ipv6.py (inet_aton): We could raise an exception other than diff --git a/dns/message.py b/dns/message.py index 85279359..66e6a478 100644 --- a/dns/message.py +++ b/dns/message.py @@ -161,7 +161,7 @@ class Message(object): def __repr__(self): return '' - + def __str__(self): return self.to_text() @@ -173,7 +173,7 @@ class Message(object): @rtype: string """ - + s = cStringIO.StringIO() print >> s, 'id %d' % self.id print >> s, 'opcode %s' % \ @@ -288,7 +288,7 @@ class Message(object): covers=dns.rdatatype.NONE, deleting=None, create=False, force_unique=False): """Find the RRset with the given attributes in the specified section. - + @param section: the section of the message to look in, e.g. self.answer. @type section: list of dns.rrset.RRset objects @@ -336,7 +336,7 @@ class Message(object): """Get the RRset with the given attributes in the specified section. If the RRset is not found, None is returned. - + @param section: the section of the message to look in, e.g. self.answer. @type section: list of dns.rrset.RRset objects @@ -371,7 +371,7 @@ class Message(object): Additional keyword arguments are passed to the rrset to_wire() method. - + @param origin: The origin to be appended to any relative names. @type origin: dns.name.Name object @param max_size: The maximum size of the wire format output; default @@ -414,7 +414,7 @@ class Message(object): tsig_error=0, other_data=''): """When sending, a TSIG signature using the specified keyring and keyname should be added. - + @param keyring: The TSIG keyring to use; defaults to None. @type keyring: dict @param keyname: The name of the TSIG key to use; defaults to None. @@ -476,6 +476,10 @@ class Message(object): ednsflags = 0 payload = 0 request_payload = 0 + else: + # make sure the EDNS version in ednsflags agrees with edns + ednsflags &= 0xFF00FFFFL + ednsflags |= (edns << 16) self.edns = edns self.ednsflags = ednsflags self.payload = payload @@ -509,7 +513,7 @@ class Message(object): (value, evalue) = dns.rcode.to_flags(rcode) self.flags &= 0xFFF0 self.flags |= value - self.ednsflags &= 0xFF000000L + self.ednsflags &= 0x00FFFFFFL self.ednsflags |= evalue if self.ednsflags != 0 and self.edns < 0: self.edns = 0 @@ -545,7 +549,7 @@ class _WireReader(object): DNS dynamic updates. @type zone_rdclass: int """ - + def __init__(self, wire, message, question_only=False): self.wire = wire self.message = message @@ -553,7 +557,7 @@ class _WireReader(object): self.updating = False self.zone_rdclass = dns.rdataclass.IN self.question_only = question_only - + def _get_question(self, qcount): """Read the next I{qcount} records from the wire data and add them to the question section. @@ -562,7 +566,7 @@ class _WireReader(object): if self.updating and qcount > 1: raise dns.exception.FormError - + for i in xrange(0, qcount): (qname, used) = dns.name.from_wire(self.wire, self.current) if not self.message.origin is None: @@ -577,7 +581,7 @@ class _WireReader(object): force_unique=True) if self.updating: self.zone_rdclass = rdclass - + def _get_section(self, section, count): """Read the next I{count} records from the wire data and add them to the specified section. @@ -585,7 +589,7 @@ class _WireReader(object): @type section: list of dns.rrset.RRset objects @param count: the number of records to read @type count: int""" - + if self.updating: force_unique = True else: @@ -618,7 +622,7 @@ class _WireReader(object): if secret is None: raise UnknownTSIGKey, "key '%s' unknown" % name self.message.tsig_ctx = \ - dns.tsig.validate(self.wire, + dns.tsig.validate(self.wire, name, secret, int(time.time()), @@ -660,7 +664,7 @@ class _WireReader(object): def read(self): """Read a wire format DNS message and build a dns.message.Message object.""" - + l = len(self.wire) if l < 12: raise ShortHeader @@ -730,11 +734,11 @@ def from_wire(wire, keyring=None, request_mac='', xfr=False, origin=None, reader.read() return m - + class _TextReader(object): """Text format reader. - + @ivar tok: the tokenizer @type tok: dns.tokenizer.Tokenizer object @ivar message: The message object being built @@ -758,7 +762,7 @@ class _TextReader(object): def _header_line(self, section): """Process one line from the text format header section.""" - + (ttype, what) = self.tok.get() if what == 'id': self.message.id = self.tok.get_int() @@ -803,7 +807,7 @@ class _TextReader(object): def _question_line(self, section): """Process one line from the text format question section.""" - + token = self.tok.get(want_leading = True) if token[0] != dns.tokenizer.WHITESPACE: self.last_name = dns.name.from_text(token[1], None) @@ -834,7 +838,7 @@ class _TextReader(object): """Process one line from the text format answer, authority, or additional data sections. """ - + deleting = None # Name token = self.tok.get(want_leading = True) @@ -913,7 +917,7 @@ class _TextReader(object): continue self.tok.unget(token) line_method(section) - + def from_text(text): """Convert the text format message into a message object. @@ -927,14 +931,14 @@ def from_text(text): # 'text' can also be a file, but we don't publish that fact # since it's an implementation detail. The official file # interface is from_file(). - + m = Message() reader = _TextReader(text, m) reader.read() - + return m - + def from_file(f): """Read the next text format message from the specified file. @@ -973,7 +977,7 @@ def make_query(qname, rdtype, rdclass = dns.rdataclass.IN, use_edns=None, The query will have a randomly choosen query id, and its DNS flags will be set to dns.flags.RD. - + @param qname: The query name. @type qname: dns.name.Name object or string @param rdtype: The desired rdata type. @@ -987,7 +991,7 @@ def make_query(qname, rdtype, rdclass = dns.rdataclass.IN, use_edns=None, @param want_dnssec: Should the query indicate that DNSSEC is desired? @type want_dnssec: bool @rtype: dns.message.Message object""" - + if isinstance(qname, (str, unicode)): qname = dns.name.from_text(qname) if isinstance(rdtype, str): @@ -1011,7 +1015,7 @@ def make_response(query, recursion_available=False, our_payload=8192): The response's question section is a shallow copy of the query's question section, so the query's question RRsets should not be changed. - + @param query: the query to respond to @type query: dns.message.Message object @param recursion_available: should RA be set in the response? diff --git a/dns/renderer.py b/dns/renderer.py index af7e2cb4..07bce98f 100644 --- a/dns/renderer.py +++ b/dns/renderer.py @@ -35,7 +35,7 @@ class Renderer(object): class and its to_wire() method to generate wire-format messages. This class is for those applications which need finer control over the generation of messages. - + Typical use:: r = dns.renderer.Renderer(id=1, flags=0x80, max_size=512) @@ -85,7 +85,7 @@ class Renderer(object): @param origin: the origin to use when rendering relative names @type origin: dns.name.Namem or None. """ - + self.output = cStringIO.StringIO() if id is None: self.id = random.randint(0, 65535) @@ -108,7 +108,7 @@ class Renderer(object): @param where: the offset @type where: int """ - + self.output.seek(where) self.output.truncate() keys_to_delete = [] @@ -129,7 +129,7 @@ class Renderer(object): @raises dns.exception.FormError: an attempt was made to set a section value less than the current section. """ - + if self.section != section: if self.section > section: raise dns.exception.FormError @@ -145,7 +145,7 @@ class Renderer(object): @param rdclass: the question rdata class @type rdclass: int """ - + self._set_section(QUESTION) before = self.output.tell() qname.to_wire(self.output, self.compress, self.origin) @@ -155,13 +155,13 @@ class Renderer(object): self._rollback(before) raise dns.exception.TooBig self.counts[QUESTION] += 1 - + def add_rrset(self, section, rrset, **kw): """Add the rrset to the specified section. Any keyword arguments are passed on to the rdataset's to_wire() routine. - + @param section: the section @type section: int @param rrset: the rrset @@ -191,7 +191,7 @@ class Renderer(object): @param rdataset: the rdataset @type rdataset: dns.rdataset.Rdataset object """ - + self._set_section(section) before = self.output.tell() n = rdataset.to_wire(name, self.output, self.compress, self.origin, @@ -215,6 +215,9 @@ class Renderer(object): @see: RFC 2671 """ + # make sure the EDNS version in ednsflags agrees with edns + ednsflags &= 0xFF00FFFFL + ednsflags |= (edns << 16) self._set_section(ADDITIONAL) before = self.output.tell() self.output.write(struct.pack('!BHHIH', 0, dns.rdatatype.OPT, payload, @@ -228,7 +231,7 @@ class Renderer(object): def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data, request_mac): """Add a TSIG signature to the message. - + @param keyname: the TSIG key name @type keyname: dns.name.Name object @param secret: the secret to use @@ -245,7 +248,7 @@ class Renderer(object): had the specified MAC. @type request_mac: string """ - + self._set_section(ADDITIONAL) before = self.output.tell() s = self.output.getvalue() @@ -282,7 +285,7 @@ class Renderer(object): have been rendered, but before the optional TSIG signature is added. """ - + self.output.seek(0) self.output.write(struct.pack('!HHHHHH', self.id, self.flags, self.counts[0], self.counts[1], @@ -294,5 +297,5 @@ class Renderer(object): @rtype: string """ - + return self.output.getvalue() diff --git a/tests/message.py b/tests/message.py index e63910b5..3338efe6 100644 --- a/tests/message.py +++ b/tests/message.py @@ -163,5 +163,17 @@ class MessageTestCase(unittest.TestCase): r2 = dns.message.make_response(r1) self.failUnlessRaises(dns.exception.FormError, bad) + def test_ExtendedRcodeSetting(self): + m = dns.message.make_query('foo', 'A') + m.set_rcode(4095) + self.failUnless(m.rcode() == 4095) + m.set_rcode(2) + self.failUnless(m.rcode() == 2) + + def test_EDNSVersionCoherence(self): + m = dns.message.make_query('foo', 'A') + m.use_edns(1) + self.failUnless((m.ednsflags >> 16) & 0xFF == 1) + if __name__ == '__main__': unittest.main()