]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnscrypt.py
Merge pull request #7306 from pieterlexis/hello-2019
[thirdparty/pdns.git] / regression-tests.dnsdist / dnscrypt.py
1 #!/usr/bin/env python2
2 import socket
3 import struct
4 import time
5 import dns
6 import dns.message
7 import libnacl
8 import libnacl.utils
9 import binascii
10 from builtins import bytes
11
12 class DNSCryptResolverCertificate(object):
13 DNSCRYPT_CERT_MAGIC = b'\x44\x4e\x53\x43'
14 DNSCRYPT_ES_VERSION = b'\x00\x01'
15 DNSCRYPT_PROTOCOL_MIN_VERSION = b'\x00\x00'
16
17 def __init__(self, serial, validFrom, validUntil, publicKey, clientMagic):
18 self.serial = serial
19 self.validFrom = validFrom
20 self.validUntil = validUntil
21 self.publicKey = publicKey
22 self.clientMagic = clientMagic
23
24 def isValid(self):
25 now = time.time()
26 return self.validFrom <= now and self.validUntil >= now
27
28 @staticmethod
29 def fromBinary(binary, providerFP):
30 if len(binary) != 124:
31 raise Exception("Invalid binary certificate")
32
33 certMagic = binary[0:4]
34 esVersion = binary[4:6]
35 protocolMinVersion = binary[6:8]
36
37 if certMagic != DNSCryptResolverCertificate.DNSCRYPT_CERT_MAGIC or esVersion != DNSCryptResolverCertificate.DNSCRYPT_ES_VERSION or protocolMinVersion != DNSCryptResolverCertificate.DNSCRYPT_PROTOCOL_MIN_VERSION:
38 raise Exception("Invalid binary certificate")
39
40 orig = libnacl.crypto_sign_open(binary[8:124], providerFP)
41
42 resolverPK = orig[0:32]
43 clientMagic = orig[32:40]
44 serial = struct.unpack_from("!I", orig[40:44])[0]
45 validFrom = struct.unpack_from("!I", orig[44:48])[0]
46 validUntil = struct.unpack_from("!I", orig[48:52])[0]
47 return DNSCryptResolverCertificate(serial, validFrom, validUntil, resolverPK, clientMagic)
48
49 class DNSCryptClient(object):
50 DNSCRYPT_NONCE_SIZE = 24
51 DNSCRYPT_MAC_SIZE = 16
52 DNSCRYPT_PADDED_BLOCK_SIZE = 64
53 DNSCRYPT_MIN_UDP_LENGTH = 256
54 DNSCRYPT_RESOLVER_MAGIC = b'\x72\x36\x66\x6e\x76\x57\x6a\x38'
55
56 @staticmethod
57 def _addrToSocketType(addr):
58 result = None
59 try:
60 socket.inet_pton(socket.AF_INET6, addr)
61 result = socket.AF_INET6
62 except socket.error:
63 socket.inet_pton(socket.AF_INET, addr)
64 result = socket.AF_INET
65
66 return result
67
68 def __init__(self, providerName, providerFingerprint, resolverAddress, resolverPort=443, timeout=2):
69 self._providerName = providerName
70 self._providerFingerprint = binascii.unhexlify(providerFingerprint.lower().replace(':', ''))
71 self._resolverAddress = resolverAddress
72 self._resolverPort = resolverPort
73 self._resolverCertificates = []
74 self._publicKey, self._privateKey = libnacl.crypto_box_keypair()
75 self._timeout = timeout
76
77 addrType = self._addrToSocketType(self._resolverAddress)
78 self._sock = socket.socket(addrType, socket.SOCK_DGRAM)
79 self._sock.settimeout(timeout)
80 self._sock.connect((self._resolverAddress, self._resolverPort))
81
82 def _sendQuery(self, queryContent, tcp=False):
83 if tcp:
84 addrType = self._addrToSocketType(self._resolverAddress)
85 sock = socket.socket(addrType, socket.SOCK_STREAM)
86 sock.settimeout(self._timeout)
87 sock.connect((self._resolverAddress, self._resolverPort))
88 sock.send(struct.pack("!H", len(queryContent)))
89 else:
90 sock = self._sock
91
92 sock.send(queryContent)
93
94 data = None
95 if tcp:
96 got = sock.recv(2)
97 if got:
98 (rlen,) = struct.unpack("!H", got)
99 data = sock.recv(rlen)
100 else:
101 data = sock.recv(4096)
102
103 return data
104
105 def _hasValidResolverCertificate(self):
106
107 for cert in self._resolverCertificates:
108 if cert.isValid():
109 return True
110
111 return False
112
113 def clearExpiredResolverCertificates(self):
114 newCerts = []
115
116 for cert in self._resolverCertificates:
117 if cert.isValid():
118 newCerts.append(cert)
119
120 self._resolverCertificates = newCerts
121
122 def refreshResolverCertificates(self):
123 self.clearExpiredResolverCertificates()
124
125 query = dns.message.make_query(self._providerName, dns.rdatatype.TXT, dns.rdataclass.IN)
126 data = self._sendQuery(query.to_wire())
127
128 response = dns.message.from_wire(data)
129 if response.rcode() != dns.rcode.NOERROR or len(response.answer) != 1:
130 raise Exception("Invalid response to public key request")
131
132 an = response.answer[0]
133 if an.rdclass != dns.rdataclass.IN or an.rdtype != dns.rdatatype.TXT or len(an.items) == 0:
134 raise Exception("Invalid response to public key request")
135
136 self._resolverCertificates = []
137
138 for item in an.items:
139 if len(item.strings) != 1:
140 continue
141
142 cert = DNSCryptResolverCertificate.fromBinary(item.strings[0], self._providerFingerprint)
143 if cert.isValid():
144 self._resolverCertificates.append(cert)
145
146 def getResolverCertificate(self):
147 certs = self._resolverCertificates
148 result = None
149 for cert in certs:
150 if cert.isValid():
151 if result is None or cert.serial > result.serial:
152 result = cert
153
154 return result
155
156 def getAllResolverCertificates(self, onlyValid=False):
157 certs = self._resolverCertificates
158 result = []
159 for cert in certs:
160 if not onlyValid or cert.isValid():
161 result.append(cert)
162
163 return result
164
165 @staticmethod
166 def _generateNonce():
167 nonce = libnacl.utils.rand_nonce()
168 return nonce[:int(DNSCryptClient.DNSCRYPT_NONCE_SIZE / 2)]
169
170 def _encryptQuery(self, queryContent, resolverCert, nonce, tcp=False):
171 header = resolverCert.clientMagic + self._publicKey + nonce
172 requiredSize = len(header) + self.DNSCRYPT_MAC_SIZE + len(queryContent)
173 paddingSize = self.DNSCRYPT_PADDED_BLOCK_SIZE - (len(queryContent) % self.DNSCRYPT_PADDED_BLOCK_SIZE)
174 # padding size should be DNSCRYPT_PADDED_BLOCK_SIZE <= padding size <= 4096
175 if not tcp and requiredSize < self.DNSCRYPT_MIN_UDP_LENGTH:
176 paddingSize += self.DNSCRYPT_MIN_UDP_LENGTH - requiredSize
177 requiredSize = self.DNSCRYPT_MIN_UDP_LENGTH
178
179 padding = b'\x80'
180 idx = 0
181 while idx < (paddingSize - 1):
182 padding = padding + b'\x00'
183 idx += 1
184
185 data = queryContent + padding
186 nonce = nonce + (b'\x00'*int(self.DNSCRYPT_NONCE_SIZE / 2))
187 box = libnacl.crypto_box(data, nonce, resolverCert.publicKey, self._privateKey)
188 return header + box
189
190 def _decryptResponse(self, encryptedResponse, resolverCert, clientNonce):
191 resolverMagic = encryptedResponse[:8]
192 if resolverMagic != self.DNSCRYPT_RESOLVER_MAGIC:
193 raise Exception("Invalid encrypted response: bad resolver magic")
194
195 nonce = encryptedResponse[8:32]
196 if nonce[0:int(self.DNSCRYPT_NONCE_SIZE / 2)] != clientNonce:
197 raise Exception("Invalid encrypted response: bad nonce")
198
199 cleartext = libnacl.crypto_box_open(encryptedResponse[32:], nonce, resolverCert.publicKey, self._privateKey)
200 cleartextBytes = bytes(cleartext)
201 idx = len(cleartextBytes) - 1
202 while idx > 0:
203 if cleartextBytes[idx] != 0:
204 break
205 idx -= 1
206
207 if idx == 0 or cleartextBytes[idx] != 128:
208 raise Exception("Invalid encrypted response: invalid padding")
209
210 idx -= 1
211 paddingLen = len(cleartextBytes) - idx
212
213 return cleartext[:idx+1]
214
215 def query(self, queryContent, tcp=False):
216
217 if not self._hasValidResolverCertificate():
218 self.refreshResolverCertificates()
219
220 nonce = self._generateNonce()
221 resolverCert = self.getResolverCertificate()
222 if resolverCert is None:
223 raise Exception("No valid certificate found")
224 encryptedQuery = self._encryptQuery(queryContent, resolverCert, nonce, tcp)
225 encryptedResponse = self._sendQuery(encryptedQuery, tcp)
226 response = self._decryptResponse(encryptedResponse, resolverCert, nonce)
227 return response