]> git.ipfire.org Git - thirdparty/pdns.git/blame - regression-tests.dnsdist/test_Protobuf.py
dnsdist: Add bindings for the `DNSResponse` object
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Protobuf.py
CommitLineData
1d0bd88a
RG
1#!/usr/bin/env python
2import Queue
3import threading
4import socket
5import struct
6import sys
7import time
8from dnsdisttests import DNSDistTest
9
10import dns
11import dnsmessage_pb2
12
13class TestProtobuf(DNSDistTest):
14
15 _protobufServerPort = 4242
16 _protobufQueue = Queue.Queue()
17 _protobufCounter = 0
18 _config_params = ['_testServerPort', '_protobufServerPort']
19 _config_template = """
20 newServer{address="127.0.0.1:%s", useClientSubnet=true}
21 rl = newRemoteLogger('127.0.0.1:%s')
22 addAction(AllRule(), RemoteLogAction(rl))
23 addResponseAction(AllRule(), RemoteLogResponseAction(rl))
24 """
25
26 @classmethod
27 def ProtobufListener(cls, port):
28 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
29 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
30 try:
31 sock.bind(("127.0.0.1", port))
32 except socket.error as e:
33 print("Error binding in the protbuf listener: %s" % str(e))
34 sys.exit(1)
35
36 sock.listen(100)
37 while True:
38 (conn, _) = sock.accept()
39 data = None
40 while True:
41 data = conn.recv(2)
42 if not data:
43 break
44 (datalen,) = struct.unpack("!H", data)
45 data = conn.recv(datalen)
46 if not data:
47 break
48
49 cls._protobufQueue.put(data, True, timeout=2.0)
50
51 conn.close()
52 sock.close()
53
54 @classmethod
55 def startResponders(cls):
56 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort])
57 cls._UDPResponder.setDaemon(True)
58 cls._UDPResponder.start()
59 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort])
60 cls._TCPResponder.setDaemon(True)
61 cls._TCPResponder.start()
62
63 cls._protobufListener = threading.Thread(name='Protobuf Listener', target=cls.ProtobufListener, args=[cls._protobufServerPort])
64 cls._protobufListener.setDaemon(True)
65 cls._protobufListener.start()
66
67 def getFirstProtobufMessage(self):
68 self.assertFalse(self._protobufQueue.empty())
69 data = self._protobufQueue.get(False)
70 self.assertTrue(data)
71 msg = dnsmessage_pb2.PBDNSMessage()
72 msg.ParseFromString(data)
73 return msg
74
75 def checkProtobufBase(self, msg, protocol, query):
76 self.assertTrue(msg)
77 self.assertTrue(msg.HasField('timeSec'))
78 self.assertTrue(msg.HasField('socketFamily'))
79 self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
80 self.assertTrue(msg.HasField('from'))
81 fromvalue = getattr(msg, 'from')
82 self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), '127.0.0.1')
83 self.assertTrue(msg.HasField('socketProtocol'))
84 self.assertEquals(msg.socketProtocol, protocol)
85 self.assertTrue(msg.HasField('messageId'))
86 self.assertTrue(msg.HasField('id'))
87 self.assertEquals(msg.id, query.id)
88 self.assertTrue(msg.HasField('inBytes'))
89 self.assertEquals(msg.inBytes, len(query.to_wire()))
90 # dnsdist doesn't set the existing EDNS Subnet for now,
91 # although it might be set from Lua
92 # self.assertTrue(msg.HasField('originalRequestorSubnet'))
93 # self.assertEquals(len(msg.originalRequestorSubnet), 4)
94 # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
95
96 def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname):
97 self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
98 self.checkProtobufBase(msg, protocol, query)
99 # dnsdist doesn't fill the responder field for responses
100 # because it doesn't keep the information around.
101 self.assertTrue(msg.HasField('to'))
102 self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to), '127.0.0.1')
103 self.assertTrue(msg.HasField('question'))
104 self.assertTrue(msg.question.HasField('qClass'))
105 self.assertEquals(msg.question.qClass, qclass)
106 self.assertTrue(msg.question.HasField('qType'))
107 self.assertEquals(msg.question.qClass, qtype)
108 self.assertTrue(msg.question.HasField('qName'))
109 self.assertEquals(msg.question.qName, qname)
110
111 def checkProtobufResponse(self, msg, protocol, response):
112 self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
113 self.checkProtobufBase(msg, protocol, response)
114 self.assertTrue(msg.HasField('response'))
115 self.assertTrue(msg.response.HasField('queryTimeSec'))
116
117 def testProtobuf(self):
118 """
119 Protobuf: Send data to a protobuf server
120 """
121 name = 'query.protobuf.tests.powerdns.com.'
122 query = dns.message.make_query(name, 'A', 'IN')
123 response = dns.message.make_response(query)
124 rrset = dns.rrset.from_text(name,
125 3600,
126 dns.rdataclass.IN,
127 dns.rdatatype.A,
128 '127.0.0.1')
129 response.answer.append(rrset)
130
131 (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
132 self.assertTrue(receivedQuery)
133 self.assertTrue(receivedResponse)
134 receivedQuery.id = query.id
135 self.assertEquals(query, receivedQuery)
136 self.assertEquals(response, receivedResponse)
137
138 # let the protobuf messages the time to get there
139 time.sleep(1)
140
141 # check the protobuf message corresponding to the UDP query
142 msg = self.getFirstProtobufMessage()
143 self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
144
145 # check the protobuf message corresponding to the UDP response
146 msg = self.getFirstProtobufMessage()
147 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response)
148 self.assertEquals(len(msg.response.rrs), 1)
149 for rr in msg.response.rrs:
150 self.assertTrue(rr.HasField('class'))
151 self.assertEquals(getattr(rr, 'class'), dns.rdataclass.IN)
152 self.assertTrue(rr.HasField('type'))
153 self.assertEquals(rr.type, dns.rdatatype.A)
154 self.assertTrue(rr.HasField('name'))
155 self.assertEquals(rr.name, name)
156 self.assertTrue(rr.HasField('ttl'))
157 self.assertEquals(rr.ttl, 3600)
158 self.assertTrue(rr.HasField('rdata'))
159 self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
160
161 (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
162 self.assertTrue(receivedQuery)
163 self.assertTrue(receivedResponse)
164 receivedQuery.id = query.id
165 self.assertEquals(query, receivedQuery)
166 self.assertEquals(response, receivedResponse)
167
168 # let the protobuf messages the time to get there
169 time.sleep(1)
170
171 # check the protobuf message corresponding to the TCP query
172 msg = self.getFirstProtobufMessage()
173 self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
174
175 # check the protobuf message corresponding to the TCP response
176 msg = self.getFirstProtobufMessage()
177 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response)
178 self.assertEquals(len(msg.response.rrs), 1)
179 for rr in msg.response.rrs:
180 self.assertTrue(rr.HasField('class'))
181 self.assertEquals(getattr(rr, 'class'), dns.rdataclass.IN)
182 self.assertTrue(rr.HasField('type'))
183 self.assertEquals(rr.type, dns.rdatatype.A)
184 self.assertTrue(rr.HasField('name'))
185 self.assertEquals(rr.name, name)
186 self.assertTrue(rr.HasField('ttl'))
187 self.assertEquals(rr.ttl, 3600)
188 self.assertTrue(rr.HasField('rdata'))
189 self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')