]> git.ipfire.org Git - thirdparty/pdns.git/blame - regression-tests.dnsdist/test_TCPShort.py
Merge pull request #8795 from omoerbeek/rec-lua-docs-policytag
[thirdparty/pdns.git] / regression-tests.dnsdist / test_TCPShort.py
CommitLineData
6ac8517d
RG
1#!/usr/bin/env python
2import socket
3import struct
4import threading
5import time
6import dns
7from dnsdisttests import DNSDistTest
8
9try:
10 range = xrange
11except NameError:
12 pass
13
14class TestTCPShort(DNSDistTest):
15 # this test suite uses a different responder port
16 # because, contrary to the other ones, its
17 # responders allow trailing data and multiple responses,
18 # and we don't want to mix things up.
19 _testServerPort = 5361
20 _serverKey = 'server.key'
21 _serverCert = 'server.chain'
22 _serverName = 'tls.tests.dnsdist.org'
23 _caCert = 'ca.pem'
24 _tlsServerPort = 8453
25 _tcpSendTimeout = 60
26 _config_template = """
27 newServer{address="127.0.0.1:%s"}
28 addTLSLocal("127.0.0.1:%s", "%s", "%s")
29 setTCPSendTimeout(%d)
30 """
31 _config_params = ['_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_tcpSendTimeout']
32
33 @classmethod
34 def startResponders(cls):
35 print("Launching responders..")
36
37 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, True])
38 cls._UDPResponder.setDaemon(True)
39 cls._UDPResponder.start()
40
41 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, True, True])
42 cls._TCPResponder.setDaemon(True)
43 cls._TCPResponder.start()
44
45 def testTCPShortRead(self):
46 """
47 TCP: Short read from client
48 """
49 name = 'short-read.tcp-short.tests.powerdns.com.'
50 query = dns.message.make_query(name, 'A', 'IN')
51 expectedResponse = dns.message.make_response(query)
52 rrset = dns.rrset.from_text(name,
53 3600,
54 dns.rdataclass.IN,
55 dns.rdatatype.A,
56 '192.0.2.1')
57 expectedResponse.answer.append(rrset)
58
59 conn = self.openTCPConnection()
60 wire = query.to_wire()
61 # announce 7680 bytes (more than 4096, less than 8192 - the 512 bytes dnsdist is going to add)
62 announcedSize = 7680
63 paddingSize = announcedSize - len(wire)
64 wire = wire + (b'A' * (paddingSize - 1))
65 self._toResponderQueue.put(expectedResponse, True, 2.0)
66
67 sizeBytes = struct.pack("!H", announcedSize)
68 conn.send(sizeBytes[:1])
69 time.sleep(1)
70 conn.send(sizeBytes[1:])
71 # send announcedSize bytes minus 1 so we get a second read
72 conn.send(wire)
73 time.sleep(1)
74 # send 1024 bytes
75 conn.send(b'A' * 1024)
76
77 (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, True)
78 conn.close()
79
80 self.assertTrue(receivedQuery)
81 self.assertTrue(receivedResponse)
82 receivedQuery.id = query.id
83 self.assertEquals(query, receivedQuery)
84 self.assertEquals(receivedResponse, expectedResponse)
85
86 def testTCPTLSShortRead(self):
87 """
88 TCP/TLS: Short read from client
89 """
90 name = 'short-read-tls.tcp-short.tests.powerdns.com.'
91 query = dns.message.make_query(name, 'A', 'IN')
92 expectedResponse = dns.message.make_response(query)
93 rrset = dns.rrset.from_text(name,
94 3600,
95 dns.rdataclass.IN,
96 dns.rdatatype.A,
97 '192.0.2.1')
98 expectedResponse.answer.append(rrset)
99
100 conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert)
101 wire = query.to_wire()
102 # announce 7680 bytes (more than 4096, less than 8192 - the 512 bytes dnsdist is going to add)
103 announcedSize = 7680
104 paddingSize = announcedSize - len(wire)
105 wire = wire + (b'A' * (paddingSize - 1))
106 self._toResponderQueue.put(expectedResponse, True, 2.0)
107
108 sizeBytes = struct.pack("!H", announcedSize)
109 conn.send(sizeBytes[:1])
110 time.sleep(1)
111 conn.send(sizeBytes[1:])
112 # send announcedSize bytes minus 1 so we get a second read
113 conn.send(wire)
114 time.sleep(1)
115 # send 1024 bytes
116 conn.send(b'A' * 1024)
117
118 (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, True)
119 conn.close()
120
121 self.assertTrue(receivedQuery)
122 self.assertTrue(receivedResponse)
123 receivedQuery.id = query.id
124 self.assertEquals(query, receivedQuery)
125 self.assertEquals(receivedResponse, expectedResponse)
126
127 def testTCPShortWrite(self):
128 """
129 TCP: Short write to client
130 """
131 name = 'short-write.tcp-short.tests.powerdns.com.'
132 query = dns.message.make_query(name, 'AXFR', 'IN')
133
134 # we prepare a large AXFR answer
135 # SOA + 200 dns messages of one huge TXT RRset each + SOA
136 responses = []
137 soa = dns.rrset.from_text(name,
138 60,
139 dns.rdataclass.IN,
140 dns.rdatatype.SOA,
141 'ns.' + name + ' hostmaster.' + name + ' 1 3600 3600 3600 60')
142
143 soaResponse = dns.message.make_response(query)
144 soaResponse.use_edns(edns=False)
145 soaResponse.answer.append(soa)
146 responses.append(soaResponse)
147
148 response = dns.message.make_response(query)
149 response.use_edns(edns=False)
150 content = ""
151 for i in range(200):
152 if len(content) > 0:
153 content = content + ', '
154 content = content + (str(i)*50)
155
156 rrset = dns.rrset.from_text(name,
157 3600,
158 dns.rdataclass.IN,
159 dns.rdatatype.TXT,
160 content)
161 response.answer.append(rrset)
162
163 for _ in range(200):
164 responses.append(response)
165
166 responses.append(soaResponse)
167
168 conn = self.openTCPConnection()
169
170 for response in responses:
171 self._toResponderQueue.put(response, True, 2.0)
172
173 self.sendTCPQueryOverConnection(conn, query)
174
175 # we sleep for one second, making sure that dnsdist
176 # will fill its TCP window and buffers, which will result
177 # in some short writes
178 time.sleep(1)
179
180 # we then read the messages
181 receivedResponses = []
182 while True:
183 datalen = conn.recv(2)
184 if not datalen:
185 break
186
187 (datalen,) = struct.unpack("!H", datalen)
188 data = b''
189 remaining = datalen
190 got = conn.recv(remaining)
191 while got:
192 data = data + got
193 if len(data) == datalen:
194 break
195 remaining = remaining - len(got)
196 if remaining <= 0:
197 break
198 got = conn.recv(remaining)
199
200 if data and len(data) == datalen:
201 receivedResponse = dns.message.from_wire(data)
202 receivedResponses.append(receivedResponse)
203
204 receivedQuery = None
205 if not self._fromResponderQueue.empty():
206 receivedQuery = self._fromResponderQueue.get(True, 2.0)
207
208 conn.close()
209
210 # and check that everything is good
211 self.assertTrue(receivedQuery)
212 receivedQuery.id = query.id
213 self.assertEquals(query, receivedQuery)
214 self.assertEquals(receivedResponses, responses)
215
216 def testTCPTLSShortWrite(self):
217 """
218 TCP/TLS: Short write to client
219 """
220 # same as testTCPShortWrite but over TLS this time
221 name = 'short-write-tls.tcp-short.tests.powerdns.com.'
222 query = dns.message.make_query(name, 'AXFR', 'IN')
223 responses = []
224 soa = dns.rrset.from_text(name,
225 60,
226 dns.rdataclass.IN,
227 dns.rdatatype.SOA,
228 'ns.' + name + ' hostmaster.' + name + ' 1 3600 3600 3600 60')
229
230 soaResponse = dns.message.make_response(query)
231 soaResponse.use_edns(edns=False)
232 soaResponse.answer.append(soa)
233 responses.append(soaResponse)
234
235 response = dns.message.make_response(query)
236 response.use_edns(edns=False)
237 content = ""
238 for i in range(200):
239 if len(content) > 0:
240 content = content + ', '
241 content = content + (str(i)*50)
242
243 rrset = dns.rrset.from_text(name,
244 3600,
245 dns.rdataclass.IN,
246 dns.rdatatype.TXT,
247 content)
248 response.answer.append(rrset)
249
250 for _ in range(200):
251 responses.append(response)
252
253 responses.append(soaResponse)
254
255 conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert)
256
257 for response in responses:
258 self._toResponderQueue.put(response, True, 2.0)
259
260 self.sendTCPQueryOverConnection(conn, query)
261
262 time.sleep(1)
263
264 receivedResponses = []
265 while True:
266 datalen = conn.recv(2)
267 if not datalen:
268 break
269
270 (datalen,) = struct.unpack("!H", datalen)
271 data = b''
272 remaining = datalen
273 got = conn.recv(remaining)
274 while got:
275 data = data + got
276 if len(data) == datalen:
277 break
278 remaining = remaining - len(got)
279 if remaining <= 0:
280 break
281 got = conn.recv(remaining)
282
283 if data and len(data) == datalen:
284 receivedResponse = dns.message.from_wire(data)
285 receivedResponses.append(receivedResponse)
286
287 receivedQuery = None
288 if not self._fromResponderQueue.empty():
289 receivedQuery = self._fromResponderQueue.get(True, 2.0)
290
291 conn.close()
292
293 self.assertTrue(receivedQuery)
294 receivedQuery.id = query.id
295 self.assertEquals(query, receivedQuery)
296 self.assertEquals(len(receivedResponses), len(responses))
297 self.assertEquals(receivedResponses, responses)