'rdtypes',
'update',
'version',
- 'wiredata',
+ 'wire',
'zone',
]
raise NotImplementedError # pragma: no cover
@classmethod
- def from_wire(cls, otype, wire, current, olen):
+ def from_wire_parser(cls, otype, parser):
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
- *wire*, a ``bytes``, is the wire-format message.
-
- *current*, an ``int``, is the offset in *wire* of the beginning
- of the rdata.
-
- *olen*, an ``int``, is the length of the wire-format option data
+ *parser*, a ``dns.wire.Parser``, the parser, which should be
+ restructed to the option length.
Returns a ``dns.edns.Option``.
"""
-
raise NotImplementedError # pragma: no cover
def _cmp(self, other):
return "Generic %d" % self.otype
@classmethod
- def from_wire(cls, otype, wire, current, olen):
- return cls(otype, wire[current: current + olen])
+ def from_wire_parser(cls, otype, parser):
+ return cls(otype, parser.get_remaining())
def __str__(self):
return self.to_text()
return value
@classmethod
- def from_wire(cls, otype, wire, cur, olen):
- family, src, scope = struct.unpack('!HBB', wire[cur:cur + 4])
- cur += 4
-
+ def from_wire_parser(cls, otype, parser):
+ family, src, scope = parser.get_struct('!HBB')
addrlen = int(math.ceil(src / 8.0))
-
+ prefix = parser.get_bytes(addrlen)
if family == 1:
pad = 4 - addrlen
- addr = dns.ipv4.inet_ntoa(wire[cur:cur + addrlen] + b'\x00' * pad)
+ addr = dns.ipv4.inet_ntoa(prefix + b'\x00' * pad)
elif family == 2:
pad = 16 - addrlen
- addr = dns.ipv6.inet_ntoa(wire[cur:cur + addrlen] + b'\x00' * pad)
+ addr = dns.ipv6.inet_ntoa(prefix + b'\x00' * pad)
else:
raise ValueError('unsupported family')
def __str__(self):
return self.to_text()
+
_type_to_class = {
OptionType.ECS: ECSOption
}
return cls
+def option_from_wire_parser(otype, parser):
+ """Build an EDNS option object from wire format.
+
+ *otype*, an ``int``, is the option type.
+
+ *parser*, a ``dns.wire.Parser``, the parser, which should be
+ restricted to the option length.
+
+ Returns an instance of a subclass of ``dns.edns.Option``.
+ """
+ cls = get_option_class(otype)
+ otype = OptionType.make(otype)
+ return cls.from_wire_parser(otype, parser)
+
+
def option_from_wire(otype, wire, current, olen):
"""Build an EDNS option object from wire format.
Returns an instance of a subclass of ``dns.edns.Option``.
"""
-
- cls = get_option_class(otype)
- otype = OptionType.make(otype)
- return cls.from_wire(otype, wire, current, olen)
+ parser = dns.wire.Parser(wire, current)
+ with parser.restrict_to(olen):
+ return option_from_wire_parser(otype, parser)
import contextlib
import io
-import struct
import time
+import dns.wire
import dns.edns
import dns.enum
import dns.exception
import dns.rrset
import dns.renderer
import dns.tsig
-import dns.wiredata
import dns.rdtypes.ANY.OPT
import dns.rdtypes.ANY.TSIG
"""Wire format reader.
- wire: a binary, is the wire-format message.
+ parser: the binary parser
message: The message object being built
- current: When building a message object from wire format, this
- variable contains the offset from the beginning of wire of the next octet
- to be read.
initialize_message: Callback to set message parsing options
question_only: Are we only reading the question?
one_rr_per_rrset: Put each RR into its own RRset?
def __init__(self, wire, initialize_message, question_only=False,
one_rr_per_rrset=False, ignore_trailing=False,
keyring=None, multi=False):
- self.wire = dns.wiredata.maybe_wrap(wire)
+ self.parser = dns.wire.Parser(wire)
self.message = None
- self.current = 0
self.initialize_message = initialize_message
self.question_only = question_only
self.one_rr_per_rrset = one_rr_per_rrset
section = self.message.sections[section_number]
for i in range(qcount):
- (qname, used) = dns.name.from_wire(self.wire, self.current)
- if self.message.origin is not None:
- qname = qname.relativize(self.message.origin)
- self.current += used
- (rdtype, rdclass) = \
- struct.unpack('!HH',
- self.wire[self.current:self.current + 4])
- self.current += 4
+ qname = self.parser.get_name(self.message.origin)
+ (rdtype, rdclass) = self.parser.get_struct('!HH')
(rdclass, rdtype, _, _) = \
self.message._parse_rr_header(section_number, qname, rdclass,
rdtype)
section = self.message.sections[section_number]
force_unique = self.one_rr_per_rrset
for i in range(count):
- rr_start = self.current
- (name, used) = dns.name.from_wire(self.wire, self.current)
- absolute_name = name
+ rr_start = self.parser.current
+ absolute_name = self.parser.get_name()
if self.message.origin is not None:
- name = name.relativize(self.message.origin)
- self.current += used
- (rdtype, rdclass, ttl, rdlen) = \
- struct.unpack('!HHIH',
- self.wire[self.current:self.current + 10])
- self.current += 10
+ name = absolute_name.relativize(self.message.origin)
+ else:
+ name = absolute_name
+ (rdtype, rdclass, ttl, rdlen) = self.parser.get_struct('!HHIH')
if rdtype in (dns.rdatatype.OPT, dns.rdatatype.TSIG):
(rdclass, rdtype, deleting, empty) = \
self.message._parse_special_rr_header(section_number,
rd = None
covers = dns.rdatatype.NONE
else:
- rd = dns.rdata.from_wire(rdclass, rdtype,
- self.wire, self.current, rdlen,
- self.message.origin)
+ with self.parser.restrict_to(rdlen):
+ rd = dns.rdata.from_wire_parser(rdclass, rdtype,
+ self.parser,
+ self.message.origin)
covers = rd.covers()
if self.message.xfr and rdtype == dns.rdatatype.SOA:
force_unique = True
raise UnknownTSIGKey("key '%s' unknown" % name)
self.message.keyring = key
self.message.tsig_ctx = \
- dns.tsig.validate(self.wire,
+ dns.tsig.validate(self.parser.wire,
key,
absolute_name,
rd,
if ttl > 0x7fffffff:
ttl = 0
rrset.add(rd, ttl)
- self.current += rdlen
def read(self):
"""Read a wire format DNS message and build a dns.message.Message
object."""
- l = len(self.wire)
- if l < 12:
+ if self.parser.remaining() < 12:
raise ShortHeader
(id, flags, qcount, ancount, aucount, adcount) = \
- struct.unpack('!HHHHHH', self.wire[:12])
- self.current = 12
+ self.parser.get_struct('!HHHHHH')
factory = _message_factory_from_opcode(dns.opcode.from_flags(flags))
self.message = factory(id=id)
self.message.flags = flags
self._get_section(MessageSection.ANSWER, ancount)
self._get_section(MessageSection.AUTHORITY, aucount)
self._get_section(MessageSection.ADDITIONAL, adcount)
- if not self.ignore_trailing and self.current != l:
+ if not self.ignore_trailing and self.parser.remaining() != 0:
raise TrailingJunk
if self.multi and self.message.tsig_ctx and not self.message.had_tsig:
- self.message.tsig_ctx.update(self.wire)
+ self.message.tsig_ctx.update(self.parser.wire)
return self.message
except ImportError: # pragma: no cover
have_idna_2008 = False
+import dns.wire
import dns.exception
-import dns.wiredata
# fullcompare() result values
return Name(labels)
+def from_wire_parser(parser):
+ """Convert possibly compressed wire format into a Name.
+
+ *parser* is a dns.wire.Parser.
+
+ Raises ``dns.name.BadPointer`` if a compression pointer did not
+ point backwards in the message.
+
+ Raises ``dns.name.BadLabelType`` if an invalid label type was encountered.
+
+ Returns a ``dns.name.Name``
+ """
+
+ labels = []
+ biggest_pointer = parser.current
+ with parser.restore_furthest():
+ count = parser.get_uint8()
+ while count != 0:
+ if count < 64:
+ labels.append(parser.get_bytes(count))
+ elif count >= 192:
+ current = (count & 0x3f) * 256 + parser.get_uint8()
+ if current >= biggest_pointer:
+ raise BadPointer
+ biggest_pointer = current
+ parser.seek(current)
+ else:
+ raise BadLabelType
+ count = parser.get_uint8()
+ labels.append(b'')
+ return Name(labels)
+
+
def from_wire(message, current):
"""Convert possibly compressed wire format into a Name.
if not isinstance(message, bytes):
raise ValueError("input to from_wire() must be a byte string")
- message = dns.wiredata.maybe_wrap(message)
- labels = []
- biggest_pointer = current
- hops = 0
- count = message[current]
- current += 1
- cused = 1
- while count != 0:
- if count < 64:
- labels.append(message[current: current + count].unwrap())
- current += count
- if hops == 0:
- cused += count
- elif count >= 192:
- current = (count & 0x3f) * 256 + message[current]
- if hops == 0:
- cused += 1
- if current >= biggest_pointer:
- raise BadPointer
- biggest_pointer = current
- hops += 1
- else:
- raise BadLabelType
- count = message[current]
- current += 1
- if hops == 0:
- cused += 1
- labels.append('')
- return (Name(labels), cused)
+ parser = dns.wire.Parser(message, current)
+ name = from_wire_parser(parser)
+ return (name, parser.current - current)
import inspect
import itertools
+import dns.wire
import dns.exception
import dns.name
import dns.rdataclass
import dns.rdatatype
import dns.tokenizer
-import dns.wiredata
_hex_chunksize = 32
file.write(self.data)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- return cls(rdclass, rdtype, wire[current: current + rdlen])
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ return cls(rdclass, rdtype, parser.get_remaining())
_rdata_classes = {}
_module_prefix = 'dns.rdtypes'
relativize_to)
+def from_wire_parser(rdclass, rdtype, parser, origin=None):
+ """Build an rdata object from wire format
+
+ This function attempts to dynamically load a class which
+ implements the specified rdata class and type. If there is no
+ class-and-type-specific implementation, the GenericRdata class
+ is used.
+
+ Once a class is chosen, its from_wire() class method is called
+ with the parameters to this function.
+
+ *rdclass*, an ``int``, the rdataclass.
+
+ *rdtype*, an ``int``, the rdatatype.
+
+ *parser*, a ``dns.wire.Parser``, the parser, which should be
+ restricted to the rdata length.
+
+ *origin*, a ``dns.name.Name`` (or ``None``). If not ``None``,
+ then names will be relativized to this origin.
+
+ Returns an instance of the chosen Rdata subclass.
+ """
+
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ cls = get_rdata_class(rdclass, rdtype)
+ return cls.from_wire_parser(rdclass, rdtype, parser, origin)
+
+
def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None):
"""Build an rdata object from wire format
Returns an instance of the chosen Rdata subclass.
"""
-
- wire = dns.wiredata.maybe_wrap(wire)
- rdclass = dns.rdataclass.RdataClass.make(rdclass)
- rdtype = dns.rdatatype.RdataType.make(rdtype)
- cls = get_rdata_class(rdclass, rdtype)
- return cls.from_wire(rdclass, rdtype, wire, current, rdlen, origin)
+ parser = dns.wire.Parser(wire, current)
+ with parser.restrict_to(rdlen):
+ return from_wire_parser(rdclass, rdtype, parser, origin)
class RdatatypeExists(dns.exception.DNSException):
canonicalize)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- if rdlen < 2:
- raise dns.exception.FormError
- (precedence, relay_type) = struct.unpack('!BB',
- wire[current: current + 2])
- current += 2
- rdlen -= 2
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (precedence, relay_type) = parser.get_struct('!BB')
discovery_optional = bool(relay_type >> 7)
relay_type &= 0x7f
- (relay, cused) = Relay(relay_type).from_wire(wire, current, rdlen,
- origin)
- current += cused
- rdlen -= cused
+ relay = Relay(relay_type).from_wire_parser(parser, origin)
return cls(rdclass, rdtype, precedence, discovery_optional, relay_type,
relay)
file.write(self.value)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- (flags, l) = struct.unpack('!BB', wire[current: current + 2])
- current += 2
- tag = wire[current: current + l]
- value = wire[current + l:current + rdlen - 2]
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ flags = parser.get_uint8()
+ tag = parser.get_counted_bytes()
+ value = parser.get_remaining()
return cls(rdclass, rdtype, flags, tag, value)
file.write(self.certificate)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- prefix = wire[current: current + 5].unwrap()
- current += 5
- rdlen -= 5
- if rdlen < 0:
- raise dns.exception.FormError
- (certificate_type, key_tag, algorithm) = struct.unpack("!HHB", prefix)
- certificate = wire[current: current + rdlen].unwrap()
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (certificate_type, key_tag, algorithm) = parser.get_struct("!HHB")
+ certificate = parser.get_remaining()
return cls(rdclass, rdtype, certificate_type, key_tag, algorithm,
certificate)
file.write(bitmap)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- if rdlen < 6:
- raise dns.exception.FormError("CSYNC too short")
- (serial, flags) = struct.unpack("!IH", wire[current: current + 6])
- current += 6
- rdlen -= 6
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (serial, flags) = parser.get_struct("!IH")
windows = []
- while rdlen > 0:
- if rdlen < 3:
- raise dns.exception.FormError("CSYNC too short")
- window = wire[current]
- octets = wire[current + 1]
+ while parser.remaining() > 0:
+ window = parser.get_uint8()
+ octets = parser.get_uint8()
if octets == 0 or octets > 32:
raise dns.exception.FormError("bad CSYNC octets")
- current += 2
- rdlen -= 2
- if rdlen < octets:
- raise dns.exception.FormError("bad CSYNC bitmap length")
- bitmap = bytearray(wire[current: current + octets].unwrap())
- current += octets
- rdlen -= octets
+ bitmap = parser.get_bytes(octets)
windows.append((window, bitmap))
return cls(rdclass, rdtype, serial, flags, windows)
file.write(self.altitude)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- l = wire[current]
- current += 1
- rdlen -= 1
- if l > rdlen:
- raise dns.exception.FormError
- latitude = wire[current: current + l].unwrap()
- current += l
- rdlen -= l
- l = wire[current]
- current += 1
- rdlen -= 1
- if l > rdlen:
- raise dns.exception.FormError
- longitude = wire[current: current + l].unwrap()
- current += l
- rdlen -= l
- l = wire[current]
- current += 1
- rdlen -= 1
- if l != rdlen:
- raise dns.exception.FormError
- altitude = wire[current: current + l].unwrap()
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ latitude = parser.get_counted_bytes()
+ longitude = parser.get_counted_bytes()
+ altitude = parser.get_counted_bytes()
return cls(rdclass, rdtype, latitude, longitude, altitude)
@property
file.write(self.os)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- l = wire[current]
- current += 1
- rdlen -= 1
- if l > rdlen:
- raise dns.exception.FormError
- cpu = wire[current:current + l].unwrap()
- current += l
- rdlen -= l
- l = wire[current]
- current += 1
- rdlen -= 1
- if l != rdlen:
- raise dns.exception.FormError
- os = wire[current: current + l].unwrap()
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ cpu = parser.get_counted_bytes()
+ os = parser.get_counted_bytes()
return cls(rdclass, rdtype, cpu, os)
server.to_wire(file, None, origin, False)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- (lh, algorithm, lk) = struct.unpack('!BBH',
- wire[current: current + 4])
- current += 4
- rdlen -= 4
- hit = wire[current: current + lh].unwrap()
- current += lh
- rdlen -= lh
- key = wire[current: current + lk].unwrap()
- current += lk
- rdlen -= lk
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (lh, algorithm, lk) = parser.get_struct('!BBH')
+ hit = parser.get_bytes(lh)
+ key = parser.get_bytes(lk)
servers = []
- while rdlen > 0:
- (server, cused) = dns.name.from_wire(wire[: current + rdlen],
- current)
- current += cused
- rdlen -= cused
- if origin is not None:
- server = server.relativize(origin)
+ while parser.remaining() > 0:
+ server = parser.get_name(origin)
servers.append(server)
return cls(rdclass, rdtype, hit, algorithm, key, servers)
file.write(self.subaddress)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- l = wire[current]
- current += 1
- rdlen -= 1
- if l > rdlen:
- raise dns.exception.FormError
- address = wire[current: current + l].unwrap()
- current += l
- rdlen -= l
- if rdlen > 0:
- l = wire[current]
- current += 1
- rdlen -= 1
- if l != rdlen:
- raise dns.exception.FormError
- subaddress = wire[current: current + l].unwrap()
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ address = parser.get_counted_bytes()
+ if parser.remaining() > 0:
+ subaddress = parser.get_counted_bytes()
else:
- subaddress = ''
+ subaddress = b''
return cls(rdclass, rdtype, address, subaddress)
file.write(wire)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(version, size, hprec, vprec, latitude, longitude, altitude) = \
- struct.unpack("!BBBBIII", wire[current: current + rdlen])
+ parser.get_struct("!BBBBIII")
if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE:
raise dns.exception.FormError("bad latitude")
if latitude > 0x80000000:
file.write(bitmap)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- (next, cused) = dns.name.from_wire(wire[: current + rdlen], current)
- current += cused
- rdlen -= cused
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ next = parser.get_name(origin)
windows = []
- while rdlen > 0:
- if rdlen < 3:
- raise dns.exception.FormError("NSEC too short")
- window = wire[current]
- octets = wire[current + 1]
- if octets == 0 or octets > 32:
+ while parser.remaining() > 0:
+ window = parser.get_uint8()
+ bitmap = parser.get_counted_bytes()
+ if len(bitmap) == 0 or len(bitmap) > 32:
raise dns.exception.FormError("bad NSEC octets")
- current += 2
- rdlen -= 2
- if rdlen < octets:
- raise dns.exception.FormError("bad NSEC bitmap length")
- bitmap = bytearray(wire[current: current + octets].unwrap())
- current += octets
- rdlen -= octets
windows.append((window, bitmap))
- if origin is not None:
- next = next.relativize(origin)
return cls(rdclass, rdtype, next, windows)
file.write(bitmap)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- (algorithm, flags, iterations, slen) = \
- struct.unpack('!BBHB', wire[current: current + 5])
-
- current += 5
- rdlen -= 5
- salt = wire[current: current + slen].unwrap()
- current += slen
- rdlen -= slen
- nlen = wire[current]
- current += 1
- rdlen -= 1
- next = wire[current: current + nlen].unwrap()
- current += nlen
- rdlen -= nlen
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (algorithm, flags, iterations) = parser.get_struct('!BBH')
+ salt = parser.get_counted_bytes()
+ next = parser.get_counted_bytes()
windows = []
- while rdlen > 0:
- if rdlen < 3:
- raise dns.exception.FormError("NSEC3 too short")
- window = wire[current]
- octets = wire[current + 1]
- if octets == 0 or octets > 32:
+ while parser.remaining() > 0:
+ window = parser.get_uint8()
+ bitmap = parser.get_counted_bytes()
+ if len(bitmap) == 0 or len(bitmap) > 32:
raise dns.exception.FormError("bad NSEC3 octets")
- current += 2
- rdlen -= 2
- if rdlen < octets:
- raise dns.exception.FormError("bad NSEC3 bitmap length")
- bitmap = bytearray(wire[current: current + octets].unwrap())
- current += octets
- rdlen -= octets
windows.append((window, bitmap))
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next,
windows)
file.write(self.salt)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- (algorithm, flags, iterations, slen) = \
- struct.unpack('!BBHB',
- wire[current: current + 5])
- current += 5
- rdlen -= 5
- salt = wire[current: current + slen].unwrap()
- current += slen
- rdlen -= slen
- if rdlen != 0:
- raise dns.exception.FormError
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (algorithm, flags, iterations) = parser.get_struct('!BBH')
+ salt = parser.get_counted_bytes()
return cls(rdclass, rdtype, algorithm, flags, iterations, salt)
file.write(self.key)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- key = wire[current: current + rdlen].unwrap()
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ key = parser.get_remaining()
return cls(rdclass, rdtype, key)
return ' '.join(opt.to_text() for opt in self.options)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
options = []
- while rdlen > 0:
- if rdlen < 4:
- raise dns.exception.FormError
- (otype, olen) = struct.unpack('!HH', wire[current:current + 4])
- current += 4
- rdlen -= 4
- if olen > rdlen:
- raise dns.exception.FormError
- opt = dns.edns.option_from_wire(otype, wire, current, olen)
- current += olen
- rdlen -= olen
+ while parser.remaining() > 0:
+ (otype, olen) = parser.get_struct('!HH')
+ with parser.restrict_to(olen):
+ opt = dns.edns.option_from_wire_parser(otype, parser)
options.append(opt)
return cls(rdclass, rdtype, options)
self.txt.to_wire(file, None, origin, canonicalize)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- (mbox, cused) = dns.name.from_wire(wire[: current + rdlen],
- current)
- current += cused
- rdlen -= cused
- if rdlen <= 0:
- raise dns.exception.FormError
- (txt, cused) = dns.name.from_wire(wire[: current + rdlen],
- current)
- if cused != rdlen:
- raise dns.exception.FormError
- if origin is not None:
- mbox = mbox.relativize(origin)
- txt = txt.relativize(origin)
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ mbox = parser.get_name(origin)
+ txt = parser.get_name(origin)
return cls(rdclass, rdtype, mbox, txt)
file.write(self.signature)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- header = struct.unpack('!HBBIIIH', wire[current: current + 18])
- current += 18
- rdlen -= 18
- (signer, cused) = dns.name.from_wire(wire[: current + rdlen], current)
- current += cused
- rdlen -= cused
- if origin is not None:
- signer = signer.relativize(origin)
- signature = wire[current: current + rdlen].unwrap()
- return cls(rdclass, rdtype, header[0], header[1], header[2],
- header[3], header[4], header[5], header[6], signer,
- signature)
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ header = parser.get_struct('!HBBIIIH')
+ signer = parser.get_name(origin)
+ signature = parser.get_remaining()
+ return cls(rdclass, rdtype, *header, signer, signature)
file.write(five_ints)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- (mname, cused) = dns.name.from_wire(wire[: current + rdlen], current)
- current += cused
- rdlen -= cused
- (rname, cused) = dns.name.from_wire(wire[: current + rdlen], current)
- current += cused
- rdlen -= cused
- if rdlen != 20:
- raise dns.exception.FormError
- five_ints = struct.unpack('!IIIII',
- wire[current: current + rdlen])
- if origin is not None:
- mname = mname.relativize(origin)
- rname = rname.relativize(origin)
- return cls(rdclass, rdtype, mname, rname,
- five_ints[0], five_ints[1], five_ints[2], five_ints[3],
- five_ints[4])
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ mname = parser.get_name(origin)
+ rname = parser.get_name(origin)
+ return cls(rdclass, rdtype, mname, rname, *parser.get_struct('!IIIII'))
file.write(self.fingerprint)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- header = struct.unpack("!BB", wire[current: current + 2])
- current += 2
- rdlen -= 2
- fingerprint = wire[current: current + rdlen].unwrap()
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ header = parser.get_struct("BB")
+ fingerprint = parser.get_remaining()
return cls(rdclass, rdtype, header[0], header[1], fingerprint)
file.write(self.cert)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- header = struct.unpack("!BBB", wire[current: current + 3])
- current += 3
- rdlen -= 3
- cert = wire[current: current + rdlen].unwrap()
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ header = parser.get_struct("BBB")
+ cert = parser.get_remaining()
return cls(rdclass, rdtype, header[0], header[1], header[2], cert)
file.write(self.other)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- (algorithm, cused) = dns.name.from_wire(wire[: current + rdlen],
- current)
- current += cused
- rdlen -= cused
- if rdlen < 10:
- raise dns.exception.FormError
- (time_hi, time_lo, fudge, mac_len) = \
- struct.unpack('!HIHH', wire[current: current + 10])
- current += 10
- rdlen -= 10
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ algorithm = parser.get_name(origin)
+ (time_hi, time_lo, fudge) = parser.get_struct('!HIH')
time_signed = (time_hi << 32) + time_lo
- if rdlen < mac_len:
- raise dns.exception.FormError
- mac = wire[current: current + mac_len].unwrap()
- current += mac_len
- rdlen -= mac_len
- if rdlen < 6:
- raise dns.exception.FormError
- (original_id, error, other_len) = \
- struct.unpack('!HHH', wire[current: current + 6])
- current += 6
- rdlen -= 6
- if rdlen < other_len:
- raise dns.exception.FormError
- other = wire[current: current + other_len].unwrap()
- current += other_len
- rdlen -= other_len
+ mac = parser.get_counted_bytes(2)
+ (original_id, error) = parser.get_struct('!HH')
+ other = parser.get_counted_bytes(2)
return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac,
original_id, error, other)
file.write(self.target)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- if rdlen < 5:
- raise dns.exception.FormError('URI RR is shorter than 5 octets')
-
- (priority, weight) = struct.unpack('!HH', wire[current: current + 4])
- current += 4
- rdlen -= 4
- target = wire[current: current + rdlen]
- current += rdlen
-
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (priority, weight) = parser.get_struct('!HH')
+ target = parser.get_remaining()
+ if len(target) == 0:
+ raise dns.exception.FormError('URI target may not be empty')
return cls(rdclass, rdtype, priority, weight, target)
file.write(self.address)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- l = wire[current]
- current += 1
- rdlen -= 1
- if l != rdlen:
- raise dns.exception.FormError
- address = wire[current: current + l].unwrap()
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ address = parser.get_counted_bytes()
return cls(rdclass, rdtype, address)
import dns.rdtypes.mxbase
import struct
-class A(dns.rdtypes.mxbase.MXBase):
+class A(dns.rdata.Rdata):
"""A record for Chaosnet"""
__slots__ = ['domain', 'address']
- def __init__(self, rdclass, rdtype, address, domain):
- super().__init__(rdclass, rdtype, address, domain)
+ def __init__(self, rdclass, rdtype, domain, address):
+ super().__init__(rdclass, rdtype)
object.__setattr__(self, 'domain', domain)
object.__setattr__(self, 'address', address)
domain = tok.get_name(origin, relativize, relativize_to)
address = tok.get_uint16(base=8)
tok.get_eol()
- return cls(rdclass, rdtype, address, domain)
+ return cls(rdclass, rdtype, domain, address)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
self.domain.to_wire(file, compress, origin, canonicalize)
file.write(pref)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- (domain, cused) = dns.name.from_wire(wire[:current + rdlen - 2],
- current)
- current += cused
- (address,) = struct.unpack('!H', wire[current:current + 2])
- if cused + 2 != rdlen:
- raise dns.exception.FormError
- if origin is not None:
- domain = domain.relativize(origin)
- return cls(rdclass, rdtype, address, domain)
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ domain = parser.get_name(origin)
+ address = parser.get_uint16()
+ return cls(rdclass, rdtype, domain, address)
file.write(dns.ipv4.inet_aton(self.address))
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- address = dns.ipv4.inet_ntoa(wire[current: current + rdlen])
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ address = dns.ipv4.inet_ntoa(parser.get_remaining())
return cls(rdclass, rdtype, address)
file.write(dns.ipv6.inet_aton(self.address))
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- address = dns.ipv6.inet_ntoa(wire[current: current + rdlen])
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ address = dns.ipv6.inet_ntoa(parser.get_remaining())
return cls(rdclass, rdtype, address)
item.to_wire(file)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
items = []
- while 1:
- if rdlen == 0:
- break
- if rdlen < 4:
- raise dns.exception.FormError
- header = struct.unpack('!HBB', wire[current: current + 4])
+ while parser.remaining() > 0:
+ header = parser.get_struct('!HBB')
afdlen = header[2]
if afdlen > 127:
negation = True
afdlen -= 128
else:
negation = False
- current += 4
- rdlen -= 4
- if rdlen < afdlen:
- raise dns.exception.FormError
- address = wire[current: current + afdlen].unwrap()
+ address = parser.get_bytes(afdlen)
l = len(address)
if header[0] == 1:
if l < 4:
# seems better than throwing an exception
#
address = codecs.encode(address, 'hex_codec')
- current += afdlen
- rdlen -= afdlen
item = APLItem(header[0], negation, address, header[1])
items.append(item)
return cls(rdclass, rdtype, items)
file.write(self.data)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- data = wire[current: current + rdlen].unwrap()
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ data = parser.get_remaining()
return cls(rdclass, rdtype, data)
file.write(self.key)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- if rdlen < 3:
- raise dns.exception.FormError
- header = struct.unpack('!BBB', wire[current: current + 3])
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ header = parser.get_struct('!BBB')
gateway_type = header[1]
- current += 3
- rdlen -= 3
- (gateway, cused) = Gateway(gateway_type).from_wire(wire, current,
- rdlen, origin)
- current += cused
- rdlen -= cused
- key = wire[current: current + rdlen].unwrap()
- if origin is not None and gateway_type == 3:
- gateway = gateway.relativize(origin)
+ gateway = Gateway(gateway_type).from_wire_parser(parser, origin)
+ key = parser.get_remaining()
return cls(rdclass, rdtype, header[0], gateway_type, header[2],
gateway, key)
self.replacement.to_wire(file, compress, origin, canonicalize)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- (order, preference) = struct.unpack('!HH', wire[current: current + 4])
- current += 4
- rdlen -= 4
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (order, preference) = parser.get_struct('!HH')
strings = []
for i in range(3):
- l = wire[current]
- current += 1
- rdlen -= 1
- if l > rdlen or rdlen < 0:
- raise dns.exception.FormError
- s = wire[current: current + l].unwrap()
- current += l
- rdlen -= l
+ s = parser.get_counted_bytes()
strings.append(s)
- (replacement, cused) = dns.name.from_wire(wire[: current + rdlen],
- current)
- if cused != rdlen:
- raise dns.exception.FormError
- if origin is not None:
- replacement = replacement.relativize(origin)
+ replacement = parser.get_name(origin)
return cls(rdclass, rdtype, order, preference, strings[0], strings[1],
strings[2], replacement)
file.write(self.address)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- address = wire[current: current + rdlen].unwrap()
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ address = parser.get_remaining()
return cls(rdclass, rdtype, address)
self.mapx400.to_wire(file, None, origin, canonicalize)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- (preference, ) = struct.unpack('!H', wire[current: current + 2])
- current += 2
- rdlen -= 2
- (map822, cused) = dns.name.from_wire(wire[: current + rdlen],
- current)
- if cused > rdlen:
- raise dns.exception.FormError
- current += cused
- rdlen -= cused
- if origin is not None:
- map822 = map822.relativize(origin)
- (mapx400, cused) = dns.name.from_wire(wire[: current + rdlen],
- current)
- if cused != rdlen:
- raise dns.exception.FormError
- if origin is not None:
- mapx400 = mapx400.relativize(origin)
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ preference = parser.get_uint16()
+ map822 = parser.get_name(origin)
+ mapx400 = parser.get_name(origin)
return cls(rdclass, rdtype, preference, map822, mapx400)
self.target.to_wire(file, compress, origin, canonicalize)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- (priority, weight, port) = struct.unpack('!HHH',
- wire[current: current + 6])
- current += 6
- rdlen -= 6
- (target, cused) = dns.name.from_wire(wire[: current + rdlen],
- current)
- if cused != rdlen:
- raise dns.exception.FormError
- if origin is not None:
- target = target.relativize(origin)
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (priority, weight, port) = parser.get_struct('!HHH')
+ target = parser.get_name(origin)
return cls(rdclass, rdtype, priority, weight, port, target)
file.write(self.bitmap)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- address = dns.ipv4.inet_ntoa(wire[current: current + 4])
- protocol, = struct.unpack('!B', wire[current + 4: current + 5])
- current += 5
- rdlen -= 5
- bitmap = wire[current: current + rdlen].unwrap()
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ address = dns.ipv4.inet_ntoa(parser.get_bytes(4))
+ protocol = parser.get_uint8()
+ bitmap = parser.get_remaining()
return cls(rdclass, rdtype, address, protocol, bitmap)
file.write(self.key)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- if rdlen < 4:
- raise dns.exception.FormError
- header = struct.unpack('!HBB', wire[current: current + 4])
- current += 4
- rdlen -= 4
- key = wire[current: current + rdlen].unwrap()
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ header = parser.get_struct('!HBB')
+ key = parser.get_remaining()
return cls(rdclass, rdtype, header[0], header[1], header[2],
key)
...
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
+ def from_parser(cls, rdclass, rdtype, parser, origin=None):
...
def flags_to_text_set(self) -> Set[str]:
file.write(self.digest)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- header = struct.unpack("!HBB", wire[current: current + 4])
- current += 4
- rdlen -= 4
- digest = wire[current: current + rdlen].unwrap()
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ header = parser.get_struct("!HBB")
+ digest = parser.get_remaining()
return cls(rdclass, rdtype, header[0], header[1], header[2], digest)
file.write(self.eui)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- eui = wire[current:current + rdlen].unwrap()
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ eui = parser.get_bytes(cls.byte_len)
return cls(rdclass, rdtype, eui)
self.exchange.to_wire(file, compress, origin, canonicalize)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- (preference, ) = struct.unpack('!H', wire[current: current + 2])
- current += 2
- rdlen -= 2
- (exchange, cused) = dns.name.from_wire(wire[: current + rdlen],
- current)
- if cused != rdlen:
- raise dns.exception.FormError
- if origin is not None:
- exchange = exchange.relativize(origin)
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ preference = parser.get_uint16()
+ exchange = parser.get_name(origin)
return cls(rdclass, rdtype, preference, exchange)
self.target.to_wire(file, compress, origin, canonicalize)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
- (target, cused) = dns.name.from_wire(wire[: current + rdlen],
- current)
- if cused != rdlen:
- raise dns.exception.FormError
- if origin is not None:
- target = target.relativize(origin)
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ target = parser.get_name(origin)
return cls(rdclass, rdtype, target)
file.write(s)
@classmethod
- def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
strings = []
- while rdlen > 0:
- l = wire[current]
- current += 1
- rdlen -= 1
- if l > rdlen:
- raise dns.exception.FormError
- s = wire[current: current + l].unwrap()
- current += l
- rdlen -= l
+ while parser.remaining() > 0:
+ s = parser.get_counted_bytes()
strings.append(s)
return cls(rdclass, rdtype, strings)
else:
raise ValueError(self._invalid_type())
- def from_wire(self, wire, current, rdlen, origin=None):
+ def from_wire_parser(self, parser, origin=None):
if self.type == 0:
- return (None, 0)
+ return None
elif self.type == 1:
- return (dns.ipv4.inet_ntoa(wire[current: current + 4]), 4)
+ return dns.ipv4.inet_ntoa(parser.get_bytes(4))
elif self.type == 2:
- return (dns.ipv6.inet_ntoa(wire[current: current + 16]), 16)
+ return dns.ipv6.inet_ntoa(parser.get_bytes(16))
elif self.type == 3:
- return dns.name.from_wire(wire[: current + rdlen], current)
+ return parser.get_name(origin)
else:
raise dns.exception.FormError(self._invalid_type())
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import contextlib
+import struct
+
+import dns.exception
+import dns.name
+
+class Parser:
+ def __init__(self, wire, current=0):
+ self.wire = wire
+ self.current = 0
+ self.end = len(self.wire)
+ if current:
+ self.seek(current)
+ self.furthest = current
+
+ def remaining(self):
+ return self.end - self.current
+
+ def get_bytes(self, size):
+ if size > self.remaining():
+ raise dns.exception.FormError
+ output = self.wire[self.current:self.current + size]
+ self.current += size
+ self.furthest = max(self.furthest, self.current)
+ return output
+
+ def get_counted_bytes(self, length_size=1):
+ length = int.from_bytes(self.get_bytes(length_size), 'big')
+ return self.get_bytes(length)
+
+ def get_remaining(self):
+ return self.get_bytes(self.remaining())
+
+ def get_uint8(self):
+ return struct.unpack('!B', self.get_bytes(1))[0]
+
+ def get_uint16(self):
+ return struct.unpack('!H', self.get_bytes(2))[0]
+
+ def get_uint32(self):
+ return struct.unpack('!I', self.get_bytes(4))[0]
+
+ def get_struct(self, format):
+ return struct.unpack(format, self.get_bytes(struct.calcsize(format)))
+
+ def get_name(self, origin=None):
+ name = dns.name.from_wire_parser(self)
+ if origin:
+ name = name.relativize(origin)
+ return name
+
+ def seek(self, where):
+ # Note that seeking to the end is OK! (If you try to read
+ # after such a seek, you'll get an exception as expected.)
+ if where < 0 or where > self.end:
+ raise dns.exception.FormError
+ self.current = where
+
+ @contextlib.contextmanager
+ def restrict_to(self, size):
+ if size > self.remaining():
+ raise dns.exception.FormError
+ saved_end = self.end
+ try:
+ self.end = self.current + size
+ yield
+ # We make this check here and not in the finally as we
+ # don't want to raise if we're already raising for some
+ # other reason.
+ if self.current != self.end:
+ raise dns.exception.FormError
+ finally:
+ self.end = saved_end
+
+ @contextlib.contextmanager
+ def restore_furthest(self):
+ try:
+ yield None
+ finally:
+ self.current = self.furthest
+++ /dev/null
-# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
-
-# Copyright (C) 2011,2017 Nominum, Inc.
-#
-# Permission to use, copy, modify, and distribute this software and its
-# documentation for any purpose with or without fee is hereby granted,
-# provided that the above copyright notice and this permission notice
-# appear in all copies.
-#
-# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
-# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
-# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
-# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
-# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
-# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
-# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
-
-"""DNS Wire Data Helper"""
-
-import dns.exception
-
-
-class WireData(bytes):
- # WireData is a binary type with stricter slicing
-
- def __getitem__(self, key):
- try:
- if isinstance(key, slice):
- # make sure we are not going outside of valid ranges,
- # do stricter control of boundaries than python does
- # by default
-
- for index in (key.start, key.stop):
- if index is None:
- continue
- elif abs(index) > len(self):
- raise dns.exception.FormError
-
- return WireData(super().__getitem__(key))
- return super().__getitem__(key)
- except IndexError:
- raise dns.exception.FormError
-
- def unwrap(self):
- return bytes(self)
-
-
-def maybe_wrap(wire):
- if isinstance(wire, WireData):
- return wire
- elif isinstance(wire, bytes):
- return WireData(wire)
- elif isinstance(wire, str):
- return WireData(wire.encode())
- raise ValueError("unhandled type %s" % type(wire))
import struct
import unittest
+import dns.wire
import dns.exception
import dns.name
import dns.rdata
def test_opt_short_lengths(self):
def bad1():
- opt = OPT.from_wire(4096, dns.rdatatype.OPT,
- binascii.unhexlify('f00102'), 0, 3)
+ parser = dns.wire.Parser(bytes.fromhex('f00102'))
+ opt = OPT.from_wire_parser(4096, dns.rdatatype.OPT, parser)
self.assertRaises(dns.exception.FormError, bad1)
def bad2():
- opt = OPT.from_wire(4096, dns.rdatatype.OPT,
- binascii.unhexlify('f00100030000'), 0, 6)
+ parser = dns.wire.Parser(bytes.fromhex('f00100030000'))
+ opt = OPT.from_wire_parser(4096, dns.rdatatype.OPT, parser)
self.assertRaises(dns.exception.FormError, bad2)
+ def test_from_wire_parser(self):
+ wire = bytes.fromhex('01020304')
+ rdata = dns.rdata.from_wire('in', 'a', wire, 0, 4)
+ print(rdata)
+ self.assertEqual(rdata, dns.rdata.from_text('in', 'a', '1.2.3.4'))
+
if __name__ == '__main__':
unittest.main()
'''Valid wire format.'''
eui = b'\x01\x23\x45\x67\x89\xab'
pad_len = 100
- wire = dns.wiredata.WireData(b'x' * pad_len + eui + b'y' * pad_len * 2)
- inst = dns.rdtypes.ANY.EUI48.EUI48.from_wire(dns.rdataclass.IN,
- dns.rdatatype.EUI48,
- wire,
- pad_len,
- len(eui))
+ wire = b'x' * pad_len + eui + b'y' * pad_len * 2
+ inst = dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.EUI48,
+ wire, pad_len, len(eui))
self.assertEqual(inst.eui, eui)
def testFromWireLength(self):
'''Valid wire format.'''
eui = b'\x01\x23\x45\x67\x89'
pad_len = 100
- wire = dns.wiredata.WireData(b'x' * pad_len + eui + b'y' * pad_len * 2)
+ wire = b'x' * pad_len + eui + b'y' * pad_len * 2
with self.assertRaises(dns.exception.FormError):
- dns.rdtypes.ANY.EUI48.EUI48.from_wire(dns.rdataclass.IN,
- dns.rdatatype.EUI48,
- wire,
- pad_len,
- len(eui))
+ dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.EUI48,
+ wire, pad_len, len(eui))
class RdtypeAnyEUI64TestCase(unittest.TestCase):
'''Valid wire format.'''
eui = b'\x01\x23\x45\x67\x89\xab\xcd\xef'
pad_len = 100
- wire = dns.wiredata.WireData(b'x' * pad_len + eui + b'y' * pad_len * 2)
- inst = dns.rdtypes.ANY.EUI64.EUI64.from_wire(dns.rdataclass.IN,
- dns.rdatatype.EUI64,
- wire,
- pad_len,
- len(eui))
+ wire = b'x' * pad_len + eui + b'y' * pad_len * 2
+ inst = dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.EUI64,
+ wire, pad_len, len(eui))
self.assertEqual(inst.eui, eui)
def testFromWireLength(self):
'''Valid wire format.'''
eui = b'\x01\x23\x45\x67\x89'
pad_len = 100
- wire = dns.wiredata.WireData(b'x' * pad_len + eui + b'y' * pad_len * 2)
+ wire = b'x' * pad_len + eui + b'y' * pad_len * 2
with self.assertRaises(dns.exception.FormError):
- dns.rdtypes.ANY.EUI64.EUI64.from_wire(dns.rdataclass.IN,
- dns.rdatatype.EUI64,
- wire,
- pad_len,
- len(eui))
+ dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.EUI64,
+ wire, pad_len, len(eui))
if __name__ == '__main__':
--- /dev/null
+
+import unittest
+
+import dns.exception
+import dns.wire
+import dns.name
+
+
+class BinaryTestCase(unittest.TestCase):
+
+ def test_basic(self):
+ wire = bytes.fromhex('0102010203040102')
+ p = dns.wire.Parser(wire)
+ self.assertEqual(p.get_uint16(), 0x0102)
+ with p.restrict_to(5):
+ self.assertEqual(p.get_uint32(), 0x01020304)
+ self.assertEqual(p.get_uint8(), 0x01)
+ self.assertEqual(p.remaining(), 0)
+ with self.assertRaises(dns.exception.FormError):
+ p.get_uint16()
+ self.assertEqual(p.remaining(), 1)
+ self.assertEqual(p.get_uint8(), 0x02)
+ with self.assertRaises(dns.exception.FormError):
+ p.get_uint8()
+
+ def test_name(self):
+ # www.dnspython.org NS IN question
+ wire = b'\x03www\x09dnspython\x03org\x00\x00\x02\x00\x01'
+ expected = dns.name.from_text('www.dnspython.org')
+ p = dns.wire.Parser(wire)
+ self.assertEqual(p.get_name(), expected)
+ self.assertEqual(p.get_uint16(), 2)
+ self.assertEqual(p.get_uint16(), 1)
+ self.assertEqual(p.remaining(), 0)
+
+ def test_relativized_name(self):
+ # www.dnspython.org NS IN question
+ wire = b'\x03www\x09dnspython\x03org\x00\x00\x02\x00\x01'
+ origin = dns.name.from_text('dnspython.org')
+ expected = dns.name.from_text('www', None)
+ p = dns.wire.Parser(wire)
+ self.assertEqual(p.get_name(origin), expected)
+ self.assertEqual(p.remaining(), 4)
+
+ def test_compressed_name(self):
+ # www.dnspython.org NS IN question
+ wire = b'\x09dnspython\x03org\x00\x03www\xc0\x00'
+ expected1 = dns.name.from_text('dnspython.org')
+ expected2 = dns.name.from_text('www.dnspython.org')
+ p = dns.wire.Parser(wire)
+ self.assertEqual(p.get_name(), expected1)
+ self.assertEqual(p.get_name(), expected2)
+ self.assertEqual(p.remaining(), 0)
+ # verify the unseek()
+ self.assertEqual(p.current, len(wire))
+++ /dev/null
-# Copyright (C) 2016
-# Author: Martin Basti <martin.basti@gmail.com>
-#
-# Permission to use, copy, modify, and distribute this software and its
-# documentation for any purpose with or without fee is hereby granted,
-# provided that the above copyright notice and this permission notice
-# appear in all copies.
-
-import unittest
-
-from dns.exception import FormError
-from dns.wiredata import WireData, maybe_wrap
-
-
-class WireDataSlicingTestCase(unittest.TestCase):
-
- def testSliceAll(self):
- """Get all data"""
- inst = WireData(b'0123456789')
- self.assertEqual(inst[:], WireData(b'0123456789'))
-
- def testSliceAllExplicitlyDefined(self):
- """Get all data"""
- inst = WireData(b'0123456789')
- self.assertEqual(inst[0:10], WireData(b'0123456789'))
-
- def testSliceLowerHalf(self):
- """Get lower half of data"""
- inst = WireData(b'0123456789')
- self.assertEqual(inst[:5], WireData(b'01234'))
-
- def testSliceLowerHalfWithNegativeIndex(self):
- """Get lower half of data"""
- inst = WireData(b'0123456789')
- self.assertEqual(inst[:-5], WireData(b'01234'))
-
- def testSliceUpperHalf(self):
- """Get upper half of data"""
- inst = WireData(b'0123456789')
- self.assertEqual(inst[5:], WireData(b'56789'))
-
- def testSliceMiddle(self):
- """Get data from middle"""
- inst = WireData(b'0123456789')
- self.assertEqual(inst[3:6], WireData(b'345'))
-
- def testSliceMiddleWithNegativeIndex(self):
- """Get data from middle"""
- inst = WireData(b'0123456789')
- self.assertEqual(inst[-6:-3], WireData(b'456'))
-
- def testSliceMiddleWithMixedIndex(self):
- """Get data from middle"""
- inst = WireData(b'0123456789')
- self.assertEqual(inst[-8:3], WireData(b'2'))
- self.assertEqual(inst[5:-3], WireData(b'56'))
-
- def testGetOne(self):
- """Get data one by one item"""
- data = b'0123456789'
- inst = WireData(data)
- for i, byte in enumerate(data):
- self.assertEqual(inst[i], byte)
- for i in range(-1, len(data) * -1, -1):
- self.assertEqual(inst[i], data[i])
-
- def testEmptySlice(self):
- """Test empty slice"""
- data = b'0123456789'
- inst = WireData(data)
- for i, byte in enumerate(data):
- self.assertEqual(inst[i:i], b'')
- for i in range(-1, len(data) * -1, -1):
- self.assertEqual(inst[i:i], b'')
- self.assertEqual(inst[-3:-6], b'')
-
- def testSliceStartOutOfLowerBorder(self):
- """Get data from out of lower border"""
- inst = WireData(b'0123456789')
- with self.assertRaises(FormError):
- inst[-11:] # pylint: disable=pointless-statement
-
- def testSliceStopOutOfLowerBorder(self):
- """Get data from out of lower border"""
- inst = WireData(b'0123456789')
- with self.assertRaises(FormError):
- inst[:-11] # pylint: disable=pointless-statement
-
- def testSliceBothOutOfLowerBorder(self):
- """Get data from out of lower border"""
- inst = WireData(b'0123456789')
- with self.assertRaises(FormError):
- inst[-12:-11] # pylint: disable=pointless-statement
-
- def testSliceStartOutOfUpperBorder(self):
- """Get data from out of upper border"""
- inst = WireData(b'0123456789')
- with self.assertRaises(FormError):
- inst[11:] # pylint: disable=pointless-statement
-
- def testSliceStopOutOfUpperBorder(self):
- """Get data from out of upper border"""
- inst = WireData(b'0123456789')
- with self.assertRaises(FormError):
- inst[:11] # pylint: disable=pointless-statement
-
- def testSliceBothOutOfUpperBorder(self):
- """Get data from out of lower border"""
- inst = WireData(b'0123456789')
- with self.assertRaises(FormError):
- inst[10:20] # pylint: disable=pointless-statement
-
- def testGetOneOutOfLowerBorder(self):
- """Get item outside of range"""
- inst = WireData(b'0123456789')
- with self.assertRaises(FormError):
- inst[-11] # pylint: disable=pointless-statement
-
- def testGetOneOutOfUpperBorder(self):
- """Get item outside of range"""
- inst = WireData(b'0123456789')
- with self.assertRaises(FormError):
- inst[10] # pylint: disable=pointless-statement
-
- def testIteration(self):
- bval = b'0123'
- inst = WireData(bval)
- l = list(inst)
- self.assertEqual(l, [x for x in bval])
-
- def testBadWrap(self):
- def bad():
- w = maybe_wrap(123)
- self.assertRaises(ValueError, bad)