]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-cache.cc
Merge pull request #7350 from sjvs/patch-2
[thirdparty/pdns.git] / pdns / dnsdist-cache.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include <cinttypes>
23
24 #include "dnsdist.hh"
25 #include "dolog.hh"
26 #include "dnsparser.hh"
27 #include "dnsdist-cache.hh"
28 #include "dnsdist-ecs.hh"
29 #include "ednsoptions.hh"
30 #include "ednssubnet.hh"
31
32 DNSDistPacketCache::DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL, uint32_t minTTL, uint32_t tempFailureTTL, uint32_t maxNegativeTTL, uint32_t staleTTL, bool dontAge, uint32_t shards, bool deferrableInsertLock, bool parseECS): d_maxEntries(maxEntries), d_shardCount(shards), d_maxTTL(maxTTL), d_tempFailureTTL(tempFailureTTL), d_maxNegativeTTL(maxNegativeTTL), d_minTTL(minTTL), d_staleTTL(staleTTL), d_dontAge(dontAge), d_deferrableInsertLock(deferrableInsertLock), d_parseECS(parseECS)
33 {
34 d_shards.resize(d_shardCount);
35
36 /* we reserve maxEntries + 1 to avoid rehashing from occurring
37 when we get to maxEntries, as it means a load factor of 1 */
38 for (auto& shard : d_shards) {
39 shard.setSize((maxEntries / d_shardCount) + 1);
40 }
41 }
42
43 DNSDistPacketCache::~DNSDistPacketCache()
44 {
45 try {
46 vector<std::unique_ptr<WriteLock>> locks;
47 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
48 locks.push_back(std::unique_ptr<WriteLock>(new WriteLock(&d_shards.at(shardIndex).d_lock)));
49 }
50 }
51 catch(...) {
52 }
53 }
54
55 bool DNSDistPacketCache::getClientSubnet(const char* packet, unsigned int consumed, uint16_t len, boost::optional<Netmask>& subnet)
56 {
57 uint16_t optRDPosition;
58 size_t remaining = 0;
59
60 int res = getEDNSOptionsStart(const_cast<char*>(packet), consumed, len, &optRDPosition, &remaining);
61
62 if (res == 0) {
63 char * ecsOptionStart = nullptr;
64 size_t ecsOptionSize = 0;
65
66 res = getEDNSOption(const_cast<char*>(packet) + optRDPosition, remaining, EDNSOptionCode::ECS, &ecsOptionStart, &ecsOptionSize);
67
68 if (res == 0 && ecsOptionSize > (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)) {
69
70 EDNSSubnetOpts eso;
71 if (getEDNSSubnetOptsFromString(ecsOptionStart + (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), ecsOptionSize - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), &eso) == true) {
72 subnet = eso.source;
73 return true;
74 }
75 }
76 }
77
78 return false;
79 }
80
81 bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, bool dnssecOK, const boost::optional<Netmask>& subnet) const
82 {
83 if (cachedValue.queryFlags != queryFlags || cachedValue.dnssecOK != dnssecOK || cachedValue.tcp != tcp || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname) {
84 return false;
85 }
86
87 if (d_parseECS && cachedValue.subnet != subnet) {
88 return false;
89 }
90
91 return true;
92 }
93
94 void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, CacheValue& newValue)
95 {
96 auto& map = shard.d_map;
97 /* check again now that we hold the lock to prevent a race */
98 if (map.size() >= (d_maxEntries / d_shardCount)) {
99 return;
100 }
101
102 std::unordered_map<uint32_t,CacheValue>::iterator it;
103 bool result;
104 tie(it, result) = map.insert({key, newValue});
105
106 if (result) {
107 shard.d_entriesCount++;
108 return;
109 }
110
111 /* in case of collision, don't override the existing entry
112 except if it has expired */
113 CacheValue& value = it->second;
114 bool wasExpired = value.validity <= newValue.added;
115
116 if (!wasExpired && !cachedValueMatches(value, newValue.queryFlags, newValue.qname, newValue.qtype, newValue.qclass, newValue.tcp, newValue.dnssecOK, newValue.subnet)) {
117 d_insertCollisions++;
118 return;
119 }
120
121 /* if the existing entry had a longer TTD, keep it */
122 if (newValue.validity <= value.validity) {
123 return;
124 }
125
126 value = newValue;
127 }
128
129 void DNSDistPacketCache::insert(uint32_t key, const boost::optional<Netmask>& subnet, uint16_t queryFlags, bool dnssecOK, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional<uint32_t> tempFailureTTL)
130 {
131 if (responseLen < sizeof(dnsheader)) {
132 return;
133 }
134
135 uint32_t minTTL;
136
137 if (rcode == RCode::ServFail || rcode == RCode::Refused) {
138 minTTL = tempFailureTTL == boost::none ? d_tempFailureTTL : *tempFailureTTL;
139 if (minTTL == 0) {
140 return;
141 }
142 }
143 else {
144 bool seenAuthSOA = false;
145 minTTL = getMinTTL(response, responseLen, &seenAuthSOA);
146
147 /* no TTL found, we don't want to cache this */
148 if (minTTL == std::numeric_limits<uint32_t>::max()) {
149 return;
150 }
151
152 if (rcode == RCode::NXDomain || (rcode == RCode::NoError && seenAuthSOA)) {
153 minTTL = std::min(minTTL, d_maxNegativeTTL);
154 }
155 else if (minTTL > d_maxTTL) {
156 minTTL = d_maxTTL;
157 }
158
159 if (minTTL < d_minTTL) {
160 d_ttlTooShorts++;
161 return;
162 }
163 }
164
165 uint32_t shardIndex = getShardIndex(key);
166
167 if (d_shards.at(shardIndex).d_entriesCount >= (d_maxEntries / d_shardCount)) {
168 return;
169 }
170
171 const time_t now = time(nullptr);
172 time_t newValidity = now + minTTL;
173 CacheValue newValue;
174 newValue.qname = qname;
175 newValue.qtype = qtype;
176 newValue.qclass = qclass;
177 newValue.queryFlags = queryFlags;
178 newValue.len = responseLen;
179 newValue.validity = newValidity;
180 newValue.added = now;
181 newValue.tcp = tcp;
182 newValue.dnssecOK = dnssecOK;
183 newValue.value = std::string(response, responseLen);
184 newValue.subnet = subnet;
185
186 auto& shard = d_shards.at(shardIndex);
187
188 if (d_deferrableInsertLock) {
189 TryWriteLock w(&shard.d_lock);
190
191 if (!w.gotIt()) {
192 d_deferredInserts++;
193 return;
194 }
195 insertLocked(shard, key, newValue);
196 }
197 else {
198 WriteLock w(&shard.d_lock);
199
200 insertLocked(shard, key, newValue);
201 }
202 }
203
204 bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, boost::optional<Netmask>& subnet, bool dnssecOK, uint32_t allowExpired, bool skipAging)
205 {
206 std::string dnsQName(dq.qname->toDNSString());
207 uint32_t key = getKey(dnsQName, consumed, reinterpret_cast<const unsigned char*>(dq.dh), dq.len, dq.tcp);
208
209 if (keyOut)
210 *keyOut = key;
211
212 if (d_parseECS) {
213 getClientSubnet(reinterpret_cast<const char*>(dq.dh), consumed, dq.len, subnet);
214 }
215
216 uint32_t shardIndex = getShardIndex(key);
217 time_t now = time(nullptr);
218 time_t age;
219 bool stale = false;
220 auto& shard = d_shards.at(shardIndex);
221 auto& map = shard.d_map;
222 {
223 TryReadLock r(&shard.d_lock);
224 if (!r.gotIt()) {
225 d_deferredLookups++;
226 return false;
227 }
228
229 std::unordered_map<uint32_t,CacheValue>::const_iterator it = map.find(key);
230 if (it == map.end()) {
231 d_misses++;
232 return false;
233 }
234
235 const CacheValue& value = it->second;
236 if (value.validity < now) {
237 if ((now - value.validity) >= static_cast<time_t>(allowExpired)) {
238 d_misses++;
239 return false;
240 }
241 else {
242 stale = true;
243 }
244 }
245
246 if (*responseLen < value.len || value.len < sizeof(dnsheader)) {
247 return false;
248 }
249
250 /* check for collision */
251 if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.dh)), *dq.qname, dq.qtype, dq.qclass, dq.tcp, dnssecOK, subnet)) {
252 d_lookupCollisions++;
253 return false;
254 }
255
256 memcpy(response, &queryId, sizeof(queryId));
257 memcpy(response + sizeof(queryId), value.value.c_str() + sizeof(queryId), sizeof(dnsheader) - sizeof(queryId));
258
259 if (value.len == sizeof(dnsheader)) {
260 /* DNS header only, our work here is done */
261 *responseLen = value.len;
262 d_hits++;
263 return true;
264 }
265
266 const size_t dnsQNameLen = dnsQName.length();
267 if (value.len < (sizeof(dnsheader) + dnsQNameLen)) {
268 return false;
269 }
270
271 memcpy(response + sizeof(dnsheader), dnsQName.c_str(), dnsQNameLen);
272 if (value.len > (sizeof(dnsheader) + dnsQNameLen)) {
273 memcpy(response + sizeof(dnsheader) + dnsQNameLen, value.value.c_str() + sizeof(dnsheader) + dnsQNameLen, value.len - (sizeof(dnsheader) + dnsQNameLen));
274 }
275 *responseLen = value.len;
276 if (!stale) {
277 age = now - value.added;
278 }
279 else {
280 age = (value.validity - value.added) - d_staleTTL;
281 }
282 }
283
284 if (!d_dontAge && !skipAging) {
285 ageDNSPacket(response, *responseLen, age);
286 }
287
288 d_hits++;
289 return true;
290 }
291
292 /* Remove expired entries, until the cache has at most
293 upTo entries in it.
294 */
295 void DNSDistPacketCache::purgeExpired(size_t upTo)
296 {
297 uint64_t size = getSize();
298
299 if (size == 0 || upTo >= size) {
300 return;
301 }
302
303 size_t toRemove = size - upTo;
304
305 size_t scannedMaps = 0;
306
307 const time_t now = time(nullptr);
308 do {
309 uint32_t shardIndex = (d_expungeIndex++ % d_shardCount);
310 WriteLock w(&d_shards.at(shardIndex).d_lock);
311 auto& map = d_shards[shardIndex].d_map;
312
313 for(auto it = map.begin(); toRemove > 0 && it != map.end(); ) {
314 const CacheValue& value = it->second;
315
316 if (value.validity < now) {
317 it = map.erase(it);
318 --toRemove;
319 d_shards[shardIndex].d_entriesCount--;
320 } else {
321 ++it;
322 }
323 }
324
325 scannedMaps++;
326 }
327 while (toRemove > 0 && scannedMaps < d_shardCount);
328 }
329
330 /* Remove all entries, keeping only upTo
331 entries in the cache */
332 void DNSDistPacketCache::expunge(size_t upTo)
333 {
334 const uint64_t size = getSize();
335
336 if (upTo >= size) {
337 return;
338 }
339
340 size_t toRemove = size - upTo;
341 size_t removed = 0;
342
343 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
344 WriteLock w(&d_shards.at(shardIndex).d_lock);
345 auto& map = d_shards[shardIndex].d_map;
346 auto beginIt = map.begin();
347 auto endIt = beginIt;
348 size_t removeFromThisShard = (toRemove - removed) / (d_shardCount - shardIndex);
349 if (map.size() >= removeFromThisShard) {
350 std::advance(endIt, removeFromThisShard);
351 map.erase(beginIt, endIt);
352 d_shards[shardIndex].d_entriesCount -= removeFromThisShard;
353 removed += removeFromThisShard;
354 }
355 else {
356 removed += map.size();
357 map.clear();
358 d_shards[shardIndex].d_entriesCount = 0;
359 }
360 }
361 }
362
363 void DNSDistPacketCache::expungeByName(const DNSName& name, uint16_t qtype, bool suffixMatch)
364 {
365 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
366 WriteLock w(&d_shards.at(shardIndex).d_lock);
367 auto& map = d_shards[shardIndex].d_map;
368
369 for(auto it = map.begin(); it != map.end(); ) {
370 const CacheValue& value = it->second;
371
372 if ((value.qname == name || (suffixMatch && value.qname.isPartOf(name))) && (qtype == QType::ANY || qtype == value.qtype)) {
373 it = map.erase(it);
374 d_shards[shardIndex].d_entriesCount--;
375 } else {
376 ++it;
377 }
378 }
379 }
380 }
381
382 bool DNSDistPacketCache::isFull()
383 {
384 return (getSize() >= d_maxEntries);
385 }
386
387 uint64_t DNSDistPacketCache::getSize()
388 {
389 uint64_t count = 0;
390
391 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
392 count += d_shards.at(shardIndex).d_entriesCount;
393 }
394
395 return count;
396 }
397
398 uint32_t DNSDistPacketCache::getMinTTL(const char* packet, uint16_t length, bool* seenNoDataSOA)
399 {
400 return getDNSPacketMinTTL(packet, length, seenNoDataSOA);
401 }
402
403 uint32_t DNSDistPacketCache::getKey(const std::string& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp)
404 {
405 uint32_t result = 0;
406 /* skip the query ID */
407 if (packetLen < sizeof(dnsheader))
408 throw std::range_error("Computing packet cache key for an invalid packet size");
409 result = burtle(packet + 2, sizeof(dnsheader) - 2, result);
410 string lc(toLower(qname));
411 result = burtle((const unsigned char*) lc.c_str(), lc.length(), result);
412 if (packetLen < sizeof(dnsheader) + consumed) {
413 throw std::range_error("Computing packet cache key for an invalid packet");
414 }
415 if (packetLen > ((sizeof(dnsheader) + consumed))) {
416 result = burtle(packet + sizeof(dnsheader) + consumed, packetLen - (sizeof(dnsheader) + consumed), result);
417 }
418 result = burtle((const unsigned char*) &tcp, sizeof(tcp), result);
419 return result;
420 }
421
422 uint32_t DNSDistPacketCache::getShardIndex(uint32_t key) const
423 {
424 return key % d_shardCount;
425 }
426
427 string DNSDistPacketCache::toString()
428 {
429 return std::to_string(getSize()) + "/" + std::to_string(d_maxEntries);
430 }
431
432 uint64_t DNSDistPacketCache::getEntriesCount()
433 {
434 return getSize();
435 }
436
437 uint64_t DNSDistPacketCache::dump(int fd)
438 {
439 FILE * fp = fdopen(dup(fd), "w");
440 if (fp == nullptr) {
441 return 0;
442 }
443
444 fprintf(fp, "; dnsdist's packet cache dump follows\n;\n");
445
446 uint64_t count = 0;
447 time_t now = time(nullptr);
448 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
449 ReadLock w(&d_shards.at(shardIndex).d_lock);
450 auto& map = d_shards[shardIndex].d_map;
451
452 for(const auto entry : map) {
453 const CacheValue& value = entry.second;
454 count++;
455
456 try {
457 fprintf(fp, "%s %" PRId64 " %s ; key %" PRIu32 ", length %" PRIu16 ", tcp %d, added %" PRId64 "\n", value.qname.toString().c_str(), static_cast<int64_t>(value.validity - now), QType(value.qtype).getName().c_str(), entry.first, value.len, value.tcp, static_cast<int64_t>(value.added));
458 }
459 catch(...) {
460 fprintf(fp, "; error printing '%s'\n", value.qname.empty() ? "EMPTY" : value.qname.toString().c_str());
461 }
462 }
463 }
464
465 fclose(fp);
466 return count;
467 }