From: Remi Gacogne Date: Mon, 11 Oct 2021 13:52:18 +0000 (+0200) Subject: dnsdist: Prevent implicit conversions between dnsdist::Protocol and uint8_t X-Git-Tag: dnsdist-1.7.0-alpha2~25^2^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=426ccc675bb77250bd83dd0ceab308b5ccf8b839;p=thirdparty%2Fpdns.git dnsdist: Prevent implicit conversions between dnsdist::Protocol and uint8_t This commit makes sure that we always use the dnsdist::Protocol type, and that we cannot easily convert to or from a different type by mistake. C++ really dropped the ball by making it impossible to declare methods on the 'enum class' that solved the issue of making regular enums implicitly convertible to int and back, thus making it possible to assign and compare different types of enums together, as well as enums and ints. The result is that we are stuck with declaring classes to hold our methods, along with a lot of plumbing to make sure that we can convert in some cases but not allow obvious mistakes from happening. --- diff --git a/pdns/dnsdist-lua-actions.cc b/pdns/dnsdist-lua-actions.cc index b6e83949fe..239b9f8f85 100644 --- a/pdns/dnsdist-lua-actions.cc +++ b/pdns/dnsdist-lua-actions.cc @@ -1320,25 +1320,19 @@ private: static DnstapMessage::ProtocolType ProtocolToDNSTap(dnsdist::Protocol protocol) { - DnstapMessage::ProtocolType result; - switch (protocol) { - default: - case dnsdist::Protocol::DoUDP: - case dnsdist::Protocol::DNSCryptUDP: - result = DnstapMessage::ProtocolType::DoUDP; - break; - case dnsdist::Protocol::DoTCP: - case dnsdist::Protocol::DNSCryptTCP: - result = DnstapMessage::ProtocolType::DoTCP; - break; - case dnsdist::Protocol::DoT: - result = DnstapMessage::ProtocolType::DoT; - break; - case dnsdist::Protocol::DoH: - result = DnstapMessage::ProtocolType::DoH; - break; - } - return result; + if (protocol == dnsdist::Protocol::DoUDP || protocol == dnsdist::Protocol::DNSCryptUDP) { + return DnstapMessage::ProtocolType::DoUDP; + } + else if (protocol == dnsdist::Protocol::DoTCP || protocol == dnsdist::Protocol::DNSCryptTCP) { + return DnstapMessage::ProtocolType::DoTCP; + } + else if (protocol == dnsdist::Protocol::DoT) { + return DnstapMessage::ProtocolType::DoT; + } + else if (protocol == dnsdist::Protocol::DoH) { + return DnstapMessage::ProtocolType::DoH; + } + throw std::runtime_error("Unhandled protocol for dnstap: " + protocol.toPrettyString()); } class DnstapLogAction : public DNSAction, public boost::noncopyable diff --git a/pdns/dnsdist-protocols.cc b/pdns/dnsdist-protocols.cc index e9c66ed4f2..89b65ecac5 100644 --- a/pdns/dnsdist-protocols.cc +++ b/pdns/dnsdist-protocols.cc @@ -21,6 +21,7 @@ */ #include +#include #include "dnsdist-protocols.hh" @@ -42,42 +43,38 @@ static const std::vector prettyNames = { "DNS over TLS", "DNS over HTTPS"}; -Protocol::Protocol(uint8_t protocol) : +Protocol::Protocol(Protocol::typeenum protocol) : d_protocol(protocol) { + if (protocol >= names.size()) { + throw std::runtime_error("Unknown protocol: '" + std::to_string(protocol) + "'"); + } } -Protocol& Protocol::operator=(const char* s) -{ - std::string str(s); - d_protocol = Protocol::fromString(str); - return *this; -} -Protocol& Protocol::operator=(const std::string& s) +Protocol::Protocol(const std::string& s) { - d_protocol = Protocol::fromString(s); + const auto& it = std::find(names.begin(), names.end(), s); + if (it == names.end()) { + throw std::runtime_error("Unknown protocol name: '" + s + "'"); + } - return *this; + auto index = std::distance(names.begin(), it); + d_protocol = static_cast(index); } -Protocol::operator uint8_t() const + +bool Protocol::operator==(Protocol::typeenum type) const { - return d_protocol; + return d_protocol == type; } + const std::string& Protocol::toString() const { - return names.at(static_cast(d_protocol)); + return names.at(static_cast(d_protocol)); } + const std::string& Protocol::toPrettyString() const { - return prettyNames.at(static_cast(d_protocol)); + return prettyNames.at(static_cast(d_protocol)); } -uint8_t Protocol::fromString(const std::string& s) -{ - const auto& it = std::find(names.begin(), names.end(), s); - if (it != names.end()) { - return std::distance(names.begin(), it); - } - return 0; -} } diff --git a/pdns/dnsdist-protocols.hh b/pdns/dnsdist-protocols.hh index 2aef0a8f94..a90e6e0c56 100644 --- a/pdns/dnsdist-protocols.hh +++ b/pdns/dnsdist-protocols.hh @@ -29,13 +29,6 @@ namespace dnsdist class Protocol { public: - Protocol(uint8_t protocol = 0); - Protocol& operator=(const char*); - Protocol& operator=(const std::string&); - operator uint8_t() const; - const std::string& toString() const; - const std::string& toPrettyString() const; - enum typeenum : uint8_t { DoUDP, @@ -46,8 +39,15 @@ public: DoH }; + Protocol(typeenum protocol = DoUDP); + explicit Protocol(const std::string& protocol); + + bool operator==(typeenum) const; + + const std::string& toString() const; + const std::string& toPrettyString() const; + private: - static uint8_t fromString(const std::string& s); - uint8_t d_protocol; + typeenum d_protocol; }; } diff --git a/pdns/dnsdist-rings.cc b/pdns/dnsdist-rings.cc index 81221d1f8f..f0c3caacac 100644 --- a/pdns/dnsdist-rings.cc +++ b/pdns/dnsdist-rings.cc @@ -138,9 +138,7 @@ size_t Rings::loadFromFile(const std::string& filepath, const struct timespec& n ComboAddress from(parts.at(idx++)); ComboAddress to; - dnsdist::Protocol protocol; - - protocol = parts.at(idx++); + dnsdist::Protocol protocol(parts.at(idx++)); if (isResponse) { to = ComboAddress(parts.at(idx++)); } diff --git a/pdns/dnsdist-rings.hh b/pdns/dnsdist-rings.hh index 87c8c94211..adff0b6ac9 100644 --- a/pdns/dnsdist-rings.hh +++ b/pdns/dnsdist-rings.hh @@ -43,7 +43,7 @@ struct Rings { uint16_t size; uint16_t qtype; // incoming protocol - uint8_t protocol; + dnsdist::Protocol protocol; }; struct Response { @@ -56,7 +56,7 @@ struct Rings { unsigned int size; uint16_t qtype; // outgoing protocol - uint8_t protocol; + dnsdist::Protocol protocol; }; struct Shard @@ -118,7 +118,7 @@ struct Rings { return d_nbResponseEntries; } - void insertQuery(const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, uint16_t size, const struct dnsheader& dh, uint8_t protocol) + void insertQuery(const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, uint16_t size, const struct dnsheader& dh, dnsdist::Protocol protocol) { for (size_t idx = 0; idx < d_nbLockTries; idx++) { auto& shard = getOneShard(); @@ -141,7 +141,7 @@ struct Rings { insertQueryLocked(*lock, when, requestor, name, qtype, size, dh, protocol); } - void insertResponse(const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, unsigned int usec, unsigned int size, const struct dnsheader& dh, const ComboAddress& backend, uint8_t protocol) + void insertResponse(const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, unsigned int usec, unsigned int size, const struct dnsheader& dh, const ComboAddress& backend, dnsdist::Protocol protocol) { for (size_t idx = 0; idx < d_nbLockTries; idx++) { auto& shard = getOneShard(); @@ -201,7 +201,7 @@ private: return d_shards[getShardId()]; } - void insertQueryLocked(boost::circular_buffer& ring, const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, uint16_t size, const struct dnsheader& dh, uint8_t protocol) + void insertQueryLocked(boost::circular_buffer& ring, const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, uint16_t size, const struct dnsheader& dh, dnsdist::Protocol protocol) { if (!ring.full()) { d_nbQueryEntries++; @@ -209,7 +209,7 @@ private: ring.push_back({requestor, name, when, dh, size, qtype, protocol}); } - void insertResponseLocked(boost::circular_buffer& ring, const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, unsigned int usec, unsigned int size, const struct dnsheader& dh, const ComboAddress& backend, uint8_t protocol) + void insertResponseLocked(boost::circular_buffer& ring, const struct timespec& when, const ComboAddress& requestor, const DNSName& name, uint16_t qtype, unsigned int usec, unsigned int size, const struct dnsheader& dh, const ComboAddress& backend, dnsdist::Protocol protocol) { if (!ring.full()) { d_nbResponseEntries++; diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 50d9f620c4..5684922cc7 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -551,7 +551,7 @@ static void pickBackendSocketsReadyForReceiving(const std::shared_ptrmplexer.lock())->getAvailableFDs(ready, 1000); } -void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, uint8_t protocol) +void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol protocol) { struct timespec ts; gettime(&ts); diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index e046707b14..fdbb257423 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -1072,6 +1072,6 @@ void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname); int pickBackendSocketForSending(std::shared_ptr& state); ssize_t udpClientSendRequestToBackend(const std::shared_ptr& ss, const int sd, const PacketBuffer& request, bool healthCheck = false); -void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, uint8_t protocol); +void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol protocol); void carbonDumpThread(); diff --git a/pdns/dnsdistdist/test-dnsdistdynblocks_hh.cc b/pdns/dnsdistdist/test-dnsdistdynblocks_hh.cc index 91bd26c39a..d5ad87afb5 100644 --- a/pdns/dnsdistdist/test-dnsdistdynblocks_hh.cc +++ b/pdns/dnsdistdist/test-dnsdistdynblocks_hh.cc @@ -24,8 +24,8 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_QueryRate) { ComboAddress backend("192.0.2.42"); uint16_t qtype = QType::AAAA; uint16_t size = 42; - uint8_t protocol = dnsdist::Protocol::DoUDP; - uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP; unsigned int responseTime = 0; struct timespec now; gettime(&now); @@ -162,8 +162,8 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_QueryRate_responses) { ComboAddress backend("192.0.2.42"); uint16_t qtype = QType::AAAA; uint16_t size = 42; - uint8_t protocol = dnsdist::Protocol::DoUDP; - uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP; unsigned int responseTime = 0; struct timespec now; gettime(&now); @@ -221,7 +221,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_QTypeRate) { ComboAddress requestor2("192.0.2.2"); uint16_t qtype = QType::AAAA; uint16_t size = 42; - uint8_t protocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP; struct timespec now; gettime(&now); NetmaskTree emptyNMG; @@ -309,7 +309,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_RCodeRate) { ComboAddress backend("192.0.2.42"); uint16_t qtype = QType::AAAA; uint16_t size = 42; - uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP; unsigned int responseTime = 100 * 1000; /* 100ms */ struct timespec now; gettime(&now); @@ -401,7 +401,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_RCodeRatio) { ComboAddress backend("192.0.2.42"); uint16_t qtype = QType::AAAA; uint16_t size = 42; - uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP; unsigned int responseTime = 100 * 1000; /* 100ms */ struct timespec now; gettime(&now); @@ -519,7 +519,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_ResponseByteRate) { ComboAddress backend("192.0.2.42"); uint16_t qtype = QType::AAAA; uint16_t size = 100; - uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP; unsigned int responseTime = 100 * 1000; /* 100ms */ struct timespec now; gettime(&now); @@ -591,7 +591,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_Warning) { ComboAddress requestor2("192.0.2.2"); uint16_t qtype = QType::AAAA; uint16_t size = 42; - uint8_t protocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP; struct timespec now; gettime(&now); NetmaskTree emptyNMG; @@ -750,7 +750,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_Ranges) { ComboAddress requestor2("192.0.2.42"); uint16_t qtype = QType::AAAA; uint16_t size = 42; - uint8_t protocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP; struct timespec now; gettime(&now); NetmaskTree emptyNMG; @@ -805,8 +805,8 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesMetricsCache_GetTopN) { DNSName qname("rings.powerdns.com."); uint16_t qtype = QType::AAAA; uint16_t size = 42; - uint8_t protocol = dnsdist::Protocol::DoUDP; - uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP; struct timespec now; gettime(&now); NetmaskTree emptyNMG; diff --git a/pdns/dnsdistdist/test-dnsdistrings_cc.cc b/pdns/dnsdistdist/test-dnsdistrings_cc.cc index dac459908e..db665bb5fd 100644 --- a/pdns/dnsdistdist/test-dnsdistrings_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistrings_cc.cc @@ -30,8 +30,8 @@ static void test_ring(size_t maxEntries, size_t numberOfShards, size_t nbLockTri ComboAddress requestor2("192.0.2.2"); uint16_t qtype = QType::AAAA; uint16_t size = 42; - uint8_t protocol = dnsdist::Protocol::DoUDP; - uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP; struct timespec now; gettime(&now); @@ -201,8 +201,8 @@ BOOST_AUTO_TEST_CASE(test_Rings_Threaded) { unsigned int latency = 100; uint16_t qtype = QType::AAAA; uint16_t size = 42; - uint8_t protocol = dnsdist::Protocol::DoUDP; - uint8_t outgoingProtocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP; + dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP; Rings rings(numberOfEntries, numberOfShards, lockAttempts, true); Rings::Query query({requestor, qname, now, dh, size, qtype, protocol}); diff --git a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc index a0e5682d5a..fd9c029df1 100644 --- a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc +++ b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc @@ -61,7 +61,7 @@ uint64_t getLatencyCount(const std::string&) return 0; } -void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, uint8_t protocol) +void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol protocol) { }