]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnscrypt.py
Merge pull request #6322 from zeha/dnsdist-tests
[thirdparty/pdns.git] / regression-tests.dnsdist / dnscrypt.py
1 #!/usr/bin/env python2
2 import socket
3 import struct
4 import time
5 import dns
6 import dns.message
7 import libnacl
8 import libnacl.utils
9 import binascii
10
11
12 class DNSCryptResolverCertificate(object):
13 DNSCRYPT_CERT_MAGIC = '\x44\x4e\x53\x43'
14 DNSCRYPT_ES_VERSION = '\x00\x01'
15 DNSCRYPT_PROTOCOL_MIN_VERSION = '\x00\x00'
16
17 def __init__(self, serial, validFrom, validUntil, publicKey, clientMagic):
18 self.serial = serial
19 self.validFrom = validFrom
20 self.validUntil = validUntil
21 self.publicKey = publicKey
22 self.clientMagic = clientMagic
23
24 def isValid(self):
25 now = time.time()
26 return self.validFrom <= now and self.validUntil >= now
27
28 @staticmethod
29 def fromBinary(binary, providerFP):
30 if len(binary) != 124:
31 raise Exception("Invalid binary certificate")
32
33 certMagic = binary[0:4]
34 esVersion = binary[4:6]
35 protocolMinVersion = binary[6:8]
36
37 if certMagic != DNSCryptResolverCertificate.DNSCRYPT_CERT_MAGIC or esVersion != DNSCryptResolverCertificate.DNSCRYPT_ES_VERSION or protocolMinVersion != DNSCryptResolverCertificate.DNSCRYPT_PROTOCOL_MIN_VERSION:
38 raise Exception("Invalid binary certificate")
39
40 orig = libnacl.crypto_sign_open(binary[8:124], providerFP)
41
42 resolverPK = orig[0:32]
43 clientMagic = orig[32:40]
44 serial = struct.unpack_from("I", orig[40:44])[0]
45 validFrom = struct.unpack_from("!I", orig[44:48])[0]
46 validUntil = struct.unpack_from("!I", orig[48:52])[0]
47 return DNSCryptResolverCertificate(serial, validFrom, validUntil, resolverPK, clientMagic)
48
49 class DNSCryptClient(object):
50 DNSCRYPT_NONCE_SIZE = 24
51 DNSCRYPT_MAC_SIZE = 16
52 DNSCRYPT_PADDED_BLOCK_SIZE = 64
53 DNSCRYPT_MIN_UDP_LENGTH = 256
54 DNSCRYPT_RESOLVER_MAGIC = '\x72\x36\x66\x6e\x76\x57\x6a\x38'
55
56 @staticmethod
57 def _addrToSocketType(addr):
58 result = None
59 try:
60 socket.inet_pton(socket.AF_INET6, addr)
61 result = socket.AF_INET6
62 except socket.error:
63 socket.inet_pton(socket.AF_INET, addr)
64 result = socket.AF_INET
65
66 return result
67
68 def __init__(self, providerName, providerFingerprint, resolverAddress, resolverPort=443, timeout=2):
69 self._providerName = providerName
70 self._providerFingerprint = binascii.unhexlify(providerFingerprint.lower().replace(':', ''))
71 self._resolverAddress = resolverAddress
72 self._resolverPort = resolverPort
73 self._resolverCertificates = []
74 self._publicKey, self._privateKey = libnacl.crypto_box_keypair()
75 self._timeout = timeout
76
77 addrType = self._addrToSocketType(self._resolverAddress)
78 self._sock = socket.socket(addrType, socket.SOCK_DGRAM)
79 self._sock.settimeout(timeout)
80 self._sock.connect((self._resolverAddress, self._resolverPort))
81
82 def _sendQuery(self, queryContent, tcp=False):
83 if tcp:
84 addrType = self._addrToSocketType(self._resolverAddress)
85 sock = socket.socket(addrType, socket.SOCK_STREAM)
86 sock.settimeout(self._timeout)
87 sock.connect((self._resolverAddress, self._resolverPort))
88 sock.send(struct.pack("!H", len(queryContent)))
89 else:
90 sock = self._sock
91
92 sock.send(queryContent)
93
94 data = None
95 if tcp:
96 got = sock.recv(2)
97 print(len(got))
98 if got:
99 (rlen,) = struct.unpack("!H", got)
100 data = sock.recv(rlen)
101 else:
102 data = sock.recv(4096)
103
104 return data
105
106 def _hasValidResolverCertificate(self):
107
108 for cert in self._resolverCertificates:
109 if cert.isValid():
110 return True
111
112 return False
113
114 def clearExpiredResolverCertificates(self):
115 newCerts = []
116
117 for cert in self._resolverCertificates:
118 if cert.isValid():
119 newCerts.append(cert)
120
121 self._resolverCertificates = newCerts
122
123 def refreshResolverCertificates(self):
124 self.clearExpiredResolverCertificates()
125
126 query = dns.message.make_query(self._providerName, dns.rdatatype.TXT, dns.rdataclass.IN)
127 data = self._sendQuery(query.to_wire())
128
129 response = dns.message.from_wire(data)
130 if response.rcode() != dns.rcode.NOERROR or len(response.answer) != 1:
131 raise Exception("Invalid response to public key request")
132
133 an = response.answer[0]
134 if an.rdclass != dns.rdataclass.IN or an.rdtype != dns.rdatatype.TXT or len(an.items) == 0:
135 raise Exception("Invalid response to public key request")
136
137 for item in an.items:
138 if len(item.strings) != 1:
139 continue
140
141 cert = DNSCryptResolverCertificate.fromBinary(item.strings[0], self._providerFingerprint)
142 if cert.isValid():
143 self._resolverCertificates.append(cert)
144
145 def getResolverCertificate(self):
146 certs = self._resolverCertificates
147 result = None
148 for cert in certs:
149 if cert.isValid():
150 if result is None or cert.serial > result.serial:
151 result = cert
152
153 return result
154
155 @staticmethod
156 def _generateNonce():
157 nonce = libnacl.utils.rand_nonce()
158 return nonce[:(DNSCryptClient.DNSCRYPT_NONCE_SIZE / 2)]
159
160 def _encryptQuery(self, queryContent, resolverCert, nonce, tcp=False):
161 header = resolverCert.clientMagic + self._publicKey + nonce
162 requiredSize = len(header) + self.DNSCRYPT_MAC_SIZE + len(queryContent)
163 paddingSize = self.DNSCRYPT_PADDED_BLOCK_SIZE - (len(queryContent) % self.DNSCRYPT_PADDED_BLOCK_SIZE)
164 # padding size should be DNSCRYPT_PADDED_BLOCK_SIZE <= padding size <= 4096
165 if not tcp and requiredSize < self.DNSCRYPT_MIN_UDP_LENGTH:
166 paddingSize += self.DNSCRYPT_MIN_UDP_LENGTH - requiredSize
167 requiredSize = self.DNSCRYPT_MIN_UDP_LENGTH
168
169 padding = '\x80'
170 idx = 0
171 while idx < (paddingSize - 1):
172 padding = padding + '\x00'
173 idx += 1
174
175 data = queryContent + padding
176 nonce = nonce + ('\x00'*(self.DNSCRYPT_NONCE_SIZE / 2))
177 box = libnacl.crypto_box(data, nonce, resolverCert.publicKey, self._privateKey)
178 return header + box
179
180 def _decryptResponse(self, encryptedResponse, resolverCert, clientNonce):
181 resolverMagic = encryptedResponse[:8]
182 if resolverMagic != self.DNSCRYPT_RESOLVER_MAGIC:
183 raise Exception("Invalid encrypted response: bad resolver magic")
184
185 nonce = encryptedResponse[8:32]
186 if nonce[0:self.DNSCRYPT_NONCE_SIZE / 2] != clientNonce:
187 raise Exception("Invalid encrypted response: bad nonce")
188
189 cleartext = libnacl.crypto_box_open(encryptedResponse[32:], nonce, resolverCert.publicKey, self._privateKey)
190 idx = len(cleartext) - 1
191 while idx > 0:
192 if cleartext[idx] != '\x00':
193 break
194 idx -= 1
195
196 if idx == 0 or cleartext[idx] != '\x80':
197 raise Exception("Invalid encrypted response: invalid padding")
198
199 idx -= 1
200 paddingLen = len(cleartext) - idx
201
202 return cleartext[:idx+1]
203
204 def query(self, queryContent, tcp=False):
205
206 if not self._hasValidResolverCertificate():
207 self.refreshResolverCertificates()
208
209 nonce = self._generateNonce()
210 resolverCert = self.getResolverCertificate()
211 if resolverCert is None:
212 raise Exception("No valid certificate found")
213 encryptedQuery = self._encryptQuery(queryContent, resolverCert, nonce, tcp)
214 encryptedResponse = self._sendQuery(encryptedQuery, tcp)
215 response = self._decryptResponse(encryptedResponse, resolverCert, nonce)
216 return response