]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
pytests: refactor utils to generate msg ids
authorTomas Krizek <tomas.krizek@nic.cz>
Tue, 13 Nov 2018 17:17:01 +0000 (18:17 +0100)
committerTomas Krizek <tomas.krizek@nic.cz>
Tue, 4 Dec 2018 16:13:42 +0000 (17:13 +0100)
tests/pytests/kresd.py
tests/pytests/test_conn_mgmt.py
tests/pytests/test_tcp_prefix.py
tests/pytests/utils.py

index 1dca9853a9b9f2c0e2fb768bfd7e3ed08396b098..0ef25661e918088a5fec540b0dd35ef2c77ab5d8 100644 (file)
@@ -6,7 +6,6 @@ import ssl
 import subprocess
 import time
 
-import dns
 import jinja2
 import pytest
 
@@ -43,9 +42,8 @@ def make_ssl_context():
 
 
 def ping_alive(sock):
-    msgid = utils.random_msgid()
-    buf = utils.get_msgbuf('localhost.', dns.rdatatype.A, msgid)
-    sock.sendall(buf)
+    buff, msgid = utils.get_msgbuff()
+    sock.sendall(buff)
     answer = utils.receive_parse_answer(sock)
     return answer.id == msgid
 
index 78539dfad28749acaea4076d208d16cf386875fa..5d907c3b0da1e3d40daf31fe87f76db24866004f 100644 (file)
@@ -1,8 +1,5 @@
 """TCP Connection Management tests"""
 
-import dns
-import dns.message
-
 import utils
 
 
@@ -13,16 +10,12 @@ def test_ignore_garbage(kresd_sock):
 
     Expected: garbage must be ignored and the second query must be answered
     """
-    MSG_ID = 1
-
-    msg = utils.get_msgbuf('localhost.', dns.rdatatype.A, MSG_ID)
-    garbage = utils.get_prefixed_garbage(1024)
-    buf = garbage + msg
+    msg_buff, msgid = utils.get_msgbuff()
+    garbage_buff = utils.get_prefixed_garbage(1024)
+    kresd_sock.sendall(garbage_buff + msg_buff)
 
-    kresd_sock.sendall(buf)
     msg_answer = utils.receive_parse_answer(kresd_sock)
-
-    assert msg_answer.id == MSG_ID
+    assert msg_answer.id == msgid
 
 
 def test_pipelining(kresd_sock):
@@ -31,13 +24,10 @@ def test_pipelining(kresd_sock):
 
     Expected: answer to the second query must come first.
     """
-    MSG_ID_FIRST = 1
-    MSG_ID_SECOND = 2
+    buff1, msgid1 = utils.get_msgbuff('1000.delay.getdnsapi.net.', msgid=1)
+    buff2, msgid2 = utils.get_msgbuff('1.delay.getdnsapi.net.', msgid=2)
+    buff = buff1 + buff2
+    kresd_sock.sendall(buff)
 
-    buf = utils.get_msgbuf('1000.delay.getdnsapi.net.', dns.rdatatype.A, MSG_ID_FIRST) \
-        + utils.get_msgbuf('1.delay.getdnsapi.net.', dns.rdatatype.A, MSG_ID_SECOND)
-
-    kresd_sock.sendall(buf)
     msg_answer = utils.receive_parse_answer(kresd_sock)
-
-    assert msg_answer.id == MSG_ID_SECOND
+    assert msg_answer.id == msgid2
index 2ad9d6e838c441c3d1ee3f75a1712b570a7d2ed6..d611320934212d4ebc2563650d0b70105733d39a 100644 (file)
@@ -47,7 +47,7 @@ def send_incorrect_repeatedly(sock, buff, delay=1):
 
 def test_less_than_header(kresd_sock):
     """Prefix is less than the length of the DNS message header."""
-    wire = utils.prepare_wire()
+    wire, _ = utils.prepare_wire()
     datalen = 11  # DNS header size minus 1
     buff = utils.prepare_buffer(wire, datalen)
     send_incorrect_repeatedly(kresd_sock, buff)
@@ -55,7 +55,7 @@ def test_less_than_header(kresd_sock):
 
 def test_greater_than_message(kresd_sock):
     """Prefix is greater than the length of the entire DNS message."""
-    wire = utils.prepare_wire()
+    wire, _ = utils.prepare_wire()
     datalen = len(wire) + 16
     buff = utils.prepare_buffer(wire, datalen)
     send_incorrect_repeatedly(kresd_sock, buff)
@@ -64,7 +64,7 @@ def test_greater_than_message(kresd_sock):
 def test_cuts_message(kresd_sock):
     """Prefix is greater than the length of the DNS message header, but shorter than
     the entire DNS message."""
-    wire = utils.prepare_wire()
+    wire, _ = utils.prepare_wire()
     datalen = 14  # DNS Header size plus 2
     assert datalen < len(wire)
     buff = utils.prepare_buffer(wire, datalen)
@@ -75,11 +75,10 @@ def test_cuts_message_after_ok(kresd_sock):
     """First, normal DNS message is sent. Afterwards, message with incorrect prefix
     (greater than header, less than entire message) is sent. First message must be
     answered, then the connection should be closed after timeout."""
-    normal_msg_id = 1
-    normal_wire = utils.prepare_wire(normal_msg_id)
+    normal_wire, normal_msgid = utils.prepare_wire(msgid=1)
     normal_buff = utils.prepare_buffer(normal_wire)
 
-    cut_wire = utils.prepare_wire()
+    cut_wire, _ = utils.prepare_wire(msgid=2)
     cut_datalen = 14
     assert cut_datalen < len(cut_wire)
     cut_buff = utils.prepare_buffer(cut_wire, cut_datalen)
@@ -97,8 +96,8 @@ def test_trailing_garbage(kresd_sock):
     """Prefix is correct, but the message has trailing garbage. The connection must
     stay open until all message have been sent and answered."""
     for _ in range(10):
-        msgid = utils.random_msgid()
-        wire = utils.prepare_wire(msgid) + utils.get_garbage(8)
+        wire, msgid = utils.prepare_wire()
+        wire += utils.get_garbage(8)
         buff = utils.prepare_buffer(wire)
 
         kresd_sock.sendall(buff)
index 9ba4e6595e7a4a5bd85cd9ff44a17621d89bc23f..9f365d8e4c1b79da2765a8206f0b9784c035ac5b 100644 (file)
@@ -5,10 +5,6 @@ import dns
 import dns.message
 
 
-def random_msgid():
-    return random.randint(1, 65535)
-
-
 def receive_answer(sock):
     answer_total_len = 0
     data = sock.recv(2)
@@ -39,15 +35,15 @@ def receive_parse_answer(sock):
 
 
 def prepare_wire(
-        msgid=None,
         qname='localhost.',
         qtype=dns.rdatatype.A,
-        qclass=dns.rdataclass.IN):
+        qclass=dns.rdataclass.IN,
+        msgid=None):
     """Utility function to generate DNS wire format message"""
     msg = dns.message.make_query(qname, qtype, qclass)
     if msgid is not None:
         msg.id = msgid
-    return msg.to_wire()
+    return msg.to_wire(), msg.id
 
 
 def prepare_buffer(wire, datalen=None):
@@ -58,22 +54,16 @@ def prepare_buffer(wire, datalen=None):
     return struct.pack("!H", datalen) + wire
 
 
-def get_msgbuf(qname, qtype, msgid):
-    # TODO remove/refactor in favor of prepare_wire, prepare_buffer
-    msg = dns.message.make_query(qname, qtype, dns.rdataclass.IN)
-    msg.id = msgid
-    data = msg.to_wire()
-    datalen = len(data)
-    buf = struct.pack("!H", datalen) + data
-    return buf
+def get_msgbuff(qname='localhost.', qtype=dns.rdatatype.A, msgid=None):
+    wire, msgid = prepare_wire(qname, qtype, msgid=msgid)
+    buff = prepare_buffer(wire)
+    return buff, msgid
 
 
 def get_garbage(length):
-    return bytearray(random.getrandbits(8) for _ in range(length))
+    return bytes(random.getrandbits(8) for _ in range(length))
 
 
 def get_prefixed_garbage(length):
     data = get_garbage(length)
-    datalen = len(data)
-    buf = struct.pack("!H", datalen) + data
-    return buf
+    return prepare_buffer(data)