]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Prevent implicit conversions between dnsdist::Protocol and uint8_t
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 11 Oct 2021 13:52:18 +0000 (15:52 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 11 Oct 2021 13:52:18 +0000 (15:52 +0200)
This commit makes sure that we always use the dnsdist::Protocol type,
and that we cannot easily convert to or from a different type by mistake.

C++ really dropped the ball by making it impossible to declare methods
on the 'enum class' that solved the issue of making regular enums
implicitly convertible to int and back, thus making it possible to
assign and compare different types of enums together, as well as
enums and ints.
The result is that we are stuck with declaring classes to hold our
methods, along with a lot of plumbing to make sure that we can convert
in some cases but not allow obvious mistakes from happening.

pdns/dnsdist-lua-actions.cc
pdns/dnsdist-protocols.cc
pdns/dnsdist-protocols.hh
pdns/dnsdist-rings.cc
pdns/dnsdist-rings.hh
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/test-dnsdistdynblocks_hh.cc
pdns/dnsdistdist/test-dnsdistrings_cc.cc
pdns/dnsdistdist/test-dnsdisttcp_cc.cc

index b6e83949fec1a2292c9df806fdbc7f5ba038f17f..239b9f8f85d59b84478ff1acf972a8ac023fa3dc 100644 (file)
@@ -1320,25 +1320,19 @@ private:
 
 static DnstapMessage::ProtocolType ProtocolToDNSTap(dnsdist::Protocol protocol)
 {
-  DnstapMessage::ProtocolType result;
-  switch (protocol) {
-  default:
-  case dnsdist::Protocol::DoUDP:
-  case dnsdist::Protocol::DNSCryptUDP:
-    result = DnstapMessage::ProtocolType::DoUDP;
-    break;
-  case dnsdist::Protocol::DoTCP:
-  case dnsdist::Protocol::DNSCryptTCP:
-    result = DnstapMessage::ProtocolType::DoTCP;
-    break;
-  case dnsdist::Protocol::DoT:
-    result = DnstapMessage::ProtocolType::DoT;
-    break;
-  case dnsdist::Protocol::DoH:
-    result = DnstapMessage::ProtocolType::DoH;
-    break;
-  }
-  return result;
+  if (protocol == dnsdist::Protocol::DoUDP || protocol == dnsdist::Protocol::DNSCryptUDP) {
+    return DnstapMessage::ProtocolType::DoUDP;
+  }
+  else if (protocol == dnsdist::Protocol::DoTCP || protocol == dnsdist::Protocol::DNSCryptTCP) {
+    return DnstapMessage::ProtocolType::DoTCP;
+  }
+  else if (protocol == dnsdist::Protocol::DoT) {
+    return DnstapMessage::ProtocolType::DoT;
+  }
+  else if (protocol == dnsdist::Protocol::DoH) {
+    return DnstapMessage::ProtocolType::DoH;
+  }
+  throw std::runtime_error("Unhandled protocol for dnstap: " + protocol.toPrettyString());
 }
 
 class DnstapLogAction : public DNSAction, public boost::noncopyable
index e9c66ed4f234f7cd806db863aa7408abc2de3bc0..89b65ecac53c223ccb6e2464e6dcc914bc5740be 100644 (file)
@@ -21,6 +21,7 @@
  */
 
 #include <algorithm>
+#include <stdexcept>
 
 #include "dnsdist-protocols.hh"
 
@@ -42,42 +43,38 @@ static const std::vector<std::string> prettyNames = {
   "DNS over TLS",
   "DNS over HTTPS"};
 
-Protocol::Protocol(uint8_t protocol) :
+Protocol::Protocol(Protocol::typeenum protocol) :
   d_protocol(protocol)
 {
+  if (protocol >= names.size()) {
+    throw std::runtime_error("Unknown protocol: '" + std::to_string(protocol) + "'");
+  }
 }
-Protocol& Protocol::operator=(const char* s)
-{
-  std::string str(s);
-  d_protocol = Protocol::fromString(str);
 
-  return *this;
-}
-Protocol& Protocol::operator=(const std::string& s)
+Protocol::Protocol(const std::string& s)
 {
-  d_protocol = Protocol::fromString(s);
+  const auto& it = std::find(names.begin(), names.end(), s);
+  if (it == names.end()) {
+    throw std::runtime_error("Unknown protocol name: '" + s + "'");
+  }
 
-  return *this;
+  auto index = std::distance(names.begin(), it);
+  d_protocol = static_cast<Protocol::typeenum>(index);
 }
-Protocol::operator uint8_t() const
+
+bool Protocol::operator==(Protocol::typeenum type) const
 {
-  return d_protocol;
+  return d_protocol == type;
 }
+
 const std::string& Protocol::toString() const
 {
-  return names.at(static_cast<int>(d_protocol));
+  return names.at(static_cast<uint8_t>(d_protocol));
 }
+
 const std::string& Protocol::toPrettyString() const
 {
-  return prettyNames.at(static_cast<int>(d_protocol));
+  return prettyNames.at(static_cast<uint8_t>(d_protocol));
 }
-uint8_t Protocol::fromString(const std::string& s)
-{
-  const auto& it = std::find(names.begin(), names.end(), s);
-  if (it != names.end()) {
-    return std::distance(names.begin(), it);
-  }
 
-  return 0;
-}
 }
index 2aef0a8f94514f9492fe93b4259eb9a6b9bfebe3..a90e6e0c5602a18fe9c30b7d2227e6bdf22c2215 100644 (file)
@@ -29,13 +29,6 @@ namespace dnsdist
 class Protocol
 {
 public:
-  Protocol(uint8_t protocol = 0);
-  Protocol& operator=(const char*);
-  Protocol& operator=(const std::string&);
-  operator uint8_t() const;
-  const std::string& toString() const;
-  const std::string& toPrettyString() const;
-
   enum typeenum : uint8_t
   {
     DoUDP,
@@ -46,8 +39,15 @@ public:
     DoH
   };
 
+  Protocol(typeenum protocol = DoUDP);
+  explicit Protocol(const std::string& protocol);
+
+  bool operator==(typeenum) const;
+
+  const std::string& toString() const;
+  const std::string& toPrettyString() const;
+
 private:
-  static uint8_t fromString(const std::string& s);
-  uint8_t d_protocol;
+  typeenum d_protocol;
 };
 }
index 81221d1f8f7567c8ab8c1193f93ffa2bb66a7d60..f0c3caacac466e9380a509be8d3db0d5f69d58b4 100644 (file)
@@ -138,9 +138,7 @@ size_t Rings::loadFromFile(const std::string& filepath, const struct timespec& n
 
     ComboAddress from(parts.at(idx++));
     ComboAddress to;
-    dnsdist::Protocol protocol;
-
-    protocol = parts.at(idx++);
+    dnsdist::Protocol protocol(parts.at(idx++));
     if (isResponse) {
       to = ComboAddress(parts.at(idx++));
     }
index 87c8c94211ee91e1f830fb004459b4f03249a2c9..adff0b6ac9459e7bd13b9de2669ebd0ca3d5c760 100644 (file)
@@ -43,7 +43,7 @@ struct Rings {
     uint16_t size;
     uint16_t qtype;
     // incoming protocol
-    uint8_t protocol;
+    dnsdist::Protocol protocol;
   };
   struct Response
   {
@@ -56,7 +56,7 @@ struct Rings {
     unsigned int size;
     uint16_t qtype;
     // outgoing protocol
-    uint8_t protocol;
+    dnsdist::Protocol protocol;
   };
 
   struct Shard
@@ -118,7 +118,7 @@ struct Rings {
     return d_nbResponseEntries;
   }
 
-  void insertQuery(const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, uint16_t size, const struct dnsheader& dh, uint8_t protocol)
+  void insertQuery(const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, uint16_t size, const struct dnsheader& dh, dnsdist::Protocol protocol)
   {
     for (size_t idx = 0; idx < d_nbLockTries; idx++) {
       auto& shard = getOneShard();
@@ -141,7 +141,7 @@ struct Rings {
     insertQueryLocked(*lock, when, requestor, name, qtype, size, dh, protocol);
   }
 
-  void insertResponse(const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, unsigned int usec, unsigned int size, const struct dnsheader& dh, const ComboAddress& backend, uint8_t protocol)
+  void insertResponse(const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, unsigned int usec, unsigned int size, const struct dnsheader& dh, const ComboAddress& backend, dnsdist::Protocol protocol)
   {
     for (size_t idx = 0; idx < d_nbLockTries; idx++) {
       auto& shard = getOneShard();
@@ -201,7 +201,7 @@ private:
     return d_shards[getShardId()];
   }
 
-  void insertQueryLocked(boost::circular_buffer<Query>& ring, const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, uint16_t size, const struct dnsheader& dh, uint8_t protocol)
+  void insertQueryLocked(boost::circular_buffer<Query>& ring, const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, uint16_t size, const struct dnsheader& dh, dnsdist::Protocol protocol)
   {
     if (!ring.full()) {
       d_nbQueryEntries++;
@@ -209,7 +209,7 @@ private:
     ring.push_back({requestor, name, when, dh, size, qtype, protocol});
   }
 
-  void insertResponseLocked(boost::circular_buffer<Response>& ring, const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, unsigned int usec, unsigned int size, const struct dnsheader& dh, const ComboAddress& backend, uint8_t protocol)
+  void insertResponseLocked(boost::circular_buffer<Response>& ring, const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, unsigned int usec, unsigned int size, const struct dnsheader& dh, const ComboAddress& backend, dnsdist::Protocol protocol)
   {
     if (!ring.full()) {
       d_nbResponseEntries++;
index 50d9f620c4aa6f9269e5899f17274c85d4a1497b..5684922cc77799691150aac7e085fb981a8d049c 100644 (file)
@@ -551,7 +551,7 @@ static void pickBackendSocketsReadyForReceiving(const std::shared_ptr<Downstream
   (*state->mplexer.lock())->getAvailableFDs(ready, 1000);
 }
 
-void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, uint8_t protocol)
+void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol protocol)
 {
   struct timespec ts;
   gettime(&ts);
index e046707b14f4c9f6b275f0babfb956c3f91c445a..fdbb25742318fb6bda5c7abaa1e30e73938595d4 100644 (file)
@@ -1072,6 +1072,6 @@ void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname);
 
 int pickBackendSocketForSending(std::shared_ptr<DownstreamState>& state);
 ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const PacketBuffer& request, bool healthCheck = false);
-void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, uint8_t protocol);
+void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol protocol);
 
 void carbonDumpThread();
index 91bd26c39a84733d8682fbe9d9e9fd632b19e719..d5ad87afb555cd37d96968ee567f2cf02c9068f8 100644 (file)
@@ -24,8 +24,8 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_QueryRate) {
   ComboAddress backend("192.0.2.42");
   uint16_t qtype = QType::AAAA;
   uint16_t size = 42;
-  uint8_t protocol = dnsdist::Protocol::DoUDP;
-  uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP;
   unsigned int responseTime = 0;
   struct timespec now;
   gettime(&now);
@@ -162,8 +162,8 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_QueryRate_responses) {
   ComboAddress backend("192.0.2.42");
   uint16_t qtype = QType::AAAA;
   uint16_t size = 42;
-  uint8_t protocol = dnsdist::Protocol::DoUDP;
-  uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP;
   unsigned int responseTime = 0;
   struct timespec now;
   gettime(&now);
@@ -221,7 +221,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_QTypeRate) {
   ComboAddress requestor2("192.0.2.2");
   uint16_t qtype = QType::AAAA;
   uint16_t size = 42;
-  uint8_t protocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP;
   struct timespec now;
   gettime(&now);
   NetmaskTree<DynBlock> emptyNMG;
@@ -309,7 +309,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_RCodeRate) {
   ComboAddress backend("192.0.2.42");
   uint16_t qtype = QType::AAAA;
   uint16_t size = 42;
-  uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP;
   unsigned int responseTime = 100 * 1000; /* 100ms */
   struct timespec now;
   gettime(&now);
@@ -401,7 +401,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_RCodeRatio) {
   ComboAddress backend("192.0.2.42");
   uint16_t qtype = QType::AAAA;
   uint16_t size = 42;
-  uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP;
   unsigned int responseTime = 100 * 1000; /* 100ms */
   struct timespec now;
   gettime(&now);
@@ -519,7 +519,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_ResponseByteRate) {
   ComboAddress backend("192.0.2.42");
   uint16_t qtype = QType::AAAA;
   uint16_t size = 100;
-  uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP;
   unsigned int responseTime = 100 * 1000; /* 100ms */
   struct timespec now;
   gettime(&now);
@@ -591,7 +591,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_Warning) {
   ComboAddress requestor2("192.0.2.2");
   uint16_t qtype = QType::AAAA;
   uint16_t size = 42;
-  uint8_t protocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP;
   struct timespec now;
   gettime(&now);
   NetmaskTree<DynBlock> emptyNMG;
@@ -750,7 +750,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_Ranges) {
   ComboAddress requestor2("192.0.2.42");
   uint16_t qtype = QType::AAAA;
   uint16_t size = 42;
-  uint8_t protocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP;
   struct timespec now;
   gettime(&now);
   NetmaskTree<DynBlock> emptyNMG;
@@ -805,8 +805,8 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesMetricsCache_GetTopN) {
   DNSName qname("rings.powerdns.com.");
   uint16_t qtype = QType::AAAA;
   uint16_t size = 42;
-  uint8_t protocol = dnsdist::Protocol::DoUDP;
-  uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP;
   struct timespec now;
   gettime(&now);
   NetmaskTree<DynBlock> emptyNMG;
index dac459908e1a2d4304324cfbc81641fb57e752c6..db665bb5fdf6bdc19d880b856ec91c2a595ed19f 100644 (file)
@@ -30,8 +30,8 @@ static void test_ring(size_t maxEntries, size_t numberOfShards, size_t nbLockTri
   ComboAddress requestor2("192.0.2.2");
   uint16_t qtype = QType::AAAA;
   uint16_t size = 42;
-  uint8_t protocol = dnsdist::Protocol::DoUDP;
-  uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP;
   struct timespec now;
   gettime(&now);
 
@@ -201,8 +201,8 @@ BOOST_AUTO_TEST_CASE(test_Rings_Threaded) {
   unsigned int latency = 100;
   uint16_t qtype = QType::AAAA;
   uint16_t size = 42;
-  uint8_t protocol = dnsdist::Protocol::DoUDP;
-  uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP;
 
   Rings rings(numberOfEntries, numberOfShards, lockAttempts, true);
   Rings::Query query({requestor, qname, now, dh, size, qtype, protocol});
index a0e5682d5ab3019735eda70981e74bf7b9a01335..fd9c029df1ece48e1fde0b27e1b78f3c9b0c337f 100644 (file)
@@ -61,7 +61,7 @@ uint64_t getLatencyCount(const std::string&)
   return 0;
 }
 
-void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, uint8_t protocol)
+void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol protocol)
 {
 }