uint16_t qlen, rlen;
vector<uint8_t> rewrittenResponse;
shared_ptr<DownstreamState> ds;
- ComboAddress dest;
- dest.reset();
- dest.sin4.sin_family = ci.remote.sin4.sin_family;
- socklen_t len = dest.getSocklen();
size_t queriesCount = 0;
time_t connectionStartTime = time(NULL);
std::vector<char> queryBuffer;
std::vector<char> answerBuffer;
- if (getsockname(ci.fd, (sockaddr*)&dest, &len)) {
+ ComboAddress dest;
+ dest.reset();
+ dest.sin4.sin_family = ci.remote.sin4.sin_family;
+ socklen_t socklen = dest.getSocklen();
+ if (getsockname(ci.fd, (sockaddr*)&dest, &socklen)) {
dest = ci.cs->local;
}
break;
}
- bool ednsAdded = false;
- bool ecsAdded = false;
/* allocate a bit more memory to be able to spoof the content,
or to add ECS without allocating a new buffer */
queryBuffer.resize(qlen + 512);
char* query = &queryBuffer[0];
handler.read(query, qlen, g_tcpRecvTimeout, remainingTime);
- /* we need this one to be accurate ("real") for the protobuf message */
- struct timespec queryRealTime;
- struct timespec now;
- gettime(&now);
- gettime(&queryRealTime, true);
+ /* we need an accurate ("real") value for the response and
+ to store into the IDS, but not for insertion into the
+ rings for example */
+ struct timespec now;
+ struct timespec queryRealTime;
+ gettime(&now);
+ gettime(&queryRealTime, true);
-#ifdef HAVE_DNSCRYPT
std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
- if (ci.cs->dnscryptCtx) {
- dnsCryptQuery = std::make_shared<DNSCryptQuery>(ci.cs->dnscryptCtx);
- uint16_t decryptedQueryLen = 0;
- vector<uint8_t> response;
- bool decrypted = handleDNSCryptQuery(query, qlen, dnsCryptQuery, &decryptedQueryLen, true, queryRealTime.tv_sec, response);
-
- if (!decrypted) {
- if (response.size() > 0) {
- handler.writeSizeAndMsg(response.data(), response.size(), g_tcpSendTimeout);
- }
- break;
- }
- qlen = decryptedQueryLen;
+#ifdef HAVE_DNSCRYPT
+ auto dnsCryptResponse = checkDNSCryptQuery(*ci.cs, query, qlen, dnsCryptQuery, queryRealTime.tv_sec, true);
+ if (dnsCryptResponse) {
+ handler.writeSizeAndMsg(reinterpret_cast<char*>(dnsCryptResponse->data()), static_cast<uint16_t>(dnsCryptResponse->size()), g_tcpSendTimeout);
+ continue;
}
#endif
- struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
+ struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
if (!checkQueryHeaders(dh)) {
- goto drop;
+ break;
}
- string poolname;
- int delayMsec=0;
-
- const uint16_t* flags = getFlagsFromDNSHeader(dh);
- uint16_t origFlags = *flags;
- uint16_t qtype, qclass;
- unsigned int consumed = 0;
- DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
- DNSQuestion dq(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, queryBuffer.size(), qlen, true, &queryRealTime);
+ uint16_t qtype, qclass;
+ unsigned int consumed = 0;
+ DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+ DNSQuestion dq(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, queryBuffer.size(), qlen, true, &queryRealTime);
+ dq.dnsCryptQuery = std::move(dnsCryptQuery);
- if (!processQuery(holders, dq, poolname, &delayMsec, now)) {
- goto drop;
- }
-
- if(dq.dh->qr) { // something turned it into a response
- fixUpQueryTurnedResponse(dq, origFlags);
-
- DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, dq.len, true, &queryRealTime);
-#ifdef HAVE_PROTOBUF
- dr.uniqueId = dq.uniqueId;
-#endif
- dr.qTag = dq.qTag;
+ responseSender sender = [&handler](const ClientState& cs, const char* data, uint16_t dataSize, int delayMsec, const ComboAddress& dest, const ComboAddress& remote) {
+ handler.writeSizeAndMsg(data, dataSize, g_tcpSendTimeout);
+ };
- if (!processResponse(holders.selfAnsweredRespRulactions, dr, &delayMsec)) {
- goto drop;
- }
-
-#ifdef HAVE_DNSCRYPT
- if (!encryptResponse(query, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) {
- goto drop;
+ bool dropped = false;
+ auto ds = processQuery(dq, *ci.cs, holders, sender, dropped);
+ if (!ds) {
+ if (dropped) {
+ break;
}
-#endif
- handler.writeSizeAndMsg(query, dq.len, g_tcpSendTimeout);
- ++g_stats.selfAnswered;
continue;
}
- std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, poolname);
- std::shared_ptr<DNSDistPacketCache> packetCache = serverPool->packetCache;
-
- auto policy = *(holders.policy);
- if (serverPool->policy != nullptr) {
- policy = *(serverPool->policy);
- }
- auto servers = serverPool->getServers();
- if (policy.isLua) {
- std::lock_guard<std::mutex> lock(g_luamutex);
- ds = policy.policy(servers, &dq);
- }
- else {
- ds = policy.policy(servers, &dq);
- }
-
- uint32_t cacheKeyNoECS = 0;
- uint32_t cacheKey = 0;
- boost::optional<Netmask> subnet;
+ // check how that would work!!
char cachedResponse[4096];
uint16_t cachedResponseSize = sizeof cachedResponse;
- uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL;
- bool useZeroScope = false;
-
- bool dnssecOK = false;
- if (packetCache && !dq.skipCache) {
- dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
- }
-
- if (dq.useECS && ((ds && ds->useECS) || (!ds && serverPool->getECS()))) {
- // we special case our cache in case a downstream explicitly gave us a universally valid response with a 0 scope
- if (packetCache && !dq.skipCache && (!ds || !ds->disableZeroScope) && packetCache->isECSParsingEnabled()) {
- if (packetCache->get(dq, consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKeyNoECS, subnet, dnssecOK, allowExpired)) {
- DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, (dnsheader*) cachedResponse, sizeof cachedResponse, cachedResponseSize, true, &queryRealTime);
-#ifdef HAVE_PROTOBUF
- dr.uniqueId = dq.uniqueId;
-#endif
- dr.qTag = dq.qTag;
-
- if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) {
- goto drop;
- }
-
-#ifdef HAVE_DNSCRYPT
- if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery, nullptr, nullptr)) {
- goto drop;
- }
-#endif
- handler.writeSizeAndMsg(cachedResponse, cachedResponseSize, g_tcpSendTimeout);
- g_stats.cacheHits++;
- switch (dr.dh->rcode) {
- case RCode::NXDomain:
- ++g_stats.frontendNXDomain;
- break;
- case RCode::ServFail:
- ++g_stats.frontendServFail;
- break;
- case RCode::NoError:
- ++g_stats.frontendNoError;
- break;
- }
- continue;
- }
-
- if (!subnet) {
- /* there was no existing ECS on the query, enable the zero-scope feature */
- useZeroScope = true;
- }
- }
-
- if (!handleEDNSClientSubnet(dq, &(ednsAdded), &(ecsAdded), g_preserveTrailingData)) {
- vinfolog("Dropping query from %s because we couldn't insert the ECS value", ci.remote.toStringWithPort());
- goto drop;
- }
- }
-
- if (packetCache && !dq.skipCache) {
- if (packetCache->get(dq, (uint16_t) consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, subnet, dnssecOK, allowExpired)) {
- DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, (dnsheader*) cachedResponse, sizeof cachedResponse, cachedResponseSize, true, &queryRealTime);
-#ifdef HAVE_PROTOBUF
- dr.uniqueId = dq.uniqueId;
-#endif
- dr.qTag = dq.qTag;
-
- if (!processResponse(holders.cacheHitRespRulactions, dr, &delayMsec)) {
- goto drop;
- }
-
-#ifdef HAVE_DNSCRYPT
- if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery, nullptr, nullptr)) {
- goto drop;
- }
-#endif
- handler.writeSizeAndMsg(cachedResponse, cachedResponseSize, g_tcpSendTimeout);
- ++g_stats.cacheHits;
- switch (dr.dh->rcode) {
- case RCode::NXDomain:
- ++g_stats.frontendNXDomain;
- break;
- case RCode::ServFail:
- ++g_stats.frontendServFail;
- break;
- case RCode::NoError:
- ++g_stats.frontendNoError;
- break;
- }
- continue;
- }
- ++g_stats.cacheMisses;
- }
-
- if(!ds) {
- ++g_stats.noPolicy;
-
- if (g_servFailOnNoPolicy) {
- restoreFlags(dh, origFlags);
- dq.dh->rcode = RCode::ServFail;
- dq.dh->qr = true;
-
- DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(query), dq.size, dq.len, false, &queryRealTime);
-#ifdef HAVE_PROTOBUF
- dr.uniqueId = dq.uniqueId;
-#endif
- dr.qTag = dq.qTag;
-
- if (!processResponse(holders.selfAnsweredRespRulactions, dr, &delayMsec)) {
- goto drop;
- }
-
-#ifdef HAVE_DNSCRYPT
- if (!encryptResponse(query, &dq.len, dq.size, true, dnsCryptQuery, nullptr, nullptr)) {
- goto drop;
- }
-#endif
- handler.writeSizeAndMsg(query, dq.len, g_tcpSendTimeout);
-
- // no response-only statistics counter to update.
- continue;
- }
-
- break;
- }
-
- if (dq.addXPF && ds->xpfRRCode != 0) {
- addXPF(dq, ds->xpfRRCode, g_preserveTrailingData);
- }
int dsock = -1;
uint16_t downstreamFailures=0;
#endif /* MSG_FASTOPEN */
}
- ds->queries++;
ds->outstanding++;
outstanding = true;
freshConn=true;
#endif /* MSG_FASTOPEN */
if(xfrStarted) {
- goto drop;
+ break;
}
goto retry;
}
size_t responseSize = rlen;
uint16_t addRoom = 0;
#ifdef HAVE_DNSCRYPT
- if (dnsCryptQuery && (UINT16_MAX - rlen) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) {
+ if (dq.dnsCryptQuery && (UINT16_MAX - rlen) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) {
addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
}
#endif
}
firstPacket=false;
bool zeroScope = false;
- if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded, ecsAdded, rewrittenResponse, addRoom, useZeroScope ? &zeroScope : nullptr)) {
+ if (!fixUpResponse(&response, &responseLen, &responseSize, qname, dq.origFlags, dq.ednsAdded, dq.ecsAdded, rewrittenResponse, addRoom, dq.useZeroScope ? &zeroScope : nullptr)) {
break;
}
#endif
dr.qTag = dq.qTag;
- if (!processResponse(localRespRulactions, dr, &delayMsec)) {
+ if (!processResponse(localRespRulactions, dr, &dq.delayMsec)) {
break;
}
- if (packetCache && !dq.skipCache) {
- if (!useZeroScope) {
+ if (dq.packetCache && !dq.skipCache) {
+ if (!dq.useZeroScope) {
/* if the query was not suitable for zero-scope, for
example because it had an existing ECS entry so the hash is
not really 'no ECS', so just insert it for the existing subnet
zeroScope = false;
}
// if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
- packetCache->insert(zeroScope ? cacheKeyNoECS : cacheKey, zeroScope ? boost::none : subnet, origFlags, dnssecOK, qname, qtype, qclass, response, responseLen, true, dh->rcode, dq.tempFailureTTL);
+ dq.packetCache->insert(zeroScope ? dq.cacheKeyNoECS : dq.cacheKey, zeroScope ? boost::none : dq.subnet, dq.origFlags, dq.dnssecOK, qname, qtype, qclass, response, responseLen, true, dh->rcode, dq.tempFailureTTL);
}
#ifdef HAVE_DNSCRYPT
- if (!encryptResponse(response, &responseLen, responseSize, true, dnsCryptQuery, &dh, &dhCopy)) {
- goto drop;
+ if (!encryptResponse(response, &responseLen, responseSize, true, dq.dnsCryptQuery, &dh, &dhCopy)) {
+ break;
}
#endif
if (!handler.writeSizeAndMsg(response, responseLen, g_tcpSendTimeout)) {
rewrittenResponse.clear();
}
}
- catch(...) {}
-
- drop:;
+ catch(const std::exception& e) {
+ vinfolog("Got exception while handling TCP query: %s", e.what());
+ }
+ catch(...) {
+ }
vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort());
return true;
}
-void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
+static void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
{
static const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
static const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
*flags |= origFlags;
}
-bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
+static bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
{
restoreFlags(dq.dh, origFlags);
}
#endif
-static bool sendUDPResponse(int origFD, char* response, uint16_t responseLen, int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
+static bool sendUDPResponse(int origFD, const char* response, uint16_t responseLen, int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
{
if(delayMsec && g_delay) {
DelayedPacket dp{origFD, string(response,responseLen), origRemote, origDest};
}
-static int pickBackendSocketForSending(DownstreamState* state)
+static int pickBackendSocketForSending(std::shared_ptr<DownstreamState>& state)
{
return state->sockets[state->socketsOffset++ % state->sockets.size()];
}
}
}
-bool processQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, int* delayMsec, const struct timespec& now)
+static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& poolname, const struct timespec& now)
{
- g_rings.insertQuery(now,*dq.remote,*dq.qname,dq.qtype,dq.len,*dq.dh);
+ g_rings.insertQuery(now, *dq.remote, *dq.qname, dq.qtype, dq.len, *dq.dh);
if(g_qcount.enabled) {
string qname = (*dq.qname).toString(".");
break;
/* non-terminal actions follow */
case DNSAction::Action::Delay:
- *delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
+ dq.delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
break;
case DNSAction::Action::None:
/* fall-through */
return true;
}
-static ssize_t udpClientSendRequestToBackend(DownstreamState* ss, const int sd, const char* request, const size_t requestLen, bool healthCheck=false)
+static ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const char* request, const size_t requestLen, bool healthCheck=false)
{
ssize_t result;
}
#ifdef HAVE_DNSCRYPT
-static bool checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, const ComboAddress& dest, const ComboAddress& remote, time_t now)
+boost::optional<std::vector<uint8_t>> checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, time_t now, bool tcp)
{
if (cs.dnscryptCtx) {
vector<uint8_t> response;
dnsCryptQuery = std::make_shared<DNSCryptQuery>(cs.dnscryptCtx);
- bool decrypted = handleDNSCryptQuery(const_cast<char*>(query), len, dnsCryptQuery, &decryptedQueryLen, false, now, response);
+ bool decrypted = handleDNSCryptQuery(const_cast<char*>(query), len, dnsCryptQuery, &decryptedQueryLen, tcp, now, response);
if (!decrypted) {
if (response.size() > 0) {
- sendUDPResponse(cs.udpFD, reinterpret_cast<char*>(response.data()), static_cast<uint16_t>(response.size()), 0, dest, remote);
+ return response;
}
- return false;
+ throw std::runtime_error("Unable to decrypt DNSCrypt query, dropping.");
}
len = decryptedQueryLen;
}
- return true;
+ return boost::none;
}
#endif /* HAVE_DNSCRYPT */
}
#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
-static int sendAndEncryptUDPResponse(LocalHolders& holders, ClientState& cs, const DNSQuestion& dq, char* response, uint16_t responseLen, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, int delayMsec, const ComboAddress& dest, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf, bool cacheHit)
+static int sendResponse(LocalHolders& holders, ClientState& cs, DNSQuestion& dq, char* response, uint16_t responseLen, bool cacheHit, responseSender sender)
{
- DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(response), dq.size, responseLen, false, dq.queryTime);
+ DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(response), dq.size, responseLen, dq.tcp, dq.queryTime);
+
#ifdef HAVE_PROTOBUF
dr.uniqueId = dq.uniqueId;
#endif
dr.qTag = dq.qTag;
- if (!processResponse(cacheHit ? holders.cacheHitRespRulactions : holders.selfAnsweredRespRulactions, dr, &delayMsec)) {
+ if (!processResponse(cacheHit ? holders.cacheHitRespRulactions : holders.selfAnsweredRespRulactions, dr, &dq.delayMsec)) {
return -1;
}
if (!cs.muted) {
#ifdef HAVE_DNSCRYPT
- if (!encryptResponse(response, &responseLen, dq.size, false, dnsCryptQuery, nullptr, nullptr)) {
+ if (!encryptResponse(response, &responseLen, dq.size, dq.tcp, dq.dnsCryptQuery, nullptr, nullptr)) {
return -1;
}
#endif
-#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
- if (delayMsec == 0 && responsesVect != nullptr) {
- queueResponse(cs, response, responseLen, dest, *dq.remote, responsesVect[*queuedResponses], respIOV, respCBuf);
- (*queuedResponses)++;
- }
- else
-#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
- {
- sendUDPResponse(cs.udpFD, response, responseLen, delayMsec, dest, *dq.remote);
- }
+
+ sender(cs, response, responseLen, dq.delayMsec, *dq.local, *dq.remote);
}
if (cacheHit) {
return 0;
}
-static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf)
+/* returns nullptr if the query has been taken care of (cache-hit, self-answered or discarded) and a backend it should be sent to otherwise */
+std::shared_ptr<DownstreamState> processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, responseSender sender, bool& dropped)
{
- assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr));
- uint16_t queryId = 0;
+ const uint16_t queryId = ntohs(dq.dh->id);
try {
- if (!isUDPQueryAcceptable(cs, holders, msgh, remote, dest)) {
- return;
- }
-
/* we need an accurate ("real") value for the response and
to store into the IDS, but not for insertion into the
rings for example */
- struct timespec queryRealTime;
struct timespec now;
gettime(&now);
- gettime(&queryRealTime, true);
-
- std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
-
-#ifdef HAVE_DNSCRYPT
- if (!checkDNSCryptQuery(cs, query, len, dnsCryptQuery, dest, remote, queryRealTime.tv_sec)) {
- return;
- }
-#endif
-
- struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
- queryId = ntohs(dh->id);
-
- if (!checkQueryHeaders(dh)) {
- return;
- }
string poolname;
- int delayMsec = 0;
- const uint16_t * flags = getFlagsFromDNSHeader(dh);
- const uint16_t origFlags = *flags;
- uint16_t qtype, qclass;
- unsigned int consumed = 0;
- DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
- DNSQuestion dq(&qname, qtype, qclass, consumed, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false, &queryRealTime);
- bool dnssecOK = false;
- if (!processQuery(holders, dq, poolname, &delayMsec, now))
- {
- return;
+ if (!applyRulesToQuery(holders, dq, poolname, now)) {
+ dropped = true;
+ return nullptr;
}
if(dq.dh->qr) { // something turned it into a response
- fixUpQueryTurnedResponse(dq, origFlags);
+ fixUpQueryTurnedResponse(dq, dq.origFlags);
if (!cs.muted) {
- char* response = query;
+ char* response = reinterpret_cast<char*>(dq.dh);
uint16_t responseLen = dq.len;
- sendAndEncryptUDPResponse(holders, cs, dq, response, responseLen, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, false);
+ sendResponse(holders, cs, dq, response, responseLen, false, sender);
++g_stats.selfAnswered;
}
- return;
+ return nullptr;
}
- DownstreamState* ss = nullptr;
+ std::shared_ptr<DownstreamState> ss{nullptr};
std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, poolname);
- std::shared_ptr<DNSDistPacketCache> packetCache = serverPool->packetCache;
+ dq.packetCache = serverPool->packetCache;
auto policy = *(holders.policy);
if (serverPool->policy != nullptr) {
policy = *(serverPool->policy);
auto servers = serverPool->getServers();
if (policy.isLua) {
std::lock_guard<std::mutex> lock(g_luamutex);
- ss = policy.policy(servers, &dq).get();
+ ss = policy.policy(servers, &dq);
}
else {
- ss = policy.policy(servers, &dq).get();
+ ss = policy.policy(servers, &dq);
}
- bool ednsAdded = false;
- bool ecsAdded = false;
- uint32_t cacheKeyNoECS = 0;
- uint32_t cacheKey = 0;
- boost::optional<Netmask> subnet;
uint16_t cachedResponseSize = dq.size;
uint32_t allowExpired = ss ? 0 : g_staleCacheEntriesTTL;
- bool useZeroScope = false;
- if (packetCache && !dq.skipCache) {
- dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
+ if (dq.packetCache && !dq.skipCache) {
+ dq.dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
}
if (dq.useECS && ((ss && ss->useECS) || (!ss && serverPool->getECS()))) {
// we special case our cache in case a downstream explicitly gave us a universally valid response with a 0 scope
- if (packetCache && !dq.skipCache && (!ss || !ss->disableZeroScope) && packetCache->isECSParsingEnabled()) {
- if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKeyNoECS, subnet, dnssecOK, allowExpired)) {
- sendAndEncryptUDPResponse(holders, cs, dq, query, cachedResponseSize, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, true);
- return;
+ if (dq.packetCache && !dq.skipCache && (!ss || !ss->disableZeroScope) && dq.packetCache->isECSParsingEnabled()) {
+ if (dq.packetCache->get(dq, dq.consumed, dq.dh->id, reinterpret_cast<char*>(dq.dh), &cachedResponseSize, &dq.cacheKeyNoECS, dq.subnet, dq.dnssecOK, allowExpired)) {
+ sendResponse(holders, cs, dq, reinterpret_cast<char*>(dq.dh), cachedResponseSize, true, sender);
+ return nullptr;
}
- if (!subnet) {
+ if (!dq.subnet) {
/* there was no existing ECS on the query, enable the zero-scope feature */
- useZeroScope = true;
+ dq.useZeroScope = true;
}
}
- if (!handleEDNSClientSubnet(dq, &(ednsAdded), &(ecsAdded), g_preserveTrailingData)) {
- vinfolog("Dropping query from %s because we couldn't insert the ECS value", remote.toStringWithPort());
- return;
+ if (!handleEDNSClientSubnet(dq, &(dq.ednsAdded), &(dq.ecsAdded), g_preserveTrailingData)) {
+ vinfolog("Dropping query from %s because we couldn't insert the ECS value", dq.remote->toStringWithPort());
+ dropped = true;
+ return nullptr;
}
}
- if (packetCache && !dq.skipCache) {
- if (packetCache->get(dq, consumed, dh->id, query, &cachedResponseSize, &cacheKey, subnet, dnssecOK, allowExpired)) {
- sendAndEncryptUDPResponse(holders, cs, dq, query, cachedResponseSize, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, true);
- return;
+ if (dq.packetCache && !dq.skipCache) {
+ if (dq.packetCache->get(dq, dq.consumed, dq.dh->id, reinterpret_cast<char*>(dq.dh), &cachedResponseSize, &dq.cacheKey, dq.subnet, dq.dnssecOK, allowExpired)) {
+ sendResponse(holders, cs, dq, reinterpret_cast<char*>(dq.dh), cachedResponseSize, true, sender);
+ return nullptr;
}
++g_stats.cacheMisses;
}
++g_stats.noPolicy;
if (g_servFailOnNoPolicy && !cs.muted) {
- char* response = query;
+ char* response = reinterpret_cast<char*>(dq.dh);
uint16_t responseLen = dq.len;
- restoreFlags(dh, origFlags);
+ restoreFlags(dq.dh, dq.origFlags);
dq.dh->rcode = RCode::ServFail;
dq.dh->qr = true;
- sendAndEncryptUDPResponse(holders, cs, dq, response, responseLen, dnsCryptQuery, delayMsec, dest, responsesVect, queuedResponses, respIOV, respCBuf, false);
-
+ sendResponse(holders, cs, dq, response, responseLen, false, sender);
// no response-only statistics counter to update.
}
- vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toString(), QType(dq.qtype).getName(), remote.toStringWithPort());
- return;
+ vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toString(), QType(dq.qtype).getName(), dq.remote->toStringWithPort());
+ return nullptr;
}
if (dq.addXPF && ss->xpfRRCode != 0) {
}
ss->queries++;
+ return ss;
+ }
+ catch(const std::exception& e){
+ vinfolog("Got an error while parsing a %s query from %s, id %d: %s", (dq.tcp ? "TCP" : "UDP"), dq.remote->toStringWithPort(), queryId, e.what());
+ dropped = true;
+ }
+ return nullptr;
+}
+
+static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, char* respCBuf)
+{
+ assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr));
+ uint16_t queryId = 0;
+
+ try {
+ if (!isUDPQueryAcceptable(cs, holders, msgh, remote, dest)) {
+ return;
+ }
+
+ /* we need an accurate ("real") value for the response and
+ to store into the IDS, but not for insertion into the
+ rings for example */
+ struct timespec queryRealTime;
+ struct timespec now;
+ gettime(&now);
+ gettime(&queryRealTime, true);
+
+ std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
+
+#ifdef HAVE_DNSCRYPT
+ auto dnsCryptResponse = checkDNSCryptQuery(cs, query, len, dnsCryptQuery, queryRealTime.tv_sec, false);
+ if (dnsCryptResponse) {
+ sendUDPResponse(cs.udpFD, reinterpret_cast<char*>(dnsCryptResponse->data()), static_cast<uint16_t>(dnsCryptResponse->size()), 0, dest, remote);
+ return;
+ }
+#endif
+
+ struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
+ queryId = ntohs(dh->id);
+
+ if (!checkQueryHeaders(dh)) {
+ return;
+ }
+
+ uint16_t qtype, qclass;
+ unsigned int consumed = 0;
+ DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+ DNSQuestion dq(&qname, qtype, qclass, consumed, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false, &queryRealTime);
+ dq.dnsCryptQuery = std::move(dnsCryptQuery);
+
+ responseSender sender = [&responsesVect, &queuedResponses, &respIOV, &respCBuf](const ClientState& cs, const char* data, uint16_t dataSize, int delayMsec, const ComboAddress& dest, const ComboAddress& remote) -> void {
+#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
+ if (delayMsec == 0 && responsesVect != nullptr) {
+ queueResponse(cs, data, dataSize, dest, remote, responsesVect[*queuedResponses], respIOV, respCBuf);
+ (*queuedResponses)++;
+ return;
+ }
+#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
+ sendUDPResponse(cs.udpFD, data, dataSize, delayMsec, dest, remote);
+ };
+
+ bool dropped = false;
+ auto ss = processQuery(dq, cs, holders, sender, dropped);
+
+ if (!ss) {
+ return;
+ }
unsigned int idOffset = (ss->idOffset++) % ss->idStates.size();
IDState* ids = &ss->idStates[idOffset];
ids->origID = dh->id;
ids->origRemote = remote;
ids->sentTime.set(queryRealTime);
- ids->qname = qname;
+ ids->qname = std::move(qname);
ids->qtype = dq.qtype;
ids->qclass = dq.qclass;
- ids->delayMsec = delayMsec;
+ ids->delayMsec = dq.delayMsec;
ids->tempFailureTTL = dq.tempFailureTTL;
- ids->origFlags = origFlags;
- ids->cacheKey = cacheKey;
- ids->cacheKeyNoECS = cacheKeyNoECS;
- ids->subnet = subnet;
+ ids->origFlags = dq.origFlags;
+ ids->cacheKey = dq.cacheKey;
+ ids->cacheKeyNoECS = dq.cacheKeyNoECS;
+ ids->subnet = dq.subnet;
ids->skipCache = dq.skipCache;
- ids->packetCache = packetCache;
- ids->ednsAdded = ednsAdded;
- ids->ecsAdded = ecsAdded;
- ids->useZeroScope = useZeroScope;
+ ids->packetCache = dq.packetCache;
+ ids->ednsAdded = dq.ednsAdded;
+ ids->ecsAdded = dq.ecsAdded;
+ ids->useZeroScope = dq.useZeroScope;
ids->qTag = dq.qTag;
- ids->dnssecOK = dnssecOK;
+ ids->dnssecOK = dq.dnssecOK;
/* If we couldn't harvest the real dest addr, still
write down the listening addr since it will be useful
ids->destHarvested = false;
}
#ifdef HAVE_DNSCRYPT
- ids->dnsCryptQuery = dnsCryptQuery;
+ ids->dnsCryptQuery = std::move(dq.dnsCryptQuery);
#endif
#ifdef HAVE_PROTOBUF
- ids->uniqueId = dq.uniqueId;
+ ids->uniqueId = std::move(dq.uniqueId);
#endif
dh->id = idOffset;
#endif
}
-static bool upCheck(DownstreamState& ds)
+static bool upCheck(const shared_ptr<DownstreamState>& ds)
try
{
- DNSName checkName = ds.checkName;
- uint16_t checkType = ds.checkType.getCode();
- uint16_t checkClass = ds.checkClass;
+ DNSName checkName = ds->checkName;
+ uint16_t checkType = ds->checkType.getCode();
+ uint16_t checkClass = ds->checkClass;
dnsheader checkHeader;
memset(&checkHeader, 0, sizeof(checkHeader));
checkHeader.id = getRandomDNSID();
checkHeader.rd = true;
- if (ds.setCD) {
+ if (ds->setCD) {
checkHeader.cd = true;
}
- if (ds.checkFunction) {
+ if (ds->checkFunction) {
std::lock_guard<std::mutex> lock(g_luamutex);
- auto ret = ds.checkFunction(checkName, checkType, checkClass, &checkHeader);
+ auto ret = ds->checkFunction(checkName, checkType, checkClass, &checkHeader);
checkName = std::get<0>(ret);
checkType = std::get<1>(ret);
checkClass = std::get<2>(ret);
dnsheader * requestHeader = dpw.getHeader();
*requestHeader = checkHeader;
- Socket sock(ds.remote.sin4.sin_family, SOCK_DGRAM);
+ Socket sock(ds->remote.sin4.sin_family, SOCK_DGRAM);
sock.setNonBlocking();
- if (!IsAnyAddress(ds.sourceAddr)) {
+ if (!IsAnyAddress(ds->sourceAddr)) {
sock.setReuseAddr();
- sock.bind(ds.sourceAddr);
+ sock.bind(ds->sourceAddr);
}
- sock.connect(ds.remote);
- ssize_t sent = udpClientSendRequestToBackend(&ds, sock.getHandle(), (char*)&packet[0], packet.size(), true);
+ sock.connect(ds->remote);
+ ssize_t sent = udpClientSendRequestToBackend(ds, sock.getHandle(), (char*)&packet[0], packet.size(), true);
if (sent < 0) {
int ret = errno;
if (g_verboseHealthChecks)
- infolog("Error while sending a health check query to backend %s: %d", ds.getNameWithAddr(), ret);
+ infolog("Error while sending a health check query to backend %s: %d", ds->getNameWithAddr(), ret);
return false;
}
- int ret = waitForRWData(sock.getHandle(), true, /* ms to seconds */ ds.checkTimeout / 1000, /* remaining ms to us */ (ds.checkTimeout % 1000) * 1000);
+ int ret = waitForRWData(sock.getHandle(), true, /* ms to seconds */ ds->checkTimeout / 1000, /* remaining ms to us */ (ds->checkTimeout % 1000) * 1000);
if(ret < 0 || !ret) { // error, timeout, both are down!
if (ret < 0) {
ret = errno;
if (g_verboseHealthChecks)
- infolog("Error while waiting for the health check response from backend %s: %d", ds.getNameWithAddr(), ret);
+ infolog("Error while waiting for the health check response from backend %s: %d", ds->getNameWithAddr(), ret);
}
else {
if (g_verboseHealthChecks)
- infolog("Timeout while waiting for the health check response from backend %s", ds.getNameWithAddr());
+ infolog("Timeout while waiting for the health check response from backend %s", ds->getNameWithAddr());
}
return false;
}
sock.recvFrom(reply, from);
/* we are using a connected socket but hey.. */
- if (from != ds.remote) {
+ if (from != ds->remote) {
if (g_verboseHealthChecks)
- infolog("Invalid health check response received from %s, expecting one from %s", from.toStringWithPort(), ds.remote.toStringWithPort());
+ infolog("Invalid health check response received from %s, expecting one from %s", from.toStringWithPort(), ds->remote.toStringWithPort());
return false;
}
if (reply.size() < sizeof(*responseHeader)) {
if (g_verboseHealthChecks)
- infolog("Invalid health check response of size %d from backend %s, expecting at least %d", reply.size(), ds.getNameWithAddr(), sizeof(*responseHeader));
+ infolog("Invalid health check response of size %d from backend %s, expecting at least %d", reply.size(), ds->getNameWithAddr(), sizeof(*responseHeader));
return false;
}
if (responseHeader->id != requestHeader->id) {
if (g_verboseHealthChecks)
- infolog("Invalid health check response id %d from backend %s, expecting %d", responseHeader->id, ds.getNameWithAddr(), requestHeader->id);
+ infolog("Invalid health check response id %d from backend %s, expecting %d", responseHeader->id, ds->getNameWithAddr(), requestHeader->id);
return false;
}
if (!responseHeader->qr) {
if (g_verboseHealthChecks)
- infolog("Invalid health check response from backend %s, expecting QR to be set", ds.getNameWithAddr());
+ infolog("Invalid health check response from backend %s, expecting QR to be set", ds->getNameWithAddr());
return false;
}
if (responseHeader->rcode == RCode::ServFail) {
if (g_verboseHealthChecks)
- infolog("Backend %s responded to health check with ServFail", ds.getNameWithAddr());
+ infolog("Backend %s responded to health check with ServFail", ds->getNameWithAddr());
return false;
}
- if (ds.mustResolve && (responseHeader->rcode == RCode::NXDomain || responseHeader->rcode == RCode::Refused)) {
+ if (ds->mustResolve && (responseHeader->rcode == RCode::NXDomain || responseHeader->rcode == RCode::Refused)) {
if (g_verboseHealthChecks)
- infolog("Backend %s responded to health check with %s while mustResolve is set", ds.getNameWithAddr(), responseHeader->rcode == RCode::NXDomain ? "NXDomain" : "Refused");
+ infolog("Backend %s responded to health check with %s while mustResolve is set", ds->getNameWithAddr(), responseHeader->rcode == RCode::NXDomain ? "NXDomain" : "Refused");
return false;
}
if (receivedName != checkName || receivedType != checkType || receivedClass != checkClass) {
if (g_verboseHealthChecks)
- infolog("Backend %s responded to health check with an invalid qname (%s vs %s), qtype (%s vs %s) or qclass (%d vs %d)", ds.getNameWithAddr(), receivedName.toLogString(), checkName.toLogString(), QType(receivedType).getName(), QType(checkType).getName(), receivedClass, checkClass);
+ infolog("Backend %s responded to health check with an invalid qname (%s vs %s), qtype (%s vs %s) or qclass (%d vs %d)", ds->getNameWithAddr(), receivedName.toLogString(), checkName.toLogString(), QType(receivedType).getName(), QType(checkType).getName(), receivedClass, checkClass);
return false;
}
catch(const std::exception& e)
{
if (g_verboseHealthChecks)
- infolog("Error checking the health of backend %s: %s", ds.getNameWithAddr(), e.what());
+ infolog("Error checking the health of backend %s: %s", ds->getNameWithAddr(), e.what());
return false;
}
catch(...)
{
if (g_verboseHealthChecks)
- infolog("Unknown exception while checking the health of backend %s", ds.getNameWithAddr());
+ infolog("Unknown exception while checking the health of backend %s", ds->getNameWithAddr());
return false;
}
continue;
dss->lastCheck = 0;
if(dss->availability==DownstreamState::Availability::Auto) {
- bool newState=upCheck(*dss);
+ bool newState=upCheck(dss);
if (newState) {
/* check succeeded */
dss->currentCheckFailures = 0;
for(auto& dss : g_dstates.getCopy()) { // it is a copy, but the internal shared_ptrs are the real deal
if(dss->availability==DownstreamState::Availability::Auto) {
- bool newState=upCheck(*dss);
+ bool newState=upCheck(dss);
warnlog("Marking downstream %s as '%s'", dss->getNameWithAddr(), newState ? "up" : "down");
dss->upStatus = newState;
}