From: Remi Gacogne Date: Wed, 20 Sep 2023 14:54:37 +0000 (+0200) Subject: dnsdist: Proper retry token generation and validation for DoQ X-Git-Tag: rec-5.0.0-alpha2~6^2~40 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fc7c753845599ec8d2921794be4290504008c567;p=thirdparty%2Fpdns.git dnsdist: Proper retry token generation and validation for DoQ --- diff --git a/pdns/dnsdistdist/doq.cc b/pdns/dnsdistdist/doq.cc index 6acc201279..2df6b6a404 100644 --- a/pdns/dnsdistdist/doq.cc +++ b/pdns/dnsdistdist/doq.cc @@ -32,6 +32,9 @@ #include "threadname.hh" #include "dnsdist-ecs.hh" #include "dnsdist-proxy-protocol.hh" +#include "sodcrypto.hh" + +static std::string s_quicRetryTokenKey = newKey(false); static void sendBackDOQUnit(DOQUnitUniquePtr&& du, const char* description); struct DOQServerConfig @@ -69,7 +72,6 @@ struct DOQServerConfig static constexpr size_t MAX_DATAGRAM_SIZE = 1350; static constexpr size_t LOCAL_CONN_ID_LEN = 16; -static constexpr size_t TOKEN_LEN = 32; /* check if this needs to be authenticated, via HMAC-SHA256, for example, see rfc9000 section 8.1.1 */ static std::map s_connections; @@ -287,35 +289,70 @@ static std::optional getCID() return buffer; } +static constexpr size_t MAX_TOKEN_LEN = std::tuple_size{} /* nonce */ + sizeof(uint64_t) /* TTD */ + 16 /* IPv6 */ + QUICHE_MAX_CONN_ID_LEN; + static PacketBuffer mintToken(const PacketBuffer& dcid, const ComboAddress& peer) { - // FIXME: really check whether this needs to be authenticated, via HMAC for example - const std::array keyword = {'q', 'u', 'i', 'c', 'h', 'e'}; - auto addrBytes = peer.toByteString(); - PacketBuffer token; - token.reserve(keyword.size() + addrBytes.size() + dcid.size()); - token.insert(token.end(), keyword.begin(), keyword.end()); - token.insert(token.end(), addrBytes.begin(), addrBytes.end()); - token.insert(token.end(), dcid.begin(), dcid.end()); - return token; + try { + SodiumNonce nonce; + nonce.init(); + + const auto addrBytes = peer.toByteString(); + // this token will be valid for 60s + const uint64_t ttd = time(nullptr) + 60U; + PacketBuffer plainTextToken; + plainTextToken.reserve(sizeof(ttd) + addrBytes.size() + dcid.size()); + plainTextToken.insert(plainTextToken.end(), reinterpret_cast(&ttd), reinterpret_cast(&ttd) + sizeof(ttd)); + plainTextToken.insert(plainTextToken.end(), addrBytes.begin(), addrBytes.end()); + plainTextToken.insert(plainTextToken.end(), dcid.begin(), dcid.end()); + const auto encryptedToken = sodEncryptSym(std::string_view(reinterpret_cast(plainTextToken.data()), plainTextToken.size()), s_quicRetryTokenKey, nonce, false); + // a bit sad, let's see if we can do better later + auto encryptedTokenPacket = PacketBuffer(encryptedToken.begin(), encryptedToken.end()); + encryptedTokenPacket.insert(encryptedTokenPacket.begin(), nonce.value.begin(), nonce.value.end()); + return encryptedTokenPacket; + } + catch (const std::exception& exp) { + vinfolog("Error while minting DoQ token: %s", exp.what()); + throw; + } } // returns the original destination ID if the token is valid, nothing otherwise static std::optional validateToken(const PacketBuffer& token, const PacketBuffer& dcid, const ComboAddress& peer) { - const std::array keyword = {'q', 'u', 'i', 'c', 'h', 'e'}; - auto addrBytes = peer.toByteString(); - auto minimumSize = keyword.size() + addrBytes.size(); - if (token.size() <= minimumSize) { - return std::nullopt; - } - if (std::memcmp(&*keyword.begin(), &*token.begin(), keyword.size()) != 0) { - return std::nullopt; + try { + SodiumNonce nonce; + auto addrBytes = peer.toByteString(); + const uint64_t now = time(nullptr); + const auto minimumSize = nonce.value.size() + sizeof(now) + addrBytes.size(); + if (token.size() <= minimumSize) { + return std::nullopt; + } + + memcpy(nonce.value.data(), token.data(), nonce.value.size()); + + auto cipher = std::string_view(reinterpret_cast(&token.at(nonce.value.size())), token.size() - nonce.value.size()); + auto plainText = sodDecryptSym(cipher, s_quicRetryTokenKey, nonce, false); + + if (plainText.size() <= sizeof(now) + addrBytes.size()) { + return std::nullopt; + } + + uint64_t ttd{0}; + memcpy(&ttd, plainText.data(), sizeof(ttd)); + if (ttd < now) { + return std::nullopt; + } + + if (std::memcmp(&plainText.at(sizeof(ttd)), &*addrBytes.begin(), addrBytes.size()) != 0) { + return std::nullopt; + } + return PacketBuffer(plainText.begin() + (sizeof(ttd) + addrBytes.size()), plainText.end()); } - if (std::memcmp(&token.at(keyword.size()), &*addrBytes.begin(), addrBytes.size()) != 0) { + catch (const std::exception& exp) { + vinfolog("Error while validating DoQ token: %s", exp.what()); return std::nullopt; } - return PacketBuffer(token.begin() + keyword.size() + addrBytes.size(), token.end()); } static void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, uint32_t version) @@ -607,8 +644,8 @@ static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const processDOQQuery(std::move(du)); } - catch (const std::exception& e) { - vinfolog("Had error parsing DoQ DNS packet from %s: %s", remote.toStringWithPort(), e.what()); + catch (const std::exception& exp) { + vinfolog("Had error parsing DoQ DNS packet from %s: %s", remote.toStringWithPort(), exp.what()); } } @@ -670,7 +707,7 @@ void doqThread(ClientState* cs) size_t scid_len = scid.size(); std::array dcid; size_t dcid_len = dcid.size(); - std::array token; + std::array token; size_t token_len = token.size(); auto res = quiche_header_info(reinterpret_cast(bufferStr.data()), bufferStr.size(), LOCAL_CONN_ID_LEN, diff --git a/pdns/sodcrypto.cc b/pdns/sodcrypto.cc index 9e9f91ce20..d6e9e7e906 100644 --- a/pdns/sodcrypto.cc +++ b/pdns/sodcrypto.cc @@ -30,7 +30,7 @@ #ifdef HAVE_LIBSODIUM -string newKey() +string newKey(bool base64Encoded) { std::string key; key.resize(crypto_secretbox_KEYBYTES); @@ -38,6 +38,9 @@ string newKey() // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) randombytes_buf(reinterpret_cast(key.data()), key.size()); + if (!base64Encoded) { + return key; + } return "\"" + Base64Encode(key) + "\""; } @@ -46,10 +49,10 @@ bool sodIsValidKey(const std::string& key) return key.size() == crypto_secretbox_KEYBYTES; } -std::string sodEncryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce) +std::string sodEncryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce) { if (!sodIsValidKey(key)) { - throw std::runtime_error("Invalid encryption key of size " + std::to_string(key.size()) + ", use setKey() to set a valid key"); + throw std::runtime_error("Invalid encryption key of size " + std::to_string(key.size()) + " (" + std::to_string(crypto_secretbox_KEYBYTES) + " expected), use setKey() to set a valid key"); } std::string ciphertext; @@ -63,11 +66,14 @@ std::string sodEncryptSym(const std::string_view& msg, const std::string& key, S // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) reinterpret_cast(key.data())); - nonce.increment(); + if (incrementNonce) { + nonce.increment(); + } + return ciphertext; } -std::string sodDecryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce) +std::string sodDecryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce) { std::string decrypted; @@ -92,7 +98,10 @@ std::string sodDecryptSym(const std::string_view& msg, const std::string& key, S throw std::runtime_error("Could not decrypt message, please check that the key configured with setKey() is correct"); } - nonce.increment(); + if (incrementNonce) { + nonce.increment(); + } + return decrypted; } @@ -129,16 +138,16 @@ void SodiumNonce::increment() { } -std::string sodEncryptSym(const std::string& msg, const std::string& key, SodiumNonce& nonce) +std::string sodEncryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce) { - return msg; + return std::string(msg); } -std::string sodDecryptSym(const std::string& msg, const std::string& key, SodiumNonce& nonce) +std::string sodDecryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce) { - return msg; + return std::string(msg); } -string newKey() +string newKey(bool base64Encoded) { return "\"plaintext\""; } diff --git a/pdns/sodcrypto.hh b/pdns/sodcrypto.hh index 8538702fc8..776248cc99 100644 --- a/pdns/sodcrypto.hh +++ b/pdns/sodcrypto.hh @@ -51,7 +51,7 @@ struct SodiumNonce }; std::string newKeypair(); -std::string sodEncryptSym(const std::string_view& msg, const std::string& key, SodiumNonce&); -std::string sodDecryptSym(const std::string_view& msg, const std::string& key, SodiumNonce&); -std::string newKey(); +std::string sodEncryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce = true); +std::string sodDecryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce = true); +std::string newKey(bool base64Encoded = true); bool sodIsValidKey(const std::string& key);