return false;
}
+static bool sendResponseToClient(int fd, char* response, uint16_t responseLen, size_t responseSize
+#ifdef HAVE_DNSCRYPT
+ , DnsCryptContext* dnscryptCtx,
+ std::shared_ptr<DnsCryptQuery> dnsCryptQuery
+#endif
+ )
+{
+#ifdef HAVE_DNSCRYPT
+ if (dnscryptCtx && dnsCryptQuery) {
+ uint16_t encryptedResponseLen = 0;
+ int res = dnscryptCtx->encryptResponse(response, responseLen, responseSize, dnsCryptQuery, true, &encryptedResponseLen);
+ if (res == 0) {
+ responseLen = encryptedResponseLen;
+ } else {
+ /* dropping response */
+ vinfolog("Error encrypting the response, dropping.");
+ return false;
+ }
+ }
+#endif
+ if (!putNonBlockingMsgLen(fd, responseLen, g_tcpSendTimeout))
+ return false;
+
+ writen2WithTimeout(fd, response, responseLen, g_tcpSendTimeout);
+ return true;
+}
+
std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
void* tcpClientThread(int pipefd)
uint16_t qlen, rlen;
string poolname;
- const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
- const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
- const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
string largerQuery;
vector<uint8_t> rewrittenResponse;
bool ednsAdded = false;
if (!decrypted) {
if (response.size() > 0) {
- if (putNonBlockingMsgLen(ci.fd, response.size(), g_tcpSendTimeout))
- writen2WithTimeout(ci.fd, (const char *) response.data(), response.size(), g_tcpSendTimeout);
+ sendResponseToClient(ci.fd, reinterpret_cast<char*>(response.data()), response.size(), response.size(), nullptr, nullptr);
}
break;
}
}
if(dq.dh->qr) { // something turned it into a response
- if (putNonBlockingMsgLen(ci.fd, dq.len, g_tcpSendTimeout))
- writen2WithTimeout(ci.fd, query, dq.len, g_tcpSendTimeout);
-
+ sendResponseToClient(ci.fd, queryBuffer, dq.len, dq.size
+#ifdef HAVE_DNSCRYPT
+ , ci.cs->dnscryptCtx, dnsCryptQuery
+#endif
+ );
g_stats.selfAnswered++;
goto drop;
}
uint16_t cachedResponseSize = sizeof cachedResponse;
uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL;
if (packetCache->get(dq, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, allowExpired)) {
- if (putNonBlockingMsgLen(ci.fd, cachedResponseSize, g_tcpSendTimeout))
- writen2WithTimeout(ci.fd, cachedResponse, cachedResponseSize, g_tcpSendTimeout);
+ sendResponseToClient(ci.fd, cachedResponse, cachedResponseSize, sizeof cachedResponse
+#ifdef HAVE_DNSCRYPT
+ , ci.cs->dnscryptCtx, dnsCryptQuery
+#endif
+ );
g_stats.cacheHits++;
goto drop;
}
goto retry;
}
- uint16_t responseSize = rlen;
+ size_t responseSize = rlen;
#ifdef HAVE_DNSCRYPT
if (ci.cs->dnscryptCtx && (UINT16_MAX - DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) > rlen) {
responseSize += DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
#endif
char answerbuffer[responseSize];
readn2WithTimeout(dsock, answerbuffer, rlen, ds->tcpRecvTimeout);
- struct dnsheader* responseHeaders = (struct dnsheader*)answerbuffer;
- uint16_t * responseFlags = getFlagsFromDNSHeader(responseHeaders);
- /* clear the flags we are about to restore */
- *responseFlags &= restoreFlagsMask;
- /* only keep the flags we want to restore */
- origFlags &= ~restoreFlagsMask;
- /* set the saved flags as they were */
- *responseFlags |= origFlags;
char* response = answerbuffer;
uint16_t responseLen = rlen;
--ds->outstanding;
break;
}
- dh = (struct dnsheader*) response;
- DNSName rqname;
- uint16_t rqtype, rqclass;
- try {
- rqname=DNSName(response, responseLen, sizeof(dnsheader), false, &rqtype, &rqclass, &consumed);
- }
- catch(std::exception& e) {
- if(rlen > (ssize_t)sizeof(dnsheader))
- infolog("Backend %s sent us a response with id %d that did not parse: %s", ds->remote.toStringWithPort(), ntohs(dh->id), e.what());
- g_stats.nonCompliantResponses++;
- break;
- }
-
- if (rqtype != qtype || rqclass != qclass || rqname != qname) {
+ if (!responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote)) {
break;
}
- if (ednsAdded) {
- const char * optStart = NULL;
- size_t optLen = 0;
- bool last = false;
-
- int res = locateEDNSOptRR(response, responseLen, &optStart, &optLen, &last);
-
- if (res == 0) {
- if (last) {
- /* simply remove the last AR */
- responseLen -= optLen;
- uint16_t arcount = ntohs(responseHeaders->arcount);
- arcount--;
- responseHeaders->arcount = htons(arcount);
- }
- else {
- /* Removing an intermediary RR could lead to compression error */
- if (rewriteResponseWithoutEDNS(response, responseLen, rewrittenResponse) == 0) {
+ if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded,
#ifdef HAVE_DNSCRYPT
- if (ci.cs->dnscryptCtx && rewrittenResponse.capacity() < responseSize && ci.cs->dnscryptCtx) {
- /* we preserve room for dnscrypt */
- rewrittenResponse.reserve(responseSize);
- }
+ dnsCryptQuery,
#endif
- responseSize = responseLen;
- responseLen = rewrittenResponse.size();
- response = reinterpret_cast<char*>(rewrittenResponse.data());
- }
- else {
- warnlog("Error rewriting content");
- }
- }
- }
+ rewrittenResponse)) {
+ break;
}
- if(g_fixupCase) {
- string realname = qname.toDNSString();
- if (responseLen >= (sizeof(dnsheader) + realname.length())) {
- memcpy(response + sizeof(dnsheader), realname.c_str(), realname.length());
- }
- }
-
if (packetCache && !dq.skipCache) {
packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true, dh->rcode == RCode::ServFail);
}
+ if (!sendResponseToClient(ci.fd, response, responseLen, responseSize
#ifdef HAVE_DNSCRYPT
- if (ci.cs->dnscryptCtx) {
- uint16_t encryptedResponseLen = 0;
- int res = ci.cs->dnscryptCtx->encryptResponse(response, responseLen, responseSize, dnsCryptQuery, true, &encryptedResponseLen);
-
- if (res == 0) {
- responseLen = encryptedResponseLen;
- } else {
- /* dropping response */
- vinfolog("Error encrypting the response, dropping.");
- break;
- }
- }
+ , ci.cs->dnscryptCtx, dnsCryptQuery
#endif
-
- if (putNonBlockingMsgLen(ci.fd, responseLen, ds->tcpSendTimeout))
- writen2WithTimeout(ci.fd, response, responseLen, ds->tcpSendTimeout);
+ )) {
+ break;
+ }
g_stats.responses++;
struct timespec answertime;
doAvg(g_stats.latencyAvg1000000, udiff, 1000000);
}
+bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote)
+{
+ uint16_t rqtype, rqclass;
+ unsigned int consumed;
+ DNSName rqname;
+ const struct dnsheader* dh = (struct dnsheader*) response;
+
+ if (responseLen < sizeof(dnsheader)) {
+ return false;
+ }
+
+ try {
+ rqname=DNSName(response, responseLen, sizeof(dnsheader), false, &rqtype, &rqclass, &consumed);
+ }
+ catch(std::exception& e) {
+ if(responseLen > (ssize_t)sizeof(dnsheader))
+ infolog("Backend %s sent us a response with id %d that did not parse: %s", remote.toStringWithPort(), ntohs(dh->id), e.what());
+ g_stats.nonCompliantResponses++;
+ return false;
+ }
+
+ if (rqtype != qtype || rqclass != qclass || rqname != qname) {
+ return false;
+ }
+
+ return true;
+}
+
+bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded,
+#ifdef HAVE_DNSCRYPT
+ std::shared_ptr<DnsCryptQuery> dnsCryptQuery,
+#endif
+ std::vector<uint8_t>& rewrittenResponse)
+{
+ static const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
+ static const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
+ static const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
+ struct dnsheader* dh = (struct dnsheader*) *response;
+
+ if (*responseLen < sizeof(dnsheader)) {
+ return false;
+ }
+
+ if(g_fixupCase) {
+ string realname = qname.toDNSString();
+ if (*responseLen >= (sizeof(dnsheader) + realname.length())) {
+ memcpy(*response + sizeof(dnsheader), realname.c_str(), realname.length());
+ }
+ }
+
+ uint16_t * flags = getFlagsFromDNSHeader(dh);
+ /* clear the flags we are about to restore */
+ *flags &= restoreFlagsMask;
+ /* only keep the flags we want to restore */
+ origFlags &= ~restoreFlagsMask;
+ /* set the saved flags as they were */
+ *flags |= origFlags;
+
+ if (ednsAdded) {
+ const char * optStart = NULL;
+ size_t optLen = 0;
+ bool last = false;
+
+ int res = locateEDNSOptRR(*response, *responseLen, &optStart, &optLen, &last);
+
+ if (res == 0) {
+ if (last) {
+ /* simply remove the last AR */
+ *responseLen -= optLen;
+ uint16_t arcount = ntohs(dh->arcount);
+ arcount--;
+ dh->arcount = htons(arcount);
+ }
+ else {
+ /* Removing an intermediary RR could lead to compression error */
+ if (rewriteResponseWithoutEDNS(*response, *responseLen, rewrittenResponse) == 0) {
+ *responseLen = rewrittenResponse.size();
+#ifdef HAVE_DNSCRYPT
+ if (dnsCryptQuery && (UINT16_MAX - DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) > *responseLen) {
+ rewrittenResponse.reserve(*responseLen + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE);
+ }
+ *responseSize = rewrittenResponse.capacity();
+#endif
+ *response = reinterpret_cast<char*>(rewrittenResponse.data());
+ }
+ else {
+ warnlog("Error rewriting content");
+ }
+ }
+ }
+ }
+
+ return true;
+}
+
+static bool sendUDPResponse(int origFD, char* response, uint16_t responseLen, size_t responseSize,
+#ifdef HAVE_DNSCRYPT
+ std::shared_ptr<DnsCryptQuery> dnsCryptQuery,
+#endif
+ int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
+{
+#ifdef HAVE_DNSCRYPT
+ uint16_t encryptedResponseLen = 0;
+ if(dnsCryptQuery) {
+ int res = dnsCryptQuery->ctx->encryptResponse(response, responseLen, responseSize, dnsCryptQuery, false, &encryptedResponseLen);
+
+ if (res == 0) {
+ responseLen = encryptedResponseLen;
+ } else {
+ /* dropping response */
+ vinfolog("Error encrypting the response, dropping.");
+ return false;
+ }
+ }
+#endif
+
+ if(delayMsec && g_delay) {
+ DelayedPacket dp{origFD, string(response,responseLen), origRemote, origDest};
+ g_delay->submit(dp, delayMsec);
+ }
+ else {
+ if(origDest.sin4.sin_family == 0)
+ sendto(origFD, response, responseLen, 0, (struct sockaddr*)&origRemote, origRemote.getSocklen());
+ else
+ sendfromto(origFD, response, responseLen, 0, origDest, origRemote);
+ }
+
+ return true;
+}
+
// listens on a dedicated socket, lobs answers from downstream servers to original requestors
void* responderThread(std::shared_ptr<DownstreamState> state)
{
char packet[4096];
#endif
static_assert(sizeof(packet) <= UINT16_MAX, "Packet size should fit in a uint16_t");
- const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
- const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
- const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
vector<uint8_t> rewrittenResponse;
- uint16_t qtype, qclass;
struct dnsheader* dh = (struct dnsheader*)packet;
for(;;) {
ssize_t got = recv(state->fd, packet, sizeof(packet), 0);
char * response = packet;
-#ifdef HAVE_DNSCRYPT
- uint16_t responseSize = sizeof(packet);
-#endif
+ size_t responseSize = sizeof(packet);
if (got < (ssize_t) sizeof(dnsheader))
continue;
- size_t responseLen = (size_t) got;
+ uint16_t responseLen = (size_t) got;
if(dh->id >= state->idStates.size())
continue;
mostly mess up the outstanding counter.
*/
ids->age = 0;
- unsigned int consumed;
- DNSName qname;
- try {
- qname=DNSName(packet, responseLen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
- }
- catch(std::exception& e) {
- if(got > (ssize_t)sizeof(dnsheader))
- infolog("Backend %s sent us a response with id %d that did not parse: %s", state->remote.toStringWithPort(), ntohs(dh->id), e.what());
- g_stats.nonCompliantResponses++;
+
+ if (!responseContentMatches(response, responseLen, ids->qname, ids->qtype, ids->qclass, state->remote)) {
continue;
}
- if (qtype != ids->qtype || qclass != ids->qclass || qname != ids->qname)
- continue;
--state->outstanding; // you'd think an attacker could game this, but we're using connected socket
- if(g_fixupCase) {
- string realname = ids->qname.toDNSString();
- if (responseLen >= (sizeof(dnsheader) + realname.length())) {
- memcpy(packet+12, realname.c_str(), realname.length());
- }
- }
-
if(dh->tc && g_truncateTC) {
- truncateTC(packet, (uint16_t*) &responseLen);
+ truncateTC(response, (uint16_t*) &responseLen);
}
- uint16_t * flags = getFlagsFromDNSHeader(dh);
- uint16_t origFlags = ids->origFlags;
- /* clear the flags we are about to restore */
- *flags &= restoreFlagsMask;
- /* only keep the flags we want to restore */
- origFlags &= ~restoreFlagsMask;
- /* set the saved flags as they were */
- *flags |= origFlags;
dh->id = ids->origID;
- if (ids->ednsAdded) {
- const char * optStart = NULL;
- size_t optLen = 0;
- bool last = false;
-
- int res = locateEDNSOptRR(response, responseLen, &optStart, &optLen, &last);
-
- if (res == 0) {
- if (last) {
- /* simply remove the last AR */
- responseLen -= optLen;
- uint16_t arcount = ntohs(dh->arcount);
- arcount--;
- dh->arcount = htons(arcount);
- }
- else {
- /* Removing an intermediary RR could lead to compression error */
- if (rewriteResponseWithoutEDNS(response, responseLen, rewrittenResponse) == 0) {
- responseLen = rewrittenResponse.size();
+ if (!fixUpResponse(&response, &responseLen, &responseSize, ids->qname, ids->origFlags, ids->ednsAdded,
#ifdef HAVE_DNSCRYPT
- if (ids->dnsCryptQuery && (UINT16_MAX - DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) > responseLen) {
- rewrittenResponse.reserve(responseLen + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE);
- }
- responseSize = rewrittenResponse.capacity();
+ ids->dnsCryptQuery,
#endif
- response = reinterpret_cast<char*>(rewrittenResponse.data());
- }
- else {
- warnlog("Error rewriting content");
- }
- }
- }
+ rewrittenResponse)) {
+ continue;
}
- g_stats.responses++;
-
if (ids->packetCache && !ids->skipCache) {
- ids->packetCache->insert(ids->cacheKey, qname, qtype, qclass, response, responseLen, false, dh->rcode == RCode::ServFail);
+ ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode == RCode::ServFail);
}
+ sendUDPResponse(origFD, response, responseLen, responseSize,
#ifdef HAVE_DNSCRYPT
- uint16_t encryptedResponseLen = 0;
- if(ids->dnsCryptQuery) {
- int res = ids->dnsCryptQuery->ctx->encryptResponse(response, responseLen, responseSize, ids->dnsCryptQuery, false, &encryptedResponseLen);
-
- if (res == 0) {
- responseLen = encryptedResponseLen;
- } else {
- /* dropping response */
- vinfolog("Error encrypting the response, dropping.");
- continue;
- }
- }
+ ids->dnsCryptQuery,
#endif
+ ids->delayMsec, ids->origDest, ids->origRemote);
- if(ids->delayMsec && g_delay) {
- DelayedPacket dp{origFD, string(response,responseLen), ids->origRemote, ids->origDest};
- g_delay->submit(dp, ids->delayMsec);
- }
- else {
- if(ids->origDest.sin4.sin_family == 0)
- sendto(origFD, response, responseLen, 0, (struct sockaddr*)&ids->origRemote, ids->origRemote.getSocklen());
- else
- sendfromto(origFD, response, responseLen, 0, ids->origDest, ids->origRemote);
- }
+ g_stats.responses++;
double udiff = ids->sentTime.udiff();
vinfolog("Got answer from %s, relayed to %s, took %f usec", state->remote.toStringWithPort(), ids->origRemote.toStringWithPort(), udiff);
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
std::lock_guard<std::mutex> lock(g_rings.respMutex);
- g_rings.respRing.push_back({ts, ids->origRemote, qname, qtype, (unsigned int)udiff, (unsigned int)got, *dh, state->remote});
+ g_rings.respRing.push_back({ts, ids->origRemote, ids->qname, ids->qtype, (unsigned int)udiff, (unsigned int)got, *dh, state->remote});
}
if(dh->rcode == RCode::ServFail)
g_stats.servfailResponses++;
if (!decrypted) {
if (response.size() > 0) {
ComboAddress dest;
- if(HarvestDestinationAddress(&msgh, &dest))
- sendfromto(cs->udpFD, (const char *) response.data(), response.size(), 0, dest, remote);
- else
- sendto(cs->udpFD, response.data(), response.size(), 0, (struct sockaddr*)&remote, remote.getSocklen());
+ if(!HarvestDestinationAddress(&msgh, &dest)) {
+ dest.sin4.sin_family = 0;
+ }
+ sendUDPResponse(cs->udpFD, reinterpret_cast<char*>(response.data()), response.size(), response.size(),
+#ifdef HAVE_DNSCRYPT
+ nullptr,
+#endif
+ 0, dest, remote);
}
continue;
}
if(dq.dh->qr) { // something turned it into a response
char* response = query;
uint16_t responseLen = dq.len;
-#ifdef HAVE_DNSCRYPT
uint16_t responseSize = dq.size;
-#endif
g_stats.selfAnswered++;
-#ifdef HAVE_DNSCRYPT
- uint16_t encryptedResponseLen = 0;
-
- if(dnsCryptQuery) {
- int res = cs->dnscryptCtx->encryptResponse(response, responseLen, responseSize, dnsCryptQuery, false, &encryptedResponseLen);
-
- if (res == 0) {
- responseLen = encryptedResponseLen;
- } else {
- /* dropping response */
- continue;
- }
+ ComboAddress dest;
+ if(!HarvestDestinationAddress(&msgh, &dest)) {
+ dest.sin4.sin_family = 0;
}
+ sendUDPResponse(cs->udpFD, response, responseLen, responseSize,
+#ifdef HAVE_DNSCRYPT
+ dnsCryptQuery,
#endif
- ComboAddress dest;
- if(HarvestDestinationAddress(&msgh, &dest))
- sendfromto(cs->udpFD, response, responseLen, 0, dest, remote);
- else
- sendto(cs->udpFD, response, responseLen, 0, (struct sockaddr*)&remote, remote.getSocklen());
-
+ 0, dest, remote);
continue;
}
uint32_t allowExpired = ss ? 0 : g_staleCacheEntriesTTL;
if (packetCache->get(dq, consumed, dh->id, cachedResponse, &cachedResponseSize, &cacheKey, allowExpired)) {
ComboAddress dest;
- if(HarvestDestinationAddress(&msgh, &dest))
- sendfromto(cs->udpFD, cachedResponse, cachedResponseSize, 0, dest, remote);
- else
- sendto(cs->udpFD, cachedResponse, cachedResponseSize, 0, (struct sockaddr*)&remote, remote.getSocklen());
+ if(!HarvestDestinationAddress(&msgh, &dest)) {
+ dest.sin4.sin_family = 0;
+ }
+ sendUDPResponse(cs->udpFD, cachedResponse, cachedResponseSize, sizeof cachedResponse,
+#ifdef HAVE_DNSCRYPT
+ dnsCryptQuery,
+#endif
+ 0, dest, remote);
g_stats.cacheHits++;
g_stats.latency0_1++; // we're not going to measure this
doLatencyAverages(0); // same
bool getLuaNoSideEffect(); // set if there were only explicit declarations of _no_ side effect
void resetLuaSideEffect(); // reset to indeterminate state
+bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote);
+bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded,
+#ifdef HAVE_DNSCRYPT
+ std::shared_ptr<DnsCryptQuery> dnsCryptQuery,
+#endif
+ std::vector<uint8_t>& rewrittenResponse);
+
#ifdef HAVE_DNSCRYPT
extern std::vector<std::tuple<ComboAddress,DnsCryptContext,bool>> g_dnsCryptLocals;
self._resolverPort = resolverPort
self._resolverCertificates = []
self._publicKey, self._privateKey = libnacl.crypto_box_keypair()
+ self._timeout = timeout
addrType = self._addrToSocketType(self._resolverAddress)
self._sock = socket.socket(addrType, socket.SOCK_DGRAM)
self._sock.settimeout(timeout)
self._sock.connect((self._resolverAddress, self._resolverPort))
- def _sendQuery(self, queryContent):
- self._sock.send(queryContent)
- data = self._sock.recv(4096)
+ def _sendQuery(self, queryContent, tcp=False):
+ if tcp:
+ addrType = self._addrToSocketType(self._resolverAddress)
+ sock = socket.socket(addrType, socket.SOCK_STREAM)
+ sock.settimeout(self._timeout)
+ sock.connect((self._resolverAddress, self._resolverPort))
+ sock.send(struct.pack("!H", len(queryContent)))
+ else:
+ sock = self._sock
+
+ sock.send(queryContent)
+
+ data = None
+ if tcp:
+ got = sock.recv(2)
+ print(len(got))
+ if got:
+ (rlen,) = struct.unpack("!H", got)
+ data = sock.recv(rlen)
+ else:
+ data = sock.recv(4096)
+
return data
def _hasValidResolverCertificate(self):
nonce = libnacl.utils.rand_nonce()
return nonce[:(DNSCryptClient.DNSCRYPT_NONCE_SIZE / 2)]
- def _encryptQuery(self, queryContent, resolverCert, nonce):
+ def _encryptQuery(self, queryContent, resolverCert, nonce, tcp=False):
header = resolverCert.clientMagic + self._publicKey + nonce
requiredSize = len(header) + self.DNSCRYPT_MAC_SIZE + len(queryContent)
paddingSize = self.DNSCRYPT_PADDED_BLOCK_SIZE - (len(queryContent) % self.DNSCRYPT_PADDED_BLOCK_SIZE)
# padding size should be DNSCRYPT_PADDED_BLOCK_SIZE <= padding size <= 4096
- if requiredSize < self.DNSCRYPT_MIN_UDP_LENGTH:
+ if not tcp and requiredSize < self.DNSCRYPT_MIN_UDP_LENGTH:
paddingSize += self.DNSCRYPT_MIN_UDP_LENGTH - requiredSize
requiredSize = self.DNSCRYPT_MIN_UDP_LENGTH
return cleartext[:idx+1]
- def query(self, queryContent):
+ def query(self, queryContent, tcp=False):
if not self._hasValidResolverCertificate():
self._getResolverCertificates()
resolverCert = self._getResolverCertificate()
if resolverCert is None:
raise Exception("No valid certificate found")
- encryptedQuery = self._encryptQuery(queryContent, resolverCert, nonce)
- encryptedResponse = self._sendQuery(encryptedQuery)
+ encryptedQuery = self._encryptQuery(queryContent, resolverCert, nonce, tcp)
+ encryptedResponse = self._sendQuery(encryptedQuery, tcp)
response = self._decryptResponse(encryptedResponse, resolverCert, nonce)
return response
#!/usr/bin/env python
import time
-import unittest
import dns
import dns.message
from dnsdisttests import DNSDistTest
3600,
dns.rdataclass.IN,
dns.rdatatype.A,
- '127.0.0.1')
+ '192.2.0.1')
response.answer.append(rrset)
self._toResponderQueue.put(response)
self.assertEquals(query, receivedQuery)
self.assertEquals(response, receivedResponse)
+ self._toResponderQueue.put(response)
+ data = client.query(query.to_wire(), tcp=True)
+ receivedResponse = dns.message.from_wire(data)
+ receivedQuery = None
+ if not self._fromResponderQueue.empty():
+ receivedQuery = self._fromResponderQueue.get(query)
+
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(response, receivedResponse)
+
def testResponseLargerThanPaddedQuery(self):
"""
DNSCrypt: response larger than query
self.assertTrue(len(receivedResponse.authority) == 0)
self.assertTrue(len(receivedResponse.additional) == 0)
-if __name__ == '__main__':
- unittest.main()
- exit(0)
+class TestDNSCryptWithCache(DNSDistTest):
+ _dnsDistPortDNSCrypt = 8443
+ _providerFingerprint = 'E1D7:2108:9A59:BF8D:F101:16FA:ED5E:EA6A:9F6C:C78F:7F91:AF6B:027E:62F4:69C3:B1AA'
+ _providerName = "2.provider.name"
+ _resolverCertificateSerial = 42
+ # valid from 60s ago until 2h from now
+ _resolverCertificateValidFrom = time.time() - 60
+ _resolverCertificateValidUntil = time.time() + 7200
+ _config_params = ['_resolverCertificateSerial', '_resolverCertificateValidFrom', '_resolverCertificateValidUntil', '_dnsDistPortDNSCrypt', '_providerName', '_testServerPort']
+ _config_template = """
+ generateDNSCryptCertificate("DNSCryptProviderPrivate.key", "DNSCryptResolver.cert", "DNSCryptResolver.key", %d, %d, %d)
+ addDNSCryptBind("127.0.0.1:%d", "%s", "DNSCryptResolver.cert", "DNSCryptResolver.key")
+ pc = newPacketCache(5, 86400, 1)
+ getPool(""):setCache(pc)
+ newServer{address="127.0.0.1:%s"}
+ """
+
+ def testCachedSimpleA(self):
+ """
+ DNSCrypt: encrypted A query served from cache
+ """
+ client = dnscrypt.DNSCryptClient(self._providerName, self._providerFingerprint, "127.0.0.1", 8443)
+ name = 'cacheda.dnscrypt.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '192.2.0.1')
+ response.answer.append(rrset)
+
+ # first query to fill the cache
+ self._toResponderQueue.put(response)
+ data = client.query(query.to_wire())
+ receivedResponse = dns.message.from_wire(data)
+ receivedQuery = None
+ if not self._fromResponderQueue.empty():
+ receivedQuery = self._fromResponderQueue.get(query)
+
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(response, receivedResponse)
+
+ # second query should get a cached response
+ data = client.query(query.to_wire())
+ receivedResponse = dns.message.from_wire(data)
+ receivedQuery = None
+ if not self._fromResponderQueue.empty():
+ receivedQuery = self._fromResponderQueue.get(query)
+
+ self.assertEquals(receivedQuery, None)
+ self.assertTrue(receivedResponse)
+ self.assertEquals(response, receivedResponse)