]> git.ipfire.org Git - thirdparty/pdns.git/blame - pdns/dnsdist-cache.cc
Merge pull request #7628 from tcely/patch-3
[thirdparty/pdns.git] / pdns / dnsdist-cache.cc
CommitLineData
12471842
PL
1/*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
6432820c
RG
22#include <cinttypes>
23
1ea747c0 24#include "dnsdist.hh"
886e2cf2 25#include "dolog.hh"
886e2cf2 26#include "dnsparser.hh"
1ea747c0 27#include "dnsdist-cache.hh"
78e3ac9e
RG
28#include "dnsdist-ecs.hh"
29#include "ednsoptions.hh"
30#include "ednssubnet.hh"
886e2cf2 31
78e3ac9e 32DNSDistPacketCache::DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL, uint32_t minTTL, uint32_t tempFailureTTL, uint32_t maxNegativeTTL, uint32_t staleTTL, bool dontAge, uint32_t shards, bool deferrableInsertLock, bool parseECS): d_maxEntries(maxEntries), d_shardCount(shards), d_maxTTL(maxTTL), d_tempFailureTTL(tempFailureTTL), d_maxNegativeTTL(maxNegativeTTL), d_minTTL(minTTL), d_staleTTL(staleTTL), d_dontAge(dontAge), d_deferrableInsertLock(deferrableInsertLock), d_parseECS(parseECS)
886e2cf2 33{
2b3eefc3
RG
34 d_shards.resize(d_shardCount);
35
ccac98a0 36 /* we reserve maxEntries + 1 to avoid rehashing from occurring
886e2cf2 37 when we get to maxEntries, as it means a load factor of 1 */
2b3eefc3
RG
38 for (auto& shard : d_shards) {
39 shard.setSize((maxEntries / d_shardCount) + 1);
40 }
886e2cf2
RG
41}
42
43DNSDistPacketCache::~DNSDistPacketCache()
44{
737a287f 45 try {
2b3eefc3
RG
46 vector<std::unique_ptr<WriteLock>> locks;
47 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
48 locks.push_back(std::unique_ptr<WriteLock>(new WriteLock(&d_shards.at(shardIndex).d_lock)));
49 }
737a287f 50 }
2b3eefc3 51 catch(...) {
737a287f 52 }
886e2cf2
RG
53}
54
78e3ac9e
RG
55bool DNSDistPacketCache::getClientSubnet(const char* packet, unsigned int consumed, uint16_t len, boost::optional<Netmask>& subnet)
56{
cbf4e13a 57 uint16_t optRDPosition;
78e3ac9e
RG
58 size_t remaining = 0;
59
cbf4e13a 60 int res = getEDNSOptionsStart(const_cast<char*>(packet), consumed, len, &optRDPosition, &remaining);
78e3ac9e
RG
61
62 if (res == 0) {
cbf4e13a 63 char * ecsOptionStart = nullptr;
78e3ac9e
RG
64 size_t ecsOptionSize = 0;
65
cbf4e13a 66 res = getEDNSOption(const_cast<char*>(packet) + optRDPosition, remaining, EDNSOptionCode::ECS, &ecsOptionStart, &ecsOptionSize);
78e3ac9e
RG
67
68 if (res == 0 && ecsOptionSize > (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)) {
69
70 EDNSSubnetOpts eso;
71 if (getEDNSSubnetOptsFromString(ecsOptionStart + (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), ecsOptionSize - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), &eso) == true) {
72 subnet = eso.source;
73 return true;
74 }
75 }
76 }
77
78 return false;
79}
80
d7728daf 81bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, bool dnssecOK, const boost::optional<Netmask>& subnet) const
886e2cf2 82{
d7728daf 83 if (cachedValue.queryFlags != queryFlags || cachedValue.dnssecOK != dnssecOK || cachedValue.tcp != tcp || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname) {
886e2cf2 84 return false;
8dcdbdb1
RG
85 }
86
78e3ac9e
RG
87 if (d_parseECS && cachedValue.subnet != subnet) {
88 return false;
89 }
90
886e2cf2
RG
91 return true;
92}
93
78e3ac9e 94void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, CacheValue& newValue)
2b3eefc3
RG
95{
96 auto& map = shard.d_map;
97 /* check again now that we hold the lock to prevent a race */
98 if (map.size() >= (d_maxEntries / d_shardCount)) {
99 return;
100 }
101
102 std::unordered_map<uint32_t,CacheValue>::iterator it;
103 bool result;
104 tie(it, result) = map.insert({key, newValue});
105
106 if (result) {
107 shard.d_entriesCount++;
108 return;
109 }
110
111 /* in case of collision, don't override the existing entry
112 except if it has expired */
113 CacheValue& value = it->second;
78e3ac9e 114 bool wasExpired = value.validity <= newValue.added;
2b3eefc3 115
d7728daf 116 if (!wasExpired && !cachedValueMatches(value, newValue.queryFlags, newValue.qname, newValue.qtype, newValue.qclass, newValue.tcp, newValue.dnssecOK, newValue.subnet)) {
2b3eefc3
RG
117 d_insertCollisions++;
118 return;
119 }
120
121 /* if the existing entry had a longer TTD, keep it */
78e3ac9e 122 if (newValue.validity <= value.validity) {
2b3eefc3
RG
123 return;
124 }
125
126 value = newValue;
127}
128
d7728daf 129void DNSDistPacketCache::insert(uint32_t key, const boost::optional<Netmask>& subnet, uint16_t queryFlags, bool dnssecOK, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional<uint32_t> tempFailureTTL)
886e2cf2 130{
8dcdbdb1 131 if (responseLen < sizeof(dnsheader)) {
886e2cf2 132 return;
8dcdbdb1 133 }
886e2cf2 134
0f08e82b 135 uint32_t minTTL;
886e2cf2 136
2714396e 137 if (rcode == RCode::ServFail || rcode == RCode::Refused) {
acb8f5d5 138 minTTL = tempFailureTTL == boost::none ? d_tempFailureTTL : *tempFailureTTL;
f4e5b47d
RG
139 if (minTTL == 0) {
140 return;
141 }
0f08e82b
RG
142 }
143 else {
47698274
RG
144 bool seenAuthSOA = false;
145 minTTL = getMinTTL(response, responseLen, &seenAuthSOA);
a3824e43
RG
146
147 /* no TTL found, we don't want to cache this */
148 if (minTTL == std::numeric_limits<uint32_t>::max()) {
149 return;
150 }
151
47698274
RG
152 if (rcode == RCode::NXDomain || (rcode == RCode::NoError && seenAuthSOA)) {
153 minTTL = std::min(minTTL, d_maxNegativeTTL);
154 }
155 else if (minTTL > d_maxTTL) {
0f08e82b 156 minTTL = d_maxTTL;
a3824e43 157 }
0f08e82b 158
cc8cefe1
RG
159 if (minTTL < d_minTTL) {
160 d_ttlTooShorts++;
0f08e82b 161 return;
cc8cefe1 162 }
0f08e82b 163 }
886e2cf2 164
2b3eefc3
RG
165 uint32_t shardIndex = getShardIndex(key);
166
167 if (d_shards.at(shardIndex).d_entriesCount >= (d_maxEntries / d_shardCount)) {
168 return;
886e2cf2
RG
169 }
170
c1b81381 171 const time_t now = time(nullptr);
886e2cf2
RG
172 time_t newValidity = now + minTTL;
173 CacheValue newValue;
174 newValue.qname = qname;
175 newValue.qtype = qtype;
176 newValue.qclass = qclass;
8dcdbdb1 177 newValue.queryFlags = queryFlags;
886e2cf2
RG
178 newValue.len = responseLen;
179 newValue.validity = newValidity;
180 newValue.added = now;
a176d205 181 newValue.tcp = tcp;
d7728daf 182 newValue.dnssecOK = dnssecOK;
886e2cf2 183 newValue.value = std::string(response, responseLen);
78e3ac9e 184 newValue.subnet = subnet;
886e2cf2 185
2b3eefc3
RG
186 auto& shard = d_shards.at(shardIndex);
187
188 if (d_deferrableInsertLock) {
189 TryWriteLock w(&shard.d_lock);
886e2cf2
RG
190
191 if (!w.gotIt()) {
192 d_deferredInserts++;
193 return;
194 }
78e3ac9e 195 insertLocked(shard, key, newValue);
2b3eefc3
RG
196 }
197 else {
198 WriteLock w(&shard.d_lock);
886e2cf2 199
78e3ac9e 200 insertLocked(shard, key, newValue);
886e2cf2
RG
201 }
202}
203
d7728daf 204bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, boost::optional<Netmask>& subnet, bool dnssecOK, uint32_t allowExpired, bool skipAging)
886e2cf2 205{
7acad2d5 206 std::string dnsQName(dq.qname->toDNSString());
78e3ac9e 207 uint32_t key = getKey(dnsQName, consumed, reinterpret_cast<const unsigned char*>(dq.dh), dq.len, dq.tcp);
f037144c 208
886e2cf2
RG
209 if (keyOut)
210 *keyOut = key;
211
78e3ac9e
RG
212 if (d_parseECS) {
213 getClientSubnet(reinterpret_cast<const char*>(dq.dh), consumed, dq.len, subnet);
214 }
215
2b3eefc3 216 uint32_t shardIndex = getShardIndex(key);
c1b81381 217 time_t now = time(nullptr);
886e2cf2 218 time_t age;
1ea747c0 219 bool stale = false;
2b3eefc3
RG
220 auto& shard = d_shards.at(shardIndex);
221 auto& map = shard.d_map;
886e2cf2 222 {
2b3eefc3 223 TryReadLock r(&shard.d_lock);
886e2cf2
RG
224 if (!r.gotIt()) {
225 d_deferredLookups++;
226 return false;
227 }
228
2b3eefc3
RG
229 std::unordered_map<uint32_t,CacheValue>::const_iterator it = map.find(key);
230 if (it == map.end()) {
886e2cf2
RG
231 d_misses++;
232 return false;
233 }
234
235 const CacheValue& value = it->second;
236 if (value.validity < now) {
a1a0a75a 237 if ((now - value.validity) >= static_cast<time_t>(allowExpired)) {
1ea747c0
RG
238 d_misses++;
239 return false;
240 }
241 else {
242 stale = true;
243 }
886e2cf2
RG
244 }
245
39a21975 246 if (*responseLen < value.len || value.len < sizeof(dnsheader)) {
886e2cf2
RG
247 return false;
248 }
249
250 /* check for collision */
d7728daf 251 if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.dh)), *dq.qname, dq.qtype, dq.qclass, dq.tcp, dnssecOK, subnet)) {
886e2cf2
RG
252 d_lookupCollisions++;
253 return false;
254 }
255
c8c3d4e4
RG
256 memcpy(response, &queryId, sizeof(queryId));
257 memcpy(response + sizeof(queryId), value.value.c_str() + sizeof(queryId), sizeof(dnsheader) - sizeof(queryId));
258
259 if (value.len == sizeof(dnsheader)) {
260 /* DNS header only, our work here is done */
261 *responseLen = value.len;
262 d_hits++;
263 return true;
264 }
265
f87c4aff
RG
266 const size_t dnsQNameLen = dnsQName.length();
267 if (value.len < (sizeof(dnsheader) + dnsQNameLen)) {
268 return false;
269 }
270
f87c4aff
RG
271 memcpy(response + sizeof(dnsheader), dnsQName.c_str(), dnsQNameLen);
272 if (value.len > (sizeof(dnsheader) + dnsQNameLen)) {
273 memcpy(response + sizeof(dnsheader) + dnsQNameLen, value.value.c_str() + sizeof(dnsheader) + dnsQNameLen, value.len - (sizeof(dnsheader) + dnsQNameLen));
274 }
886e2cf2 275 *responseLen = value.len;
1ea747c0
RG
276 if (!stale) {
277 age = now - value.added;
278 }
279 else {
280 age = (value.validity - value.added) - d_staleTTL;
281 }
886e2cf2
RG
282 }
283
2b67180c 284 if (!d_dontAge && !skipAging) {
886e2cf2 285 ageDNSPacket(response, *responseLen, age);
1ea747c0
RG
286 }
287
886e2cf2
RG
288 d_hits++;
289 return true;
290}
291
4275aaba
RG
292/* Remove expired entries, until the cache has at most
293 upTo entries in it.
294*/
f627611d 295size_t DNSDistPacketCache::purgeExpired(size_t upTo)
886e2cf2 296{
f627611d 297 size_t removed = 0;
2b3eefc3
RG
298 uint64_t size = getSize();
299
c1b81381 300 if (size == 0 || upTo >= size) {
f627611d 301 return removed;
4275aaba 302 }
886e2cf2 303
2b3eefc3 304 size_t toRemove = size - upTo;
886e2cf2 305
2b3eefc3
RG
306 size_t scannedMaps = 0;
307
c1b81381 308 const time_t now = time(nullptr);
2b3eefc3
RG
309 do {
310 uint32_t shardIndex = (d_expungeIndex++ % d_shardCount);
311 WriteLock w(&d_shards.at(shardIndex).d_lock);
312 auto& map = d_shards[shardIndex].d_map;
313
314 for(auto it = map.begin(); toRemove > 0 && it != map.end(); ) {
315 const CacheValue& value = it->second;
316
317 if (value.validity < now) {
318 it = map.erase(it);
886e2cf2 319 --toRemove;
2b3eefc3 320 d_shards[shardIndex].d_entriesCount--;
f627611d 321 ++removed;
2b3eefc3
RG
322 } else {
323 ++it;
324 }
886e2cf2 325 }
2b3eefc3
RG
326
327 scannedMaps++;
886e2cf2 328 }
2b3eefc3 329 while (toRemove > 0 && scannedMaps < d_shardCount);
f627611d
RG
330
331 return removed;
886e2cf2
RG
332}
333
4275aaba
RG
334/* Remove all entries, keeping only upTo
335 entries in the cache */
f627611d 336size_t DNSDistPacketCache::expunge(size_t upTo)
4275aaba 337{
6d1a9248 338 size_t removed = 0;
2b3eefc3 339 const uint64_t size = getSize();
4275aaba 340
2b3eefc3 341 if (upTo >= size) {
f627611d 342 return removed;
4275aaba
RG
343 }
344
2b3eefc3 345 size_t toRemove = size - upTo;
2b3eefc3
RG
346
347 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
348 WriteLock w(&d_shards.at(shardIndex).d_lock);
349 auto& map = d_shards[shardIndex].d_map;
350 auto beginIt = map.begin();
351 auto endIt = beginIt;
352 size_t removeFromThisShard = (toRemove - removed) / (d_shardCount - shardIndex);
353 if (map.size() >= removeFromThisShard) {
354 std::advance(endIt, removeFromThisShard);
355 map.erase(beginIt, endIt);
356 d_shards[shardIndex].d_entriesCount -= removeFromThisShard;
357 removed += removeFromThisShard;
358 }
359 else {
360 removed += map.size();
361 map.clear();
362 d_shards[shardIndex].d_entriesCount = 0;
363 }
364 }
f627611d
RG
365
366 return removed;
4275aaba
RG
367}
368
f627611d 369size_t DNSDistPacketCache::expungeByName(const DNSName& name, uint16_t qtype, bool suffixMatch)
886e2cf2 370{
f627611d
RG
371 size_t removed = 0;
372
2b3eefc3
RG
373 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
374 WriteLock w(&d_shards.at(shardIndex).d_lock);
375 auto& map = d_shards[shardIndex].d_map;
376
377 for(auto it = map.begin(); it != map.end(); ) {
378 const CacheValue& value = it->second;
379
380 if ((value.qname == name || (suffixMatch && value.qname.isPartOf(name))) && (qtype == QType::ANY || qtype == value.qtype)) {
381 it = map.erase(it);
382 d_shards[shardIndex].d_entriesCount--;
f627611d 383 ++removed;
2b3eefc3
RG
384 } else {
385 ++it;
386 }
886e2cf2
RG
387 }
388 }
f627611d
RG
389
390 return removed;
886e2cf2
RG
391}
392
393bool DNSDistPacketCache::isFull()
394{
2b3eefc3
RG
395 return (getSize() >= d_maxEntries);
396}
397
398uint64_t DNSDistPacketCache::getSize()
399{
400 uint64_t count = 0;
401
402 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
403 count += d_shards.at(shardIndex).d_entriesCount;
404 }
405
406 return count;
886e2cf2
RG
407}
408
47698274 409uint32_t DNSDistPacketCache::getMinTTL(const char* packet, uint16_t length, bool* seenNoDataSOA)
886e2cf2 410{
47698274 411 return getDNSPacketMinTTL(packet, length, seenNoDataSOA);
886e2cf2
RG
412}
413
7acad2d5 414uint32_t DNSDistPacketCache::getKey(const std::string& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp)
886e2cf2
RG
415{
416 uint32_t result = 0;
417 /* skip the query ID */
cceddbef
RG
418 if (packetLen < sizeof(dnsheader))
419 throw std::range_error("Computing packet cache key for an invalid packet size");
886e2cf2 420 result = burtle(packet + 2, sizeof(dnsheader) - 2, result);
7acad2d5 421 string lc(toLower(qname));
886e2cf2 422 result = burtle((const unsigned char*) lc.c_str(), lc.length(), result);
cceddbef
RG
423 if (packetLen < sizeof(dnsheader) + consumed) {
424 throw std::range_error("Computing packet cache key for an invalid packet");
425 }
426 if (packetLen > ((sizeof(dnsheader) + consumed))) {
427 result = burtle(packet + sizeof(dnsheader) + consumed, packetLen - (sizeof(dnsheader) + consumed), result);
428 }
a176d205 429 result = burtle((const unsigned char*) &tcp, sizeof(tcp), result);
886e2cf2
RG
430 return result;
431}
432
2b3eefc3
RG
433uint32_t DNSDistPacketCache::getShardIndex(uint32_t key) const
434{
435 return key % d_shardCount;
436}
437
886e2cf2
RG
438string DNSDistPacketCache::toString()
439{
2b3eefc3 440 return std::to_string(getSize()) + "/" + std::to_string(d_maxEntries);
886e2cf2 441}
9e9be156
RG
442
443uint64_t DNSDistPacketCache::getEntriesCount()
444{
2b3eefc3 445 return getSize();
9e9be156 446}
f037144c
RG
447
448uint64_t DNSDistPacketCache::dump(int fd)
449{
450 FILE * fp = fdopen(dup(fd), "w");
451 if (fp == nullptr) {
452 return 0;
453 }
454
455 fprintf(fp, "; dnsdist's packet cache dump follows\n;\n");
456
457 uint64_t count = 0;
458 time_t now = time(nullptr);
459 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
460 ReadLock w(&d_shards.at(shardIndex).d_lock);
461 auto& map = d_shards[shardIndex].d_map;
462
463 for(const auto entry : map) {
464 const CacheValue& value = entry.second;
465 count++;
466
467 try {
d09ee158 468 fprintf(fp, "%s %" PRId64 " %s ; key %" PRIu32 ", length %" PRIu16 ", tcp %d, added %" PRId64 "\n", value.qname.toString().c_str(), static_cast<int64_t>(value.validity - now), QType(value.qtype).getName().c_str(), entry.first, value.len, value.tcp, static_cast<int64_t>(value.added));
f037144c
RG
469 }
470 catch(...) {
471 fprintf(fp, "; error printing '%s'\n", value.qname.empty() ? "EMPTY" : value.qname.toString().c_str());
472 }
473 }
474 }
475
476 fclose(fp);
477 return count;
478}