-class Query:
+import traceback
+import dns.message
+import dns.rrset
+import dns.rcode
- match_fields = []
+
+class Entry:
+ default_ttl = 3600
+ default_cls = 'IN'
+ default_rc = 'NOERROR'
def __init__(self):
- pass
+ self.match_fields = None
+ self.adjust_fields = None
+ self.message = dns.message.Message()
+
+ def match_part(self, code, msg):
+ if code not in self.match_fields and 'all' not in self.match_fields:
+ return True
+ expected = self.message
+ if code == 'opcode':
+ return self.__compare_val(expected.opcode(), msg.opcode())
+ elif code == 'qtype':
+ return self.__compare_val(expected.question[0].rdtype, msg.question[0].rdtype)
+ elif code == 'qname':
+ return self.__compare_val(expected.question[0].name, msg.question[0].name)
+ elif code == 'flags':
+ return self.__compare_val(dns.flags.to_text(expected.flags), dns.flags.to_text(msg.flags))
+ elif code == 'question':
+ return self.__compare_rrs(expected.question, msg.question)
+ elif code == 'answer':
+ return self.__compare_rrs(expected.answer, msg.answer)
+ elif code == 'authority':
+ return self.__compare_rrs(expected.authority, msg.authority)
+ elif code == 'additional':
+ return self.__compare_rrs(expected.additional, msg.additional)
+ else:
+ raise Exception('unknown match request "%s"' % code)
+
+ def match(self, msg):
+ match_fields = self.match_fields
+ if 'all' in match_fields:
+ match_fields = ('flags', 'question', 'answer', 'authority', 'additional')
+ for code in match_fields:
+ try:
+ self.match_part(code, msg)
+ except Exception as e:
+ raise Exception("when matching %s: %s" % (code, str(e)))
- def match(self, fields):
+ def set_match(self, fields):
self.match_fields = fields
- def parse(self, text):
- pass
+ def set_adjust(self, fields):
+ self.adjust_fields = fields
-class Range:
+ def set_reply(self, fields):
+ flags = []
+ rcode = dns.rcode.from_text(self.default_rc)
+ for code in fields:
+ try:
+ rcode = dns.rcode.from_text(code)
+ except:
+ flags.append(code)
+ self.message.flags = dns.flags.from_text(' '.join(flags))
+ self.message.rcode = rcode
+
+ def begin_section(self, section):
+ self.section = section
+
+ def add_record(self, owner, args):
+ rr = self.__rr_from_str(owner, args)
+ if self.section == 'QUESTION':
+ self.message.question.append(rr)
+ elif self.section == 'ANSWER':
+ self.message.answer.append(rr)
+ elif self.section == 'AUTHORITY':
+ self.message.authority.append(rr)
+ elif self.section == 'ADDITIONAL':
+ self.message.additional.append(rr)
+ else:
+ raise Exception('attempted to add record in section %s' % self.section)
+
+
+ def __rr_from_str(self, owner, args):
+ ttl = self.default_ttl
+ rdclass = self.default_cls
+ try:
+ dns.ttl.from_text(args[0])
+ ttl = args.pop(0)
+ except:
+ pass # optional
+ try:
+ dns.rdataclass.from_text(args[0])
+ rdclass = args.pop(0)
+ except:
+ pass # optional
+ rdtype = args.pop(0)
+ if len(args) > 0:
+ return dns.rrset.from_text(owner, ttl, rdclass, rdtype, ' '.join(args))
+ else:
+ return dns.rrset.from_text(owner, ttl, rdclass, rdtype)
+
+ def __compare_rrs(self, name, expected, got):
+ for rr in expected:
+ if rr not in got:
+ raise Exception("expected record '%s'" % rr.to_text())
+ for rr in got:
+ if rr not in expected:
+ raise Exception("unexpected record '%s'" % rr.to_text())
+ return True
- a = 0
- b = 0
- queries = []
+ def __compare_val(self, expected, got):
+ if expected != got:
+ raise Exception("expected '%s', got '%s'" % (expected, got))
+ return True
+
+class Range:
def __init__(self, a, b):
self.a = a
self.b = b
+ self.queries = []
- def add_query(self, query):
- self.queries.append(query)
+ def add(self, entry):
+ self.queries.append(entry)
-class Scenario:
+class Step:
+ def __init__(self, id, type):
+ self.id = int(id)
+ self.type = type
+ self.data = []
- name = ''
- ranges = []
- steps = []
+ def add(self, entry):
+ self.data.append(entry)
- def __init__(self):
- pass
+ def play(self, ctx):
+ if self.type == 'QUERY':
+ return self.__query(ctx)
+ elif self.type == 'CHECK_ANSWER':
+ return self.__check_answer(ctx)
+ else:
+ print '%d %s (%d entries) => NOOP' % (self.id, self.type, len(self.data))
+ return None
+
+ def __check_answer(self, ctx):
+ if len(self.data) == 0:
+ raise Exception("response definition required")
+ if ctx.last_answer is None:
+ raise Exception("no answer from preceding query")
+ expected = self.data[0]
+ expected.match(ctx.last_answer)
+
+ def __query(self, ctx):
+ if len(self.data) == 0:
+ raise Exception("query definition required")
+ msg = self.data[0].message
+ self.answer = ctx.resolve(msg.to_wire())
+ if self.answer is not None:
+ self.answer = dns.message.from_wire(self.answer)
+ ctx.last_answer = self.answer
+
+
+class Scenario:
+ def __init__(self, info):
+ print '# %s' % info
+ self.ranges = []
+ self.steps = []
- def begin(self, explanation):
- print '# %s' % explanation
+ def play(self, ctx):
+ step = None
+ if len(self.steps) == 0:
+ raise ('no steps in this scenario')
+ try:
+ for step in self.steps:
+ step.play(ctx)
+ except Exception as e:
+ raise Exception('on step #%d "%s": %s\n%s' % (step.id, step.type, str(e), traceback.format_exc()))
- def range(self, a, b):
- range_new = Range(a, b)
- self.ranges.append(range_new)
- return range_new
- def step(self, n, step_type):
- pass
#!/usr/bin/env python
import sys, os, fileinput
-import _test_integration
+from pydnstest import scenario
+import _test_integration as mock_ctx
+def get_next(file_in):
+ while True:
+ line = file_in.readline()
+ if len(line) == 0:
+ return False
+ tokens = ' '.join(line.strip().split()).split()
+ if len(tokens) == 0:
+ continue # Skip empty lines
+ op = tokens.pop(0)
+ if op.startswith(';') or op.startswith('#'):
+ continue # Skip comments
+ return op, tokens
-def parse_entry(line, file_in):
+def parse_entry(op, args, file_in):
""" Parse entry definition. """
- print line.split(' ')
- for line in iter(lambda: file_in.readline(), ''):
- if line.startswith('ENTRY_END'):
+ out = scenario.Entry()
+ for op, args in iter(lambda: get_next(file_in), False):
+ if op == 'ENTRY_END':
break
+ elif op == 'REPLY':
+ out.set_reply(args)
+ elif op == 'MATCH':
+ out.set_match(args)
+ elif op == 'ADJUST':
+ out.set_adjust(args)
+ elif op == 'SECTION':
+ out.begin_section(args[0])
+ else:
+ out.add_record(op, args)
+ return out
-def parse_step(line, file_in):
+def parse_step(op, args, file_in):
""" Parse range definition. """
- print line.split(' ')
+ if len(args) < 2:
+ raise Exception('expected STEP <id> <type>')
+ out = scenario.Step(args[0], args[1])
+ op, args = get_next(file_in)
+ # Optional data
+ if op == 'ENTRY_BEGIN':
+ out.add(parse_entry(op, args, file_in))
+ else:
+ raise Exception('expected "ENTRY_BEGIN"')
+ return out
-def parse_range(line, file_in):
+def parse_range(op, args, file_in):
""" Parse range definition. """
- print line.split(' ')
- for line in iter(lambda: file_in.readline(), ''):
- if line.startswith('ENTRY_BEGIN'):
- parse_entry(line, file_in)
- if line.startswith('RANGE_END'):
+ if len(args) < 2:
+ raise Exception('expected RANGE_BEGIN <from> <to>')
+ out = scenario.Range(int(args[0]), int(args[1]))
+ for op, args in iter(lambda: get_next(file_in), False):
+ if op == 'ADDRESS':
+ out.address = args[0]
+ elif op == 'ENTRY_BEGIN':
+ out.add(parse_entry(op, args, file_in))
+ elif op == 'RANGE_END':
break
+ return out
-def parse_scenario(line, file_in):
+def parse_scenario(op, args, file_in):
""" Parse scenario definition. """
- print line.split(' ')
- for line in iter(lambda: file_in.readline(), ''):
- if line.startswith('SCENARIO_END'):
+ out = scenario.Scenario(args[0])
+ for op, args in iter(lambda: get_next(file_in), False):
+ if op == 'SCENARIO_END':
break
- if line.startswith('RANGE_BEGIN'):
- parse_range(line, file_in)
- if line.startswith('STEP'):
- parse_step(line, file_in)
-
+ if op == 'RANGE_BEGIN':
+ out.ranges.append(parse_range(op, args, file_in))
+ if op == 'STEP':
+ out.steps.append(parse_step(op, args, file_in))
+ return out
def parse_file(file_in):
- """ Parse and play scenario from a file. """
+ """ Parse scenario from a file. """
try:
- for line in iter(lambda: file_in.readline(), ''):
- if line.startswith('SCENARIO_BEGIN'):
- return parse_scenario(line, file_in)
+ for op, args in iter(lambda: get_next(file_in), False):
+ if op == 'SCENARIO_BEGIN':
+ return parse_scenario(op, args, file_in)
raise Exception("IGNORE (missing scenario)")
except Exception as e:
raise Exception('line %d: %s' % (file_in.lineno(), str(e)))
-
def parse_object(path):
""" Recursively scan file/directory for scenarios. """
if os.path.isdir(path):
for e in os.listdir(path):
parse_object(os.path.join(path, e))
elif os.path.isfile(path):
- file_in = fileinput.input(path)
- try:
- parse_file(file_in)
- print('%s OK' % os.path.basename(path))
- except Exception as e:
- print('%s %s' % (os.path.basename(path), str(e)))
- file_in.close()
+ play_object(path)
+def play_object(path):
+ """ Play scenario from a file object. """
+ file_in = fileinput.input(path)
+ mock_ctx.init()
+ try:
+ scenario = parse_file(file_in)
+ scenario.play(mock_ctx)
+ print('%s OK' % os.path.basename(path))
+ except Exception as e:
+ print('%s %s' % (os.path.basename(path), str(e)))
+ mock_ctx.deinit()
+ file_in.close()
if __name__ == '__main__':
for arg in sys.argv[1:]: