From: Remi Gacogne Date: Wed, 3 May 2023 13:02:34 +0000 (+0200) Subject: dnsdist: Fix cache hit and miss metrics with DoH queries X-Git-Tag: dnsdist-1.8.1~8^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ddc8f6237392b122fb5de7ef09742bbac9dc3af3;p=thirdparty%2Fpdns.git dnsdist: Fix cache hit and miss metrics with DoH queries Since we do two lookups for DoH queries forwarded over UDP (first TCP then UDP), we need to be careful to only record a cache miss in our last attempt. (cherry picked from commit bc4d98b7cb2ecad488560d1dbef156708a1166af) --- diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 7642b3a2d1..9f2ab57dcd 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -1347,7 +1347,9 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders forwardedOverUDP = false; } - if (dq.ids.packetCache->get(dq, dq.getHeader()->id, &dq.ids.cacheKey, dq.ids.subnet, dq.ids.dnssecOK, forwardedOverUDP, allowExpired, false, true, true)) { + /* we do not record a miss for queries received over DoH and forwarded over TCP + yet, as we will do a second-lookup */ + if (dq.ids.packetCache->get(dq, dq.getHeader()->id, &dq.ids.cacheKey, dq.ids.subnet, dq.ids.dnssecOK, forwardedOverUDP, allowExpired, false, true, (dq.ids.protocol != dnsdist::Protocol::DoH || forwardedOverUDP) ? true : false)) { restoreFlags(dq.getHeader(), dq.ids.origFlags); @@ -1363,7 +1365,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders } else if (dq.ids.protocol == dnsdist::Protocol::DoH && !forwardedOverUDP) { /* do a second-lookup for UDP responses, but we do not want TC=1 answers */ - if (dq.ids.packetCache->get(dq, dq.getHeader()->id, &dq.ids.cacheKeyUDP, dq.ids.subnet, dq.ids.dnssecOK, true, allowExpired, false, false, false)) { + if (dq.ids.packetCache->get(dq, dq.getHeader()->id, &dq.ids.cacheKeyUDP, dq.ids.subnet, dq.ids.dnssecOK, true, allowExpired, false, false, true)) { if (!prepareOutgoingResponse(holders, *dq.ids.cs, dq, true)) { return ProcessQueryResult::Drop; } diff --git a/regression-tests.dnsdist/test_Metrics.py b/regression-tests.dnsdist/test_Metrics.py index d907b11354..5a70d0d8ce 100644 --- a/regression-tests.dnsdist/test_Metrics.py +++ b/regression-tests.dnsdist/test_Metrics.py @@ -59,6 +59,18 @@ class TestRuleMetrics(DNSDistTest): self.assertIn(name, stats) return int(stats[name]) + def getPoolMetric(self, poolName, metricName): + headers = {'x-api-key': self._webServerAPIKey} + url = 'http://127.0.0.1:' + str(self._webServerPort) + '/api/v1/servers/localhost/pool?name=' + poolName + r = requests.get(url, headers=headers, timeout=self._webTimeout) + self.assertTrue(r) + self.assertEqual(r.status_code, 200) + self.assertTrue(r.json()) + content = r.json() + stats = content['stats'] + self.assertIn(metricName, stats) + return int(stats[metricName]) + def testRCodeIncreaseMetrics(self): """ Metrics: Check that metrics are correctly updated for RCodeAction @@ -93,6 +105,7 @@ class TestRuleMetrics(DNSDistTest): # self-generated responses should not increase this metric self.assertEqual(self.getMetric('servfail-responses'), servfailBackendResponses) + def testCacheMetrics(self): """ Metrics: Check that metrics are correctly updated for cache misses and hits @@ -114,6 +127,8 @@ class TestRuleMetrics(DNSDistTest): responsesBefore = self.getMetric('responses') cacheHitsBefore = self.getMetric('cache-hits') cacheMissesBefore = self.getMetric('cache-misses') + poolCacheHitsBefore = self.getPoolMetric('cache', 'cacheHits') + poolCacheMissesBefore = self.getPoolMetric('cache', 'cacheMisses') sender = getattr(self, method) # first time, cache miss @@ -128,6 +143,8 @@ class TestRuleMetrics(DNSDistTest): self.assertEqual(self.getMetric('responses'), responsesBefore + 2) self.assertEqual(self.getMetric('cache-hits'), cacheHitsBefore + 1) self.assertEqual(self.getMetric('cache-misses'), cacheMissesBefore + 1) + self.assertEqual(self.getPoolMetric('cache', 'cacheHits'), poolCacheHitsBefore + 1) + self.assertEqual(self.getPoolMetric('cache', 'cacheMisses'), poolCacheMissesBefore + 1) def testServFailMetrics(self): """