From: Remi Gacogne Date: Tue, 15 Mar 2016 14:56:37 +0000 (+0100) Subject: dnsdist: Refactor duplicated response handling code (UDP/TCP) X-Git-Tag: dnsdist-1.0.0-beta1~79^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fcffc5851fd6e75c6b3223af64bae752d133053e;p=thirdparty%2Fpdns.git dnsdist: Refactor duplicated response handling code (UDP/TCP) --- diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index d4320e7aab..1df6b06d24 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -133,6 +133,33 @@ catch(...) { return false; } +static bool sendResponseToClient(int fd, char* response, uint16_t responseLen, size_t responseSize +#ifdef HAVE_DNSCRYPT + , DnsCryptContext* dnscryptCtx, + std::shared_ptr dnsCryptQuery +#endif + ) +{ +#ifdef HAVE_DNSCRYPT + if (dnscryptCtx && dnsCryptQuery) { + uint16_t encryptedResponseLen = 0; + int res = dnscryptCtx->encryptResponse(response, responseLen, responseSize, dnsCryptQuery, true, &encryptedResponseLen); + if (res == 0) { + responseLen = encryptedResponseLen; + } else { + /* dropping response */ + vinfolog("Error encrypting the response, dropping."); + return false; + } + } +#endif + if (!putNonBlockingMsgLen(fd, responseLen, g_tcpSendTimeout)) + return false; + + writen2WithTimeout(fd, response, responseLen, g_tcpSendTimeout); + return true; +} + std::shared_ptr g_tcpclientthreads; void* tcpClientThread(int pipefd) @@ -173,9 +200,6 @@ void* tcpClientThread(int pipefd) uint16_t qlen, rlen; string poolname; - const uint16_t rdMask = 1 << FLAGS_RD_OFFSET; - const uint16_t cdMask = 1 << FLAGS_CD_OFFSET; - const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask); string largerQuery; vector rewrittenResponse; bool ednsAdded = false; @@ -214,8 +238,7 @@ void* tcpClientThread(int pipefd) if (!decrypted) { if (response.size() > 0) { - if (putNonBlockingMsgLen(ci.fd, response.size(), g_tcpSendTimeout)) - writen2WithTimeout(ci.fd, (const char *) response.data(), response.size(), g_tcpSendTimeout); + sendResponseToClient(ci.fd, reinterpret_cast(response.data()), response.size(), response.size(), nullptr, nullptr); } break; } @@ -323,9 +346,11 @@ void* tcpClientThread(int pipefd) } if(dq.dh->qr) { // something turned it into a response - if (putNonBlockingMsgLen(ci.fd, dq.len, g_tcpSendTimeout)) - writen2WithTimeout(ci.fd, query, dq.len, g_tcpSendTimeout); - + sendResponseToClient(ci.fd, queryBuffer, dq.len, dq.size +#ifdef HAVE_DNSCRYPT + , ci.cs->dnscryptCtx, dnsCryptQuery +#endif + ); g_stats.selfAnswered++; goto drop; } @@ -359,8 +384,11 @@ void* tcpClientThread(int pipefd) uint16_t cachedResponseSize = sizeof cachedResponse; uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL; if (packetCache->get(dq, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, allowExpired)) { - if (putNonBlockingMsgLen(ci.fd, cachedResponseSize, g_tcpSendTimeout)) - writen2WithTimeout(ci.fd, cachedResponse, cachedResponseSize, g_tcpSendTimeout); + sendResponseToClient(ci.fd, cachedResponse, cachedResponseSize, sizeof cachedResponse +#ifdef HAVE_DNSCRYPT + , ci.cs->dnscryptCtx, dnsCryptQuery +#endif + ); g_stats.cacheHits++; goto drop; } @@ -430,7 +458,7 @@ void* tcpClientThread(int pipefd) goto retry; } - uint16_t responseSize = rlen; + size_t responseSize = rlen; #ifdef HAVE_DNSCRYPT if (ci.cs->dnscryptCtx && (UINT16_MAX - DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) > rlen) { responseSize += DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE; @@ -438,14 +466,6 @@ void* tcpClientThread(int pipefd) #endif char answerbuffer[responseSize]; readn2WithTimeout(dsock, answerbuffer, rlen, ds->tcpRecvTimeout); - struct dnsheader* responseHeaders = (struct dnsheader*)answerbuffer; - uint16_t * responseFlags = getFlagsFromDNSHeader(responseHeaders); - /* clear the flags we are about to restore */ - *responseFlags &= restoreFlagsMask; - /* only keep the flags we want to restore */ - origFlags &= ~restoreFlagsMask; - /* set the saved flags as they were */ - *responseFlags |= origFlags; char* response = answerbuffer; uint16_t responseLen = rlen; --ds->outstanding; @@ -455,86 +475,29 @@ void* tcpClientThread(int pipefd) break; } - dh = (struct dnsheader*) response; - DNSName rqname; - uint16_t rqtype, rqclass; - try { - rqname=DNSName(response, responseLen, sizeof(dnsheader), false, &rqtype, &rqclass, &consumed); - } - catch(std::exception& e) { - if(rlen > (ssize_t)sizeof(dnsheader)) - infolog("Backend %s sent us a response with id %d that did not parse: %s", ds->remote.toStringWithPort(), ntohs(dh->id), e.what()); - g_stats.nonCompliantResponses++; - break; - } - - if (rqtype != qtype || rqclass != qclass || rqname != qname) { + if (!responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote)) { break; } - if (ednsAdded) { - const char * optStart = NULL; - size_t optLen = 0; - bool last = false; - - int res = locateEDNSOptRR(response, responseLen, &optStart, &optLen, &last); - - if (res == 0) { - if (last) { - /* simply remove the last AR */ - responseLen -= optLen; - uint16_t arcount = ntohs(responseHeaders->arcount); - arcount--; - responseHeaders->arcount = htons(arcount); - } - else { - /* Removing an intermediary RR could lead to compression error */ - if (rewriteResponseWithoutEDNS(response, responseLen, rewrittenResponse) == 0) { + if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded, #ifdef HAVE_DNSCRYPT - if (ci.cs->dnscryptCtx && rewrittenResponse.capacity() < responseSize && ci.cs->dnscryptCtx) { - /* we preserve room for dnscrypt */ - rewrittenResponse.reserve(responseSize); - } + dnsCryptQuery, #endif - responseSize = responseLen; - responseLen = rewrittenResponse.size(); - response = reinterpret_cast(rewrittenResponse.data()); - } - else { - warnlog("Error rewriting content"); - } - } - } + rewrittenResponse)) { + break; } - if(g_fixupCase) { - string realname = qname.toDNSString(); - if (responseLen >= (sizeof(dnsheader) + realname.length())) { - memcpy(response + sizeof(dnsheader), realname.c_str(), realname.length()); - } - } - if (packetCache && !dq.skipCache) { packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true, dh->rcode == RCode::ServFail); } + if (!sendResponseToClient(ci.fd, response, responseLen, responseSize #ifdef HAVE_DNSCRYPT - if (ci.cs->dnscryptCtx) { - uint16_t encryptedResponseLen = 0; - int res = ci.cs->dnscryptCtx->encryptResponse(response, responseLen, responseSize, dnsCryptQuery, true, &encryptedResponseLen); - - if (res == 0) { - responseLen = encryptedResponseLen; - } else { - /* dropping response */ - vinfolog("Error encrypting the response, dropping."); - break; - } - } + , ci.cs->dnscryptCtx, dnsCryptQuery #endif - - if (putNonBlockingMsgLen(ci.fd, responseLen, ds->tcpSendTimeout)) - writen2WithTimeout(ci.fd, response, responseLen, ds->tcpSendTimeout); + )) { + break; + } g_stats.responses++; struct timespec answertime; diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 0fce07e9b3..cb42bd76b9 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -159,6 +159,136 @@ static void doLatencyAverages(double udiff) doAvg(g_stats.latencyAvg1000000, udiff, 1000000); } +bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote) +{ + uint16_t rqtype, rqclass; + unsigned int consumed; + DNSName rqname; + const struct dnsheader* dh = (struct dnsheader*) response; + + if (responseLen < sizeof(dnsheader)) { + return false; + } + + try { + rqname=DNSName(response, responseLen, sizeof(dnsheader), false, &rqtype, &rqclass, &consumed); + } + catch(std::exception& e) { + if(responseLen > (ssize_t)sizeof(dnsheader)) + infolog("Backend %s sent us a response with id %d that did not parse: %s", remote.toStringWithPort(), ntohs(dh->id), e.what()); + g_stats.nonCompliantResponses++; + return false; + } + + if (rqtype != qtype || rqclass != qclass || rqname != qname) { + return false; + } + + return true; +} + +bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, +#ifdef HAVE_DNSCRYPT + std::shared_ptr dnsCryptQuery, +#endif + std::vector& rewrittenResponse) +{ + static const uint16_t rdMask = 1 << FLAGS_RD_OFFSET; + static const uint16_t cdMask = 1 << FLAGS_CD_OFFSET; + static const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask); + struct dnsheader* dh = (struct dnsheader*) *response; + + if (*responseLen < sizeof(dnsheader)) { + return false; + } + + if(g_fixupCase) { + string realname = qname.toDNSString(); + if (*responseLen >= (sizeof(dnsheader) + realname.length())) { + memcpy(*response + sizeof(dnsheader), realname.c_str(), realname.length()); + } + } + + uint16_t * flags = getFlagsFromDNSHeader(dh); + /* clear the flags we are about to restore */ + *flags &= restoreFlagsMask; + /* only keep the flags we want to restore */ + origFlags &= ~restoreFlagsMask; + /* set the saved flags as they were */ + *flags |= origFlags; + + if (ednsAdded) { + const char * optStart = NULL; + size_t optLen = 0; + bool last = false; + + int res = locateEDNSOptRR(*response, *responseLen, &optStart, &optLen, &last); + + if (res == 0) { + if (last) { + /* simply remove the last AR */ + *responseLen -= optLen; + uint16_t arcount = ntohs(dh->arcount); + arcount--; + dh->arcount = htons(arcount); + } + else { + /* Removing an intermediary RR could lead to compression error */ + if (rewriteResponseWithoutEDNS(*response, *responseLen, rewrittenResponse) == 0) { + *responseLen = rewrittenResponse.size(); +#ifdef HAVE_DNSCRYPT + if (dnsCryptQuery && (UINT16_MAX - DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) > *responseLen) { + rewrittenResponse.reserve(*responseLen + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE); + } + *responseSize = rewrittenResponse.capacity(); +#endif + *response = reinterpret_cast(rewrittenResponse.data()); + } + else { + warnlog("Error rewriting content"); + } + } + } + } + + return true; +} + +static bool sendUDPResponse(int origFD, char* response, uint16_t responseLen, size_t responseSize, +#ifdef HAVE_DNSCRYPT + std::shared_ptr dnsCryptQuery, +#endif + int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote) +{ +#ifdef HAVE_DNSCRYPT + uint16_t encryptedResponseLen = 0; + if(dnsCryptQuery) { + int res = dnsCryptQuery->ctx->encryptResponse(response, responseLen, responseSize, dnsCryptQuery, false, &encryptedResponseLen); + + if (res == 0) { + responseLen = encryptedResponseLen; + } else { + /* dropping response */ + vinfolog("Error encrypting the response, dropping."); + return false; + } + } +#endif + + if(delayMsec && g_delay) { + DelayedPacket dp{origFD, string(response,responseLen), origRemote, origDest}; + g_delay->submit(dp, delayMsec); + } + else { + if(origDest.sin4.sin_family == 0) + sendto(origFD, response, responseLen, 0, (struct sockaddr*)&origRemote, origRemote.getSocklen()); + else + sendfromto(origFD, response, responseLen, 0, origDest, origRemote); + } + + return true; +} + // listens on a dedicated socket, lobs answers from downstream servers to original requestors void* responderThread(std::shared_ptr state) { @@ -168,24 +298,18 @@ void* responderThread(std::shared_ptr state) char packet[4096]; #endif static_assert(sizeof(packet) <= UINT16_MAX, "Packet size should fit in a uint16_t"); - const uint16_t rdMask = 1 << FLAGS_RD_OFFSET; - const uint16_t cdMask = 1 << FLAGS_CD_OFFSET; - const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask); vector rewrittenResponse; - uint16_t qtype, qclass; struct dnsheader* dh = (struct dnsheader*)packet; for(;;) { ssize_t got = recv(state->fd, packet, sizeof(packet), 0); char * response = packet; -#ifdef HAVE_DNSCRYPT - uint16_t responseSize = sizeof(packet); -#endif + size_t responseSize = sizeof(packet); if (got < (ssize_t) sizeof(dnsheader)) continue; - size_t responseLen = (size_t) got; + uint16_t responseLen = (size_t) got; if(dh->id >= state->idStates.size()) continue; @@ -202,108 +326,38 @@ void* responderThread(std::shared_ptr state) mostly mess up the outstanding counter. */ ids->age = 0; - unsigned int consumed; - DNSName qname; - try { - qname=DNSName(packet, responseLen, sizeof(dnsheader), false, &qtype, &qclass, &consumed); - } - catch(std::exception& e) { - if(got > (ssize_t)sizeof(dnsheader)) - infolog("Backend %s sent us a response with id %d that did not parse: %s", state->remote.toStringWithPort(), ntohs(dh->id), e.what()); - g_stats.nonCompliantResponses++; + + if (!responseContentMatches(response, responseLen, ids->qname, ids->qtype, ids->qclass, state->remote)) { continue; } - if (qtype != ids->qtype || qclass != ids->qclass || qname != ids->qname) - continue; --state->outstanding; // you'd think an attacker could game this, but we're using connected socket - if(g_fixupCase) { - string realname = ids->qname.toDNSString(); - if (responseLen >= (sizeof(dnsheader) + realname.length())) { - memcpy(packet+12, realname.c_str(), realname.length()); - } - } - if(dh->tc && g_truncateTC) { - truncateTC(packet, (uint16_t*) &responseLen); + truncateTC(response, (uint16_t*) &responseLen); } - uint16_t * flags = getFlagsFromDNSHeader(dh); - uint16_t origFlags = ids->origFlags; - /* clear the flags we are about to restore */ - *flags &= restoreFlagsMask; - /* only keep the flags we want to restore */ - origFlags &= ~restoreFlagsMask; - /* set the saved flags as they were */ - *flags |= origFlags; dh->id = ids->origID; - if (ids->ednsAdded) { - const char * optStart = NULL; - size_t optLen = 0; - bool last = false; - - int res = locateEDNSOptRR(response, responseLen, &optStart, &optLen, &last); - - if (res == 0) { - if (last) { - /* simply remove the last AR */ - responseLen -= optLen; - uint16_t arcount = ntohs(dh->arcount); - arcount--; - dh->arcount = htons(arcount); - } - else { - /* Removing an intermediary RR could lead to compression error */ - if (rewriteResponseWithoutEDNS(response, responseLen, rewrittenResponse) == 0) { - responseLen = rewrittenResponse.size(); + if (!fixUpResponse(&response, &responseLen, &responseSize, ids->qname, ids->origFlags, ids->ednsAdded, #ifdef HAVE_DNSCRYPT - if (ids->dnsCryptQuery && (UINT16_MAX - DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) > responseLen) { - rewrittenResponse.reserve(responseLen + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE); - } - responseSize = rewrittenResponse.capacity(); + ids->dnsCryptQuery, #endif - response = reinterpret_cast(rewrittenResponse.data()); - } - else { - warnlog("Error rewriting content"); - } - } - } + rewrittenResponse)) { + continue; } - g_stats.responses++; - if (ids->packetCache && !ids->skipCache) { - ids->packetCache->insert(ids->cacheKey, qname, qtype, qclass, response, responseLen, false, dh->rcode == RCode::ServFail); + ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode == RCode::ServFail); } + sendUDPResponse(origFD, response, responseLen, responseSize, #ifdef HAVE_DNSCRYPT - uint16_t encryptedResponseLen = 0; - if(ids->dnsCryptQuery) { - int res = ids->dnsCryptQuery->ctx->encryptResponse(response, responseLen, responseSize, ids->dnsCryptQuery, false, &encryptedResponseLen); - - if (res == 0) { - responseLen = encryptedResponseLen; - } else { - /* dropping response */ - vinfolog("Error encrypting the response, dropping."); - continue; - } - } + ids->dnsCryptQuery, #endif + ids->delayMsec, ids->origDest, ids->origRemote); - if(ids->delayMsec && g_delay) { - DelayedPacket dp{origFD, string(response,responseLen), ids->origRemote, ids->origDest}; - g_delay->submit(dp, ids->delayMsec); - } - else { - if(ids->origDest.sin4.sin_family == 0) - sendto(origFD, response, responseLen, 0, (struct sockaddr*)&ids->origRemote, ids->origRemote.getSocklen()); - else - sendfromto(origFD, response, responseLen, 0, ids->origDest, ids->origRemote); - } + g_stats.responses++; double udiff = ids->sentTime.udiff(); vinfolog("Got answer from %s, relayed to %s, took %f usec", state->remote.toStringWithPort(), ids->origRemote.toStringWithPort(), udiff); @@ -312,7 +366,7 @@ void* responderThread(std::shared_ptr state) struct timespec ts; clock_gettime(CLOCK_MONOTONIC, &ts); std::lock_guard lock(g_rings.respMutex); - g_rings.respRing.push_back({ts, ids->origRemote, qname, qtype, (unsigned int)udiff, (unsigned int)got, *dh, state->remote}); + g_rings.respRing.push_back({ts, ids->origRemote, ids->qname, ids->qtype, (unsigned int)udiff, (unsigned int)got, *dh, state->remote}); } if(dh->rcode == RCode::ServFail) g_stats.servfailResponses++; @@ -682,10 +736,14 @@ try if (!decrypted) { if (response.size() > 0) { ComboAddress dest; - if(HarvestDestinationAddress(&msgh, &dest)) - sendfromto(cs->udpFD, (const char *) response.data(), response.size(), 0, dest, remote); - else - sendto(cs->udpFD, response.data(), response.size(), 0, (struct sockaddr*)&remote, remote.getSocklen()); + if(!HarvestDestinationAddress(&msgh, &dest)) { + dest.sin4.sin_family = 0; + } + sendUDPResponse(cs->udpFD, reinterpret_cast(response.data()), response.size(), response.size(), +#ifdef HAVE_DNSCRYPT + nullptr, +#endif + 0, dest, remote); } continue; } @@ -797,31 +855,18 @@ try if(dq.dh->qr) { // something turned it into a response char* response = query; uint16_t responseLen = dq.len; -#ifdef HAVE_DNSCRYPT uint16_t responseSize = dq.size; -#endif g_stats.selfAnswered++; -#ifdef HAVE_DNSCRYPT - uint16_t encryptedResponseLen = 0; - - if(dnsCryptQuery) { - int res = cs->dnscryptCtx->encryptResponse(response, responseLen, responseSize, dnsCryptQuery, false, &encryptedResponseLen); - - if (res == 0) { - responseLen = encryptedResponseLen; - } else { - /* dropping response */ - continue; - } + ComboAddress dest; + if(!HarvestDestinationAddress(&msgh, &dest)) { + dest.sin4.sin_family = 0; } + sendUDPResponse(cs->udpFD, response, responseLen, responseSize, +#ifdef HAVE_DNSCRYPT + dnsCryptQuery, #endif - ComboAddress dest; - if(HarvestDestinationAddress(&msgh, &dest)) - sendfromto(cs->udpFD, response, responseLen, 0, dest, remote); - else - sendto(cs->udpFD, response, responseLen, 0, (struct sockaddr*)&remote, remote.getSocklen()); - + 0, dest, remote); continue; } @@ -847,10 +892,14 @@ try uint32_t allowExpired = ss ? 0 : g_staleCacheEntriesTTL; if (packetCache->get(dq, consumed, dh->id, cachedResponse, &cachedResponseSize, &cacheKey, allowExpired)) { ComboAddress dest; - if(HarvestDestinationAddress(&msgh, &dest)) - sendfromto(cs->udpFD, cachedResponse, cachedResponseSize, 0, dest, remote); - else - sendto(cs->udpFD, cachedResponse, cachedResponseSize, 0, (struct sockaddr*)&remote, remote.getSocklen()); + if(!HarvestDestinationAddress(&msgh, &dest)) { + dest.sin4.sin_family = 0; + } + sendUDPResponse(cs->udpFD, cachedResponse, cachedResponseSize, sizeof cachedResponse, +#ifdef HAVE_DNSCRYPT + dnsCryptQuery, +#endif + 0, dest, remote); g_stats.cacheHits++; g_stats.latency0_1++; // we're not going to measure this doLatencyAverages(0); // same diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index bddbfb2295..946ceed3bd 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -506,6 +506,13 @@ void setLuaSideEffect(); // set to report a side effect, cancelling all _no_ s bool getLuaNoSideEffect(); // set if there were only explicit declarations of _no_ side effect void resetLuaSideEffect(); // reset to indeterminate state +bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote); +bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, +#ifdef HAVE_DNSCRYPT + std::shared_ptr dnsCryptQuery, +#endif + std::vector& rewrittenResponse); + #ifdef HAVE_DNSCRYPT extern std::vector> g_dnsCryptLocals; diff --git a/regression-tests.dnsdist/dnscrypt.py b/regression-tests.dnsdist/dnscrypt.py index f46bd6df78..5426a5fc41 100644 --- a/regression-tests.dnsdist/dnscrypt.py +++ b/regression-tests.dnsdist/dnscrypt.py @@ -70,15 +70,35 @@ class DNSCryptClient(object): self._resolverPort = resolverPort self._resolverCertificates = [] self._publicKey, self._privateKey = libnacl.crypto_box_keypair() + self._timeout = timeout addrType = self._addrToSocketType(self._resolverAddress) self._sock = socket.socket(addrType, socket.SOCK_DGRAM) self._sock.settimeout(timeout) self._sock.connect((self._resolverAddress, self._resolverPort)) - def _sendQuery(self, queryContent): - self._sock.send(queryContent) - data = self._sock.recv(4096) + def _sendQuery(self, queryContent, tcp=False): + if tcp: + addrType = self._addrToSocketType(self._resolverAddress) + sock = socket.socket(addrType, socket.SOCK_STREAM) + sock.settimeout(self._timeout) + sock.connect((self._resolverAddress, self._resolverPort)) + sock.send(struct.pack("!H", len(queryContent))) + else: + sock = self._sock + + sock.send(queryContent) + + data = None + if tcp: + got = sock.recv(2) + print(len(got)) + if got: + (rlen,) = struct.unpack("!H", got) + data = sock.recv(rlen) + else: + data = sock.recv(4096) + return data def _hasValidResolverCertificate(self): @@ -124,12 +144,12 @@ class DNSCryptClient(object): nonce = libnacl.utils.rand_nonce() return nonce[:(DNSCryptClient.DNSCRYPT_NONCE_SIZE / 2)] - def _encryptQuery(self, queryContent, resolverCert, nonce): + def _encryptQuery(self, queryContent, resolverCert, nonce, tcp=False): header = resolverCert.clientMagic + self._publicKey + nonce requiredSize = len(header) + self.DNSCRYPT_MAC_SIZE + len(queryContent) paddingSize = self.DNSCRYPT_PADDED_BLOCK_SIZE - (len(queryContent) % self.DNSCRYPT_PADDED_BLOCK_SIZE) # padding size should be DNSCRYPT_PADDED_BLOCK_SIZE <= padding size <= 4096 - if requiredSize < self.DNSCRYPT_MIN_UDP_LENGTH: + if not tcp and requiredSize < self.DNSCRYPT_MIN_UDP_LENGTH: paddingSize += self.DNSCRYPT_MIN_UDP_LENGTH - requiredSize requiredSize = self.DNSCRYPT_MIN_UDP_LENGTH @@ -168,7 +188,7 @@ class DNSCryptClient(object): return cleartext[:idx+1] - def query(self, queryContent): + def query(self, queryContent, tcp=False): if not self._hasValidResolverCertificate(): self._getResolverCertificates() @@ -177,7 +197,7 @@ class DNSCryptClient(object): resolverCert = self._getResolverCertificate() if resolverCert is None: raise Exception("No valid certificate found") - encryptedQuery = self._encryptQuery(queryContent, resolverCert, nonce) - encryptedResponse = self._sendQuery(encryptedQuery) + encryptedQuery = self._encryptQuery(queryContent, resolverCert, nonce, tcp) + encryptedResponse = self._sendQuery(encryptedQuery, tcp) response = self._decryptResponse(encryptedResponse, resolverCert, nonce) return response diff --git a/regression-tests.dnsdist/test_DNSCrypt.py b/regression-tests.dnsdist/test_DNSCrypt.py index c8dba56de1..abf1da5ca4 100644 --- a/regression-tests.dnsdist/test_DNSCrypt.py +++ b/regression-tests.dnsdist/test_DNSCrypt.py @@ -1,6 +1,5 @@ #!/usr/bin/env python import time -import unittest import dns import dns.message from dnsdisttests import DNSDistTest @@ -43,7 +42,7 @@ class TestDNSCrypt(DNSDistTest): 3600, dns.rdataclass.IN, dns.rdatatype.A, - '127.0.0.1') + '192.2.0.1') response.answer.append(rrset) self._toResponderQueue.put(response) @@ -59,6 +58,19 @@ class TestDNSCrypt(DNSDistTest): self.assertEquals(query, receivedQuery) self.assertEquals(response, receivedResponse) + self._toResponderQueue.put(response) + data = client.query(query.to_wire(), tcp=True) + receivedResponse = dns.message.from_wire(data) + receivedQuery = None + if not self._fromResponderQueue.empty(): + receivedQuery = self._fromResponderQueue.get(query) + + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + def testResponseLargerThanPaddedQuery(self): """ DNSCrypt: response larger than query @@ -95,6 +107,59 @@ class TestDNSCrypt(DNSDistTest): self.assertTrue(len(receivedResponse.authority) == 0) self.assertTrue(len(receivedResponse.additional) == 0) -if __name__ == '__main__': - unittest.main() - exit(0) +class TestDNSCryptWithCache(DNSDistTest): + _dnsDistPortDNSCrypt = 8443 + _providerFingerprint = 'E1D7:2108:9A59:BF8D:F101:16FA:ED5E:EA6A:9F6C:C78F:7F91:AF6B:027E:62F4:69C3:B1AA' + _providerName = "2.provider.name" + _resolverCertificateSerial = 42 + # valid from 60s ago until 2h from now + _resolverCertificateValidFrom = time.time() - 60 + _resolverCertificateValidUntil = time.time() + 7200 + _config_params = ['_resolverCertificateSerial', '_resolverCertificateValidFrom', '_resolverCertificateValidUntil', '_dnsDistPortDNSCrypt', '_providerName', '_testServerPort'] + _config_template = """ + generateDNSCryptCertificate("DNSCryptProviderPrivate.key", "DNSCryptResolver.cert", "DNSCryptResolver.key", %d, %d, %d) + addDNSCryptBind("127.0.0.1:%d", "%s", "DNSCryptResolver.cert", "DNSCryptResolver.key") + pc = newPacketCache(5, 86400, 1) + getPool(""):setCache(pc) + newServer{address="127.0.0.1:%s"} + """ + + def testCachedSimpleA(self): + """ + DNSCrypt: encrypted A query served from cache + """ + client = dnscrypt.DNSCryptClient(self._providerName, self._providerFingerprint, "127.0.0.1", 8443) + name = 'cacheda.dnscrypt.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.2.0.1') + response.answer.append(rrset) + + # first query to fill the cache + self._toResponderQueue.put(response) + data = client.query(query.to_wire()) + receivedResponse = dns.message.from_wire(data) + receivedQuery = None + if not self._fromResponderQueue.empty(): + receivedQuery = self._fromResponderQueue.get(query) + + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + # second query should get a cached response + data = client.query(query.to_wire()) + receivedResponse = dns.message.from_wire(data) + receivedQuery = None + if not self._fromResponderQueue.empty(): + receivedQuery = self._fromResponderQueue.get(query) + + self.assertEquals(receivedQuery, None) + self.assertTrue(receivedResponse) + self.assertEquals(response, receivedResponse)