]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Proper retry token generation and validation for DoQ
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 20 Sep 2023 14:54:37 +0000 (16:54 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 9 Oct 2023 11:37:54 +0000 (13:37 +0200)
pdns/dnsdistdist/doq.cc
pdns/sodcrypto.cc
pdns/sodcrypto.hh

index 6acc201279d0798382c14cffcbc2ef8a228a55d8..2df6b6a40405f388cc8233047072589ba759c80b 100644 (file)
@@ -32,6 +32,9 @@
 #include "threadname.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsdist-proxy-protocol.hh"
+#include "sodcrypto.hh"
+
+static std::string s_quicRetryTokenKey = newKey(false);
 
 static void sendBackDOQUnit(DOQUnitUniquePtr&& du, const char* description);
 struct DOQServerConfig
@@ -69,7 +72,6 @@ struct DOQServerConfig
 
 static constexpr size_t MAX_DATAGRAM_SIZE = 1350;
 static constexpr size_t LOCAL_CONN_ID_LEN = 16;
-static constexpr size_t TOKEN_LEN = 32; /* check if this needs to be authenticated, via HMAC-SHA256, for example, see rfc9000 section 8.1.1 */
 
 static std::map<PacketBuffer, Connection> s_connections;
 
@@ -287,35 +289,70 @@ static std::optional<PacketBuffer> getCID()
   return buffer;
 }
 
+static constexpr size_t MAX_TOKEN_LEN = std::tuple_size<decltype(SodiumNonce::value)>{} /* nonce */ + sizeof(uint64_t) /* TTD */ + 16 /* IPv6 */ + QUICHE_MAX_CONN_ID_LEN;
+
 static PacketBuffer mintToken(const PacketBuffer& dcid, const ComboAddress& peer)
 {
-  // FIXME: really check whether this needs to be authenticated, via HMAC for example
-  const std::array keyword = {'q', 'u', 'i', 'c', 'h', 'e'};
-  auto addrBytes = peer.toByteString();
-  PacketBuffer token;
-  token.reserve(keyword.size() + addrBytes.size() + dcid.size());
-  token.insert(token.end(), keyword.begin(), keyword.end());
-  token.insert(token.end(), addrBytes.begin(), addrBytes.end());
-  token.insert(token.end(), dcid.begin(), dcid.end());
-  return token;
+  try {
+    SodiumNonce nonce;
+    nonce.init();
+
+    const auto addrBytes = peer.toByteString();
+    // this token will be valid for 60s
+    const uint64_t ttd = time(nullptr) + 60U;
+    PacketBuffer plainTextToken;
+    plainTextToken.reserve(sizeof(ttd) + addrBytes.size() + dcid.size());
+    plainTextToken.insert(plainTextToken.end(), reinterpret_cast<const char*>(&ttd), reinterpret_cast<const char*>(&ttd) + sizeof(ttd));
+    plainTextToken.insert(plainTextToken.end(), addrBytes.begin(), addrBytes.end());
+    plainTextToken.insert(plainTextToken.end(), dcid.begin(), dcid.end());
+    const auto encryptedToken = sodEncryptSym(std::string_view(reinterpret_cast<const char*>(plainTextToken.data()), plainTextToken.size()), s_quicRetryTokenKey, nonce, false);
+    // a bit sad, let's see if we can do better later
+    auto encryptedTokenPacket = PacketBuffer(encryptedToken.begin(), encryptedToken.end());
+    encryptedTokenPacket.insert(encryptedTokenPacket.begin(), nonce.value.begin(), nonce.value.end());
+    return encryptedTokenPacket;
+  }
+  catch (const std::exception& exp) {
+    vinfolog("Error while minting DoQ token: %s", exp.what());
+    throw;
+  }
 }
 
 // returns the original destination ID if the token is valid, nothing otherwise
 static std::optional<PacketBuffer> validateToken(const PacketBuffer& token, const PacketBuffer& dcid, const ComboAddress& peer)
 {
-  const std::array keyword = {'q', 'u', 'i', 'c', 'h', 'e'};
-  auto addrBytes = peer.toByteString();
-  auto minimumSize = keyword.size() + addrBytes.size();
-  if (token.size() <= minimumSize) {
-    return std::nullopt;
-  }
-  if (std::memcmp(&*keyword.begin(), &*token.begin(), keyword.size()) != 0) {
-    return std::nullopt;
+  try {
+    SodiumNonce nonce;
+    auto addrBytes = peer.toByteString();
+    const uint64_t now = time(nullptr);
+    const auto minimumSize = nonce.value.size() + sizeof(now) + addrBytes.size();
+    if (token.size() <= minimumSize) {
+      return std::nullopt;
+    }
+
+    memcpy(nonce.value.data(), token.data(), nonce.value.size());
+
+    auto cipher = std::string_view(reinterpret_cast<const char*>(&token.at(nonce.value.size())), token.size() - nonce.value.size());
+    auto plainText = sodDecryptSym(cipher, s_quicRetryTokenKey, nonce, false);
+
+    if (plainText.size() <= sizeof(now) + addrBytes.size()) {
+      return std::nullopt;
+    }
+
+    uint64_t ttd{0};
+    memcpy(&ttd, plainText.data(), sizeof(ttd));
+    if (ttd < now) {
+      return std::nullopt;
+    }
+
+    if (std::memcmp(&plainText.at(sizeof(ttd)), &*addrBytes.begin(), addrBytes.size()) != 0) {
+      return std::nullopt;
+    }
+    return PacketBuffer(plainText.begin() + (sizeof(ttd) + addrBytes.size()), plainText.end());
   }
-  if (std::memcmp(&token.at(keyword.size()), &*addrBytes.begin(), addrBytes.size()) != 0) {
+  catch (const std::exception& exp) {
+    vinfolog("Error while validating DoQ token: %s", exp.what());
     return std::nullopt;
   }
-  return PacketBuffer(token.begin() + keyword.size() + addrBytes.size(), token.end());
 }
 
 static void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, uint32_t version)
@@ -607,8 +644,8 @@ static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const
 
     processDOQQuery(std::move(du));
   }
-  catch (const std::exception& e) {
-    vinfolog("Had error parsing DoQ DNS packet from %s: %s", remote.toStringWithPort(), e.what());
+  catch (const std::exception& exp) {
+    vinfolog("Had error parsing DoQ DNS packet from %s: %s", remote.toStringWithPort(), exp.what());
   }
 }
 
@@ -670,7 +707,7 @@ void doqThread(ClientState* cs)
         size_t scid_len = scid.size();
         std::array<uint8_t, QUICHE_MAX_CONN_ID_LEN> dcid;
         size_t dcid_len = dcid.size();
-        std::array<uint8_t, TOKEN_LEN> token;
+        std::array<uint8_t, MAX_TOKEN_LEN> token;
         size_t token_len = token.size();
 
         auto res = quiche_header_info(reinterpret_cast<const uint8_t*>(bufferStr.data()), bufferStr.size(), LOCAL_CONN_ID_LEN,
index 9e9f91ce203d6bc6c144801bb795868811640984..d6e9e7e906e41eaa6a0255d525d7d7db24fce20b 100644 (file)
@@ -30,7 +30,7 @@
 
 #ifdef HAVE_LIBSODIUM
 
-string newKey()
+string newKey(bool base64Encoded)
 {
   std::string key;
   key.resize(crypto_secretbox_KEYBYTES);
@@ -38,6 +38,9 @@ string newKey()
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
   randombytes_buf(reinterpret_cast<unsigned char*>(key.data()), key.size());
 
+  if (!base64Encoded) {
+    return key;
+  }
   return "\"" + Base64Encode(key) + "\"";
 }
 
@@ -46,10 +49,10 @@ bool sodIsValidKey(const std::string& key)
   return key.size() == crypto_secretbox_KEYBYTES;
 }
 
-std::string sodEncryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce)
+std::string sodEncryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce)
 {
   if (!sodIsValidKey(key)) {
-    throw std::runtime_error("Invalid encryption key of size " + std::to_string(key.size()) + ", use setKey() to set a valid key");
+    throw std::runtime_error("Invalid encryption key of size " + std::to_string(key.size()) + " (" + std::to_string(crypto_secretbox_KEYBYTES) + " expected), use setKey() to set a valid key");
   }
 
   std::string ciphertext;
@@ -63,11 +66,14 @@ std::string sodEncryptSym(const std::string_view& msg, const std::string& key, S
                         // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
                         reinterpret_cast<const unsigned char*>(key.data()));
 
-  nonce.increment();
+  if (incrementNonce) {
+    nonce.increment();
+  }
+
   return ciphertext;
 }
 
-std::string sodDecryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce)
+std::string sodDecryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce)
 {
   std::string decrypted;
 
@@ -92,7 +98,10 @@ std::string sodDecryptSym(const std::string_view& msg, const std::string& key, S
     throw std::runtime_error("Could not decrypt message, please check that the key configured with setKey() is correct");
   }
 
-  nonce.increment();
+  if (incrementNonce) {
+    nonce.increment();
+  }
+
   return decrypted;
 }
 
@@ -129,16 +138,16 @@ void SodiumNonce::increment()
 {
 }
 
-std::string sodEncryptSym(const std::string& msg, const std::string& key, SodiumNonce& nonce)
+std::string sodEncryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce)
 {
-  return msg;
+  return std::string(msg);
 }
-std::string sodDecryptSym(const std::string& msg, const std::string& key, SodiumNonce& nonce)
+std::string sodDecryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce)
 {
-  return msg;
+  return std::string(msg);
 }
 
-string newKey()
+string newKey(bool base64Encoded)
 {
   return "\"plaintext\"";
 }
index 8538702fc8114b8dcc21e097c5d25183b04ae66b..776248cc99feba475dff967f3a919c75e9f6e51d 100644 (file)
@@ -51,7 +51,7 @@ struct SodiumNonce
 };
 
 std::string newKeypair();
-std::string sodEncryptSym(const std::string_view& msg, const std::string& key, SodiumNonce&);
-std::string sodDecryptSym(const std::string_view& msg, const std::string& key, SodiumNonce&);
-std::string newKey();
+std::string sodEncryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce = true);
+std::string sodDecryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce = true);
+std::string newKey(bool base64Encoded = true);
 bool sodIsValidKey(const std::string& key);