import subprocess
import time
-import dns
import jinja2
import pytest
def ping_alive(sock):
- msgid = utils.random_msgid()
- buf = utils.get_msgbuf('localhost.', dns.rdatatype.A, msgid)
- sock.sendall(buf)
+ buff, msgid = utils.get_msgbuff()
+ sock.sendall(buff)
answer = utils.receive_parse_answer(sock)
return answer.id == msgid
"""TCP Connection Management tests"""
-import dns
-import dns.message
-
import utils
Expected: garbage must be ignored and the second query must be answered
"""
- MSG_ID = 1
-
- msg = utils.get_msgbuf('localhost.', dns.rdatatype.A, MSG_ID)
- garbage = utils.get_prefixed_garbage(1024)
- buf = garbage + msg
+ msg_buff, msgid = utils.get_msgbuff()
+ garbage_buff = utils.get_prefixed_garbage(1024)
+ kresd_sock.sendall(garbage_buff + msg_buff)
- kresd_sock.sendall(buf)
msg_answer = utils.receive_parse_answer(kresd_sock)
-
- assert msg_answer.id == MSG_ID
+ assert msg_answer.id == msgid
def test_pipelining(kresd_sock):
Expected: answer to the second query must come first.
"""
- MSG_ID_FIRST = 1
- MSG_ID_SECOND = 2
+ buff1, msgid1 = utils.get_msgbuff('1000.delay.getdnsapi.net.', msgid=1)
+ buff2, msgid2 = utils.get_msgbuff('1.delay.getdnsapi.net.', msgid=2)
+ buff = buff1 + buff2
+ kresd_sock.sendall(buff)
- buf = utils.get_msgbuf('1000.delay.getdnsapi.net.', dns.rdatatype.A, MSG_ID_FIRST) \
- + utils.get_msgbuf('1.delay.getdnsapi.net.', dns.rdatatype.A, MSG_ID_SECOND)
-
- kresd_sock.sendall(buf)
msg_answer = utils.receive_parse_answer(kresd_sock)
-
- assert msg_answer.id == MSG_ID_SECOND
+ assert msg_answer.id == msgid2
def test_less_than_header(kresd_sock):
"""Prefix is less than the length of the DNS message header."""
- wire = utils.prepare_wire()
+ wire, _ = utils.prepare_wire()
datalen = 11 # DNS header size minus 1
buff = utils.prepare_buffer(wire, datalen)
send_incorrect_repeatedly(kresd_sock, buff)
def test_greater_than_message(kresd_sock):
"""Prefix is greater than the length of the entire DNS message."""
- wire = utils.prepare_wire()
+ wire, _ = utils.prepare_wire()
datalen = len(wire) + 16
buff = utils.prepare_buffer(wire, datalen)
send_incorrect_repeatedly(kresd_sock, buff)
def test_cuts_message(kresd_sock):
"""Prefix is greater than the length of the DNS message header, but shorter than
the entire DNS message."""
- wire = utils.prepare_wire()
+ wire, _ = utils.prepare_wire()
datalen = 14 # DNS Header size plus 2
assert datalen < len(wire)
buff = utils.prepare_buffer(wire, datalen)
"""First, normal DNS message is sent. Afterwards, message with incorrect prefix
(greater than header, less than entire message) is sent. First message must be
answered, then the connection should be closed after timeout."""
- normal_msg_id = 1
- normal_wire = utils.prepare_wire(normal_msg_id)
+ normal_wire, normal_msgid = utils.prepare_wire(msgid=1)
normal_buff = utils.prepare_buffer(normal_wire)
- cut_wire = utils.prepare_wire()
+ cut_wire, _ = utils.prepare_wire(msgid=2)
cut_datalen = 14
assert cut_datalen < len(cut_wire)
cut_buff = utils.prepare_buffer(cut_wire, cut_datalen)
"""Prefix is correct, but the message has trailing garbage. The connection must
stay open until all message have been sent and answered."""
for _ in range(10):
- msgid = utils.random_msgid()
- wire = utils.prepare_wire(msgid) + utils.get_garbage(8)
+ wire, msgid = utils.prepare_wire()
+ wire += utils.get_garbage(8)
buff = utils.prepare_buffer(wire)
kresd_sock.sendall(buff)
import dns.message
-def random_msgid():
- return random.randint(1, 65535)
-
-
def receive_answer(sock):
answer_total_len = 0
data = sock.recv(2)
def prepare_wire(
- msgid=None,
qname='localhost.',
qtype=dns.rdatatype.A,
- qclass=dns.rdataclass.IN):
+ qclass=dns.rdataclass.IN,
+ msgid=None):
"""Utility function to generate DNS wire format message"""
msg = dns.message.make_query(qname, qtype, qclass)
if msgid is not None:
msg.id = msgid
- return msg.to_wire()
+ return msg.to_wire(), msg.id
def prepare_buffer(wire, datalen=None):
return struct.pack("!H", datalen) + wire
-def get_msgbuf(qname, qtype, msgid):
- # TODO remove/refactor in favor of prepare_wire, prepare_buffer
- msg = dns.message.make_query(qname, qtype, dns.rdataclass.IN)
- msg.id = msgid
- data = msg.to_wire()
- datalen = len(data)
- buf = struct.pack("!H", datalen) + data
- return buf
+def get_msgbuff(qname='localhost.', qtype=dns.rdatatype.A, msgid=None):
+ wire, msgid = prepare_wire(qname, qtype, msgid=msgid)
+ buff = prepare_buffer(wire)
+ return buff, msgid
def get_garbage(length):
- return bytearray(random.getrandbits(8) for _ in range(length))
+ return bytes(random.getrandbits(8) for _ in range(length))
def get_prefixed_garbage(length):
data = get_garbage(length)
- datalen = len(data)
- buf = struct.pack("!H", datalen) + data
- return buf
+ return prepare_buffer(data)