]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnsdistdohtests.py
Merge pull request #11591 from rgacogne/ddist-mac-netlink
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdistdohtests.py
1 #!/usr/bin/env python
2 import base64
3 import dns
4 import os
5 import unittest
6
7 from dnsdisttests import DNSDistTest
8
9 import pycurl
10 from io import BytesIO
11
12 @unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled')
13 class DNSDistDOHTest(DNSDistTest):
14
15 @classmethod
16 def getDOHGetURL(cls, baseurl, query, rawQuery=False):
17 if rawQuery:
18 wire = query
19 else:
20 wire = query.to_wire()
21 param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
22 return baseurl + "?dns=" + param
23
24 @classmethod
25 def openDOHConnection(cls, port, caFile, timeout=2.0):
26 conn = pycurl.Curl()
27 conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
28
29 conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message",
30 "Accept: application/dns-message"])
31 return conn
32
33 @classmethod
34 def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None):
35 url = cls.getDOHGetURL(baseurl, query, rawQuery)
36 conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
37 response_headers = BytesIO()
38 #conn.setopt(pycurl.VERBOSE, True)
39 conn.setopt(pycurl.URL, url)
40 conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
41 if useHTTPS:
42 conn.setopt(pycurl.SSL_VERIFYPEER, 1)
43 conn.setopt(pycurl.SSL_VERIFYHOST, 2)
44 if caFile:
45 conn.setopt(pycurl.CAINFO, caFile)
46
47 conn.setopt(pycurl.HTTPHEADER, customHeaders)
48 conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
49
50 if response:
51 if toQueue:
52 toQueue.put(response, True, timeout)
53 else:
54 cls._toResponderQueue.put(response, True, timeout)
55
56 receivedQuery = None
57 message = None
58 cls._response_headers = ''
59 data = conn.perform_rb()
60 cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
61 if cls._rcode == 200 and not rawResponse:
62 message = dns.message.from_wire(data)
63 elif rawResponse:
64 message = data
65
66 if useQueue:
67 if fromQueue:
68 if not fromQueue.empty():
69 receivedQuery = fromQueue.get(True, timeout)
70 else:
71 if not cls._fromResponderQueue.empty():
72 receivedQuery = cls._fromResponderQueue.get(True, timeout)
73
74 cls._response_headers = response_headers.getvalue()
75 return (receivedQuery, message)
76
77 @classmethod
78 def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
79 url = baseurl
80 conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
81 response_headers = BytesIO()
82 #conn.setopt(pycurl.VERBOSE, True)
83 conn.setopt(pycurl.URL, url)
84 conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
85 if useHTTPS:
86 conn.setopt(pycurl.SSL_VERIFYPEER, 1)
87 conn.setopt(pycurl.SSL_VERIFYHOST, 2)
88 if caFile:
89 conn.setopt(pycurl.CAINFO, caFile)
90
91 conn.setopt(pycurl.HTTPHEADER, customHeaders)
92 conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
93 conn.setopt(pycurl.POST, True)
94 data = query
95 if not rawQuery:
96 data = data.to_wire()
97
98 conn.setopt(pycurl.POSTFIELDS, data)
99
100 if response:
101 cls._toResponderQueue.put(response, True, timeout)
102
103 receivedQuery = None
104 message = None
105 cls._response_headers = ''
106 data = conn.perform_rb()
107 cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
108 if cls._rcode == 200 and not rawResponse:
109 message = dns.message.from_wire(data)
110 elif rawResponse:
111 message = data
112
113 if useQueue and not cls._fromResponderQueue.empty():
114 receivedQuery = cls._fromResponderQueue.get(True, timeout)
115
116 cls._response_headers = response_headers.getvalue()
117 return (receivedQuery, message)
118
119 def getHeaderValue(self, name):
120 for header in self._response_headers.decode().splitlines(False):
121 values = header.split(':')
122 key = values[0]
123 if key.lower() == name.lower():
124 return values[1].strip()
125 return None
126
127 def checkHasHeader(self, name, value):
128 got = self.getHeaderValue(name)
129 self.assertEqual(got, value)
130
131 def checkNoHeader(self, name):
132 self.checkHasHeader(name, None)
133
134 @classmethod
135 def setUpClass(cls):
136
137 # for some reason, @unittest.skipIf() is not applied to derived classes with some versions of Python
138 if 'SKIP_DOH_TESTS' in os.environ:
139 raise unittest.SkipTest('DNS over HTTPS tests are disabled')
140
141 cls.startResponders()
142 cls.startDNSDist()
143 cls.setUpSockets()
144
145 print("Launching tests..")