]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-cache.cc
Merge pull request #7094 from neilcook/udr
[thirdparty/pdns.git] / pdns / dnsdist-cache.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include <cinttypes>
23
24 #include "dnsdist.hh"
25 #include "dolog.hh"
26 #include "dnsparser.hh"
27 #include "dnsdist-cache.hh"
28 #include "dnsdist-ecs.hh"
29 #include "ednsoptions.hh"
30 #include "ednssubnet.hh"
31
32 DNSDistPacketCache::DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL, uint32_t minTTL, uint32_t tempFailureTTL, uint32_t maxNegativeTTL, uint32_t staleTTL, bool dontAge, uint32_t shards, bool deferrableInsertLock, bool parseECS): d_maxEntries(maxEntries), d_shardCount(shards), d_maxTTL(maxTTL), d_tempFailureTTL(tempFailureTTL), d_maxNegativeTTL(maxNegativeTTL), d_minTTL(minTTL), d_staleTTL(staleTTL), d_dontAge(dontAge), d_deferrableInsertLock(deferrableInsertLock), d_parseECS(parseECS)
33 {
34 d_shards.resize(d_shardCount);
35
36 /* we reserve maxEntries + 1 to avoid rehashing from occurring
37 when we get to maxEntries, as it means a load factor of 1 */
38 for (auto& shard : d_shards) {
39 shard.setSize((maxEntries / d_shardCount) + 1);
40 }
41 }
42
43 DNSDistPacketCache::~DNSDistPacketCache()
44 {
45 try {
46 vector<std::unique_ptr<WriteLock>> locks;
47 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
48 locks.push_back(std::unique_ptr<WriteLock>(new WriteLock(&d_shards.at(shardIndex).d_lock)));
49 }
50 }
51 catch(...) {
52 }
53 }
54
55 bool DNSDistPacketCache::getClientSubnet(const char* packet, unsigned int consumed, uint16_t len, boost::optional<Netmask>& subnet)
56 {
57 uint16_t optRDPosition;
58 size_t remaining = 0;
59
60 int res = getEDNSOptionsStart(const_cast<char*>(packet), consumed, len, &optRDPosition, &remaining);
61
62 if (res == 0) {
63 char * ecsOptionStart = nullptr;
64 size_t ecsOptionSize = 0;
65
66 res = getEDNSOption(const_cast<char*>(packet) + optRDPosition, remaining, EDNSOptionCode::ECS, &ecsOptionStart, &ecsOptionSize);
67
68 if (res == 0 && ecsOptionSize > (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)) {
69
70 EDNSSubnetOpts eso;
71 if (getEDNSSubnetOptsFromString(ecsOptionStart + (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), ecsOptionSize - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), &eso) == true) {
72 subnet = eso.source;
73 return true;
74 }
75 }
76 }
77
78 return false;
79 }
80
81 bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool tcp, const boost::optional<Netmask>& subnet) const
82 {
83 if (cachedValue.queryFlags != queryFlags || cachedValue.tcp != tcp || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname) {
84 return false;
85 }
86
87 if (d_parseECS && cachedValue.subnet != subnet) {
88 return false;
89 }
90
91 return true;
92 }
93
94 void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, CacheValue& newValue)
95 {
96 auto& map = shard.d_map;
97 /* check again now that we hold the lock to prevent a race */
98 if (map.size() >= (d_maxEntries / d_shardCount)) {
99 return;
100 }
101
102 std::unordered_map<uint32_t,CacheValue>::iterator it;
103 bool result;
104 tie(it, result) = map.insert({key, newValue});
105
106 if (result) {
107 shard.d_entriesCount++;
108 return;
109 }
110
111 /* in case of collision, don't override the existing entry
112 except if it has expired */
113 CacheValue& value = it->second;
114 bool wasExpired = value.validity <= newValue.added;
115
116 if (!wasExpired && !cachedValueMatches(value, newValue.queryFlags, newValue.qname, newValue.qtype, newValue.qclass, newValue.tcp, newValue.subnet)) {
117 d_insertCollisions++;
118 return;
119 }
120
121 /* if the existing entry had a longer TTD, keep it */
122 if (newValue.validity <= value.validity) {
123 return;
124 }
125
126 value = newValue;
127 }
128
129 void DNSDistPacketCache::insert(uint32_t key, const boost::optional<Netmask>& subnet, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional<uint32_t> tempFailureTTL)
130 {
131 if (responseLen < sizeof(dnsheader)) {
132 return;
133 }
134
135 uint32_t minTTL;
136
137 if (rcode == RCode::ServFail || rcode == RCode::Refused) {
138 minTTL = tempFailureTTL == boost::none ? d_tempFailureTTL : *tempFailureTTL;
139 if (minTTL == 0) {
140 return;
141 }
142 }
143 else {
144 bool seenAuthSOA = false;
145 minTTL = getMinTTL(response, responseLen, &seenAuthSOA);
146
147 /* no TTL found, we don't want to cache this */
148 if (minTTL == std::numeric_limits<uint32_t>::max()) {
149 return;
150 }
151
152 if (rcode == RCode::NXDomain || (rcode == RCode::NoError && seenAuthSOA)) {
153 minTTL = std::min(minTTL, d_maxNegativeTTL);
154 }
155 else if (minTTL > d_maxTTL) {
156 minTTL = d_maxTTL;
157 }
158
159 if (minTTL < d_minTTL) {
160 d_ttlTooShorts++;
161 return;
162 }
163 }
164
165 uint32_t shardIndex = getShardIndex(key);
166
167 if (d_shards.at(shardIndex).d_entriesCount >= (d_maxEntries / d_shardCount)) {
168 return;
169 }
170
171 const time_t now = time(NULL);
172 time_t newValidity = now + minTTL;
173 CacheValue newValue;
174 newValue.qname = qname;
175 newValue.qtype = qtype;
176 newValue.qclass = qclass;
177 newValue.queryFlags = queryFlags;
178 newValue.len = responseLen;
179 newValue.validity = newValidity;
180 newValue.added = now;
181 newValue.tcp = tcp;
182 newValue.value = std::string(response, responseLen);
183 newValue.subnet = subnet;
184
185 auto& shard = d_shards.at(shardIndex);
186
187 if (d_deferrableInsertLock) {
188 TryWriteLock w(&shard.d_lock);
189
190 if (!w.gotIt()) {
191 d_deferredInserts++;
192 return;
193 }
194 insertLocked(shard, key, newValue);
195 }
196 else {
197 WriteLock w(&shard.d_lock);
198
199 insertLocked(shard, key, newValue);
200 }
201 }
202
203 bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, boost::optional<Netmask>& subnet, uint32_t allowExpired, bool skipAging)
204 {
205 std::string dnsQName(dq.qname->toDNSString());
206 uint32_t key = getKey(dnsQName, consumed, reinterpret_cast<const unsigned char*>(dq.dh), dq.len, dq.tcp);
207
208 if (keyOut)
209 *keyOut = key;
210
211 if (d_parseECS) {
212 getClientSubnet(reinterpret_cast<const char*>(dq.dh), consumed, dq.len, subnet);
213 }
214
215 uint32_t shardIndex = getShardIndex(key);
216 time_t now = time(NULL);
217 time_t age;
218 bool stale = false;
219 auto& shard = d_shards.at(shardIndex);
220 auto& map = shard.d_map;
221 {
222 TryReadLock r(&shard.d_lock);
223 if (!r.gotIt()) {
224 d_deferredLookups++;
225 return false;
226 }
227
228 std::unordered_map<uint32_t,CacheValue>::const_iterator it = map.find(key);
229 if (it == map.end()) {
230 d_misses++;
231 return false;
232 }
233
234 const CacheValue& value = it->second;
235 if (value.validity < now) {
236 if ((now - value.validity) >= static_cast<time_t>(allowExpired)) {
237 d_misses++;
238 return false;
239 }
240 else {
241 stale = true;
242 }
243 }
244
245 if (*responseLen < value.len || value.len < sizeof(dnsheader)) {
246 return false;
247 }
248
249 /* check for collision */
250 if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.dh)), *dq.qname, dq.qtype, dq.qclass, dq.tcp, subnet)) {
251 d_lookupCollisions++;
252 return false;
253 }
254
255 memcpy(response, &queryId, sizeof(queryId));
256 memcpy(response + sizeof(queryId), value.value.c_str() + sizeof(queryId), sizeof(dnsheader) - sizeof(queryId));
257
258 if (value.len == sizeof(dnsheader)) {
259 /* DNS header only, our work here is done */
260 *responseLen = value.len;
261 d_hits++;
262 return true;
263 }
264
265 const size_t dnsQNameLen = dnsQName.length();
266 if (value.len < (sizeof(dnsheader) + dnsQNameLen)) {
267 return false;
268 }
269
270 memcpy(response + sizeof(dnsheader), dnsQName.c_str(), dnsQNameLen);
271 if (value.len > (sizeof(dnsheader) + dnsQNameLen)) {
272 memcpy(response + sizeof(dnsheader) + dnsQNameLen, value.value.c_str() + sizeof(dnsheader) + dnsQNameLen, value.len - (sizeof(dnsheader) + dnsQNameLen));
273 }
274 *responseLen = value.len;
275 if (!stale) {
276 age = now - value.added;
277 }
278 else {
279 age = (value.validity - value.added) - d_staleTTL;
280 }
281 }
282
283 if (!d_dontAge && !skipAging) {
284 ageDNSPacket(response, *responseLen, age);
285 }
286
287 d_hits++;
288 return true;
289 }
290
291 /* Remove expired entries, until the cache has at most
292 upTo entries in it.
293 */
294 void DNSDistPacketCache::purgeExpired(size_t upTo)
295 {
296 time_t now = time(NULL);
297 uint64_t size = getSize();
298
299 if (upTo >= size) {
300 return;
301 }
302
303 size_t toRemove = size - upTo;
304
305 size_t scannedMaps = 0;
306
307 do {
308 uint32_t shardIndex = (d_expungeIndex++ % d_shardCount);
309 WriteLock w(&d_shards.at(shardIndex).d_lock);
310 auto& map = d_shards[shardIndex].d_map;
311
312 for(auto it = map.begin(); toRemove > 0 && it != map.end(); ) {
313 const CacheValue& value = it->second;
314
315 if (value.validity < now) {
316 it = map.erase(it);
317 --toRemove;
318 d_shards[shardIndex].d_entriesCount--;
319 } else {
320 ++it;
321 }
322 }
323
324 scannedMaps++;
325 }
326 while (toRemove > 0 && scannedMaps < d_shardCount);
327 }
328
329 /* Remove all entries, keeping only upTo
330 entries in the cache */
331 void DNSDistPacketCache::expunge(size_t upTo)
332 {
333 const uint64_t size = getSize();
334
335 if (upTo >= size) {
336 return;
337 }
338
339 size_t toRemove = size - upTo;
340 size_t removed = 0;
341
342 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
343 WriteLock w(&d_shards.at(shardIndex).d_lock);
344 auto& map = d_shards[shardIndex].d_map;
345 auto beginIt = map.begin();
346 auto endIt = beginIt;
347 size_t removeFromThisShard = (toRemove - removed) / (d_shardCount - shardIndex);
348 if (map.size() >= removeFromThisShard) {
349 std::advance(endIt, removeFromThisShard);
350 map.erase(beginIt, endIt);
351 d_shards[shardIndex].d_entriesCount -= removeFromThisShard;
352 removed += removeFromThisShard;
353 }
354 else {
355 removed += map.size();
356 map.clear();
357 d_shards[shardIndex].d_entriesCount = 0;
358 }
359 }
360 }
361
362 void DNSDistPacketCache::expungeByName(const DNSName& name, uint16_t qtype, bool suffixMatch)
363 {
364 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
365 WriteLock w(&d_shards.at(shardIndex).d_lock);
366 auto& map = d_shards[shardIndex].d_map;
367
368 for(auto it = map.begin(); it != map.end(); ) {
369 const CacheValue& value = it->second;
370
371 if ((value.qname == name || (suffixMatch && value.qname.isPartOf(name))) && (qtype == QType::ANY || qtype == value.qtype)) {
372 it = map.erase(it);
373 d_shards[shardIndex].d_entriesCount--;
374 } else {
375 ++it;
376 }
377 }
378 }
379 }
380
381 bool DNSDistPacketCache::isFull()
382 {
383 return (getSize() >= d_maxEntries);
384 }
385
386 uint64_t DNSDistPacketCache::getSize()
387 {
388 uint64_t count = 0;
389
390 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
391 count += d_shards.at(shardIndex).d_entriesCount;
392 }
393
394 return count;
395 }
396
397 uint32_t DNSDistPacketCache::getMinTTL(const char* packet, uint16_t length, bool* seenNoDataSOA)
398 {
399 return getDNSPacketMinTTL(packet, length, seenNoDataSOA);
400 }
401
402 uint32_t DNSDistPacketCache::getKey(const std::string& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp)
403 {
404 uint32_t result = 0;
405 /* skip the query ID */
406 if (packetLen < sizeof(dnsheader))
407 throw std::range_error("Computing packet cache key for an invalid packet size");
408 result = burtle(packet + 2, sizeof(dnsheader) - 2, result);
409 string lc(toLower(qname));
410 result = burtle((const unsigned char*) lc.c_str(), lc.length(), result);
411 if (packetLen < sizeof(dnsheader) + consumed) {
412 throw std::range_error("Computing packet cache key for an invalid packet");
413 }
414 if (packetLen > ((sizeof(dnsheader) + consumed))) {
415 result = burtle(packet + sizeof(dnsheader) + consumed, packetLen - (sizeof(dnsheader) + consumed), result);
416 }
417 result = burtle((const unsigned char*) &tcp, sizeof(tcp), result);
418 return result;
419 }
420
421 uint32_t DNSDistPacketCache::getShardIndex(uint32_t key) const
422 {
423 return key % d_shardCount;
424 }
425
426 string DNSDistPacketCache::toString()
427 {
428 return std::to_string(getSize()) + "/" + std::to_string(d_maxEntries);
429 }
430
431 uint64_t DNSDistPacketCache::getEntriesCount()
432 {
433 return getSize();
434 }
435
436 uint64_t DNSDistPacketCache::dump(int fd)
437 {
438 FILE * fp = fdopen(dup(fd), "w");
439 if (fp == nullptr) {
440 return 0;
441 }
442
443 fprintf(fp, "; dnsdist's packet cache dump follows\n;\n");
444
445 uint64_t count = 0;
446 time_t now = time(nullptr);
447 for (uint32_t shardIndex = 0; shardIndex < d_shardCount; shardIndex++) {
448 ReadLock w(&d_shards.at(shardIndex).d_lock);
449 auto& map = d_shards[shardIndex].d_map;
450
451 for(const auto entry : map) {
452 const CacheValue& value = entry.second;
453 count++;
454
455 try {
456 fprintf(fp, "%s %" PRId64 " %s ; key %" PRIu32 ", length %" PRIu16 ", tcp %d, added %" PRId64 "\n", value.qname.toString().c_str(), static_cast<int64_t>(value.validity - now), QType(value.qtype).getName().c_str(), entry.first, value.len, value.tcp, static_cast<int64_t>(value.added));
457 }
458 catch(...) {
459 fprintf(fp, "; error printing '%s'\n", value.qname.empty() ? "EMPTY" : value.qname.toString().c_str());
460 }
461 }
462 }
463
464 fclose(fp);
465 return count;
466 }