#include "threadname.hh"
#include "dnsdist-ecs.hh"
#include "dnsdist-proxy-protocol.hh"
+#include "sodcrypto.hh"
+
+static std::string s_quicRetryTokenKey = newKey(false);
static void sendBackDOQUnit(DOQUnitUniquePtr&& du, const char* description);
struct DOQServerConfig
static constexpr size_t MAX_DATAGRAM_SIZE = 1350;
static constexpr size_t LOCAL_CONN_ID_LEN = 16;
-static constexpr size_t TOKEN_LEN = 32; /* check if this needs to be authenticated, via HMAC-SHA256, for example, see rfc9000 section 8.1.1 */
static std::map<PacketBuffer, Connection> s_connections;
return buffer;
}
+static constexpr size_t MAX_TOKEN_LEN = std::tuple_size<decltype(SodiumNonce::value)>{} /* nonce */ + sizeof(uint64_t) /* TTD */ + 16 /* IPv6 */ + QUICHE_MAX_CONN_ID_LEN;
+
static PacketBuffer mintToken(const PacketBuffer& dcid, const ComboAddress& peer)
{
- // FIXME: really check whether this needs to be authenticated, via HMAC for example
- const std::array keyword = {'q', 'u', 'i', 'c', 'h', 'e'};
- auto addrBytes = peer.toByteString();
- PacketBuffer token;
- token.reserve(keyword.size() + addrBytes.size() + dcid.size());
- token.insert(token.end(), keyword.begin(), keyword.end());
- token.insert(token.end(), addrBytes.begin(), addrBytes.end());
- token.insert(token.end(), dcid.begin(), dcid.end());
- return token;
+ try {
+ SodiumNonce nonce;
+ nonce.init();
+
+ const auto addrBytes = peer.toByteString();
+ // this token will be valid for 60s
+ const uint64_t ttd = time(nullptr) + 60U;
+ PacketBuffer plainTextToken;
+ plainTextToken.reserve(sizeof(ttd) + addrBytes.size() + dcid.size());
+ plainTextToken.insert(plainTextToken.end(), reinterpret_cast<const char*>(&ttd), reinterpret_cast<const char*>(&ttd) + sizeof(ttd));
+ plainTextToken.insert(plainTextToken.end(), addrBytes.begin(), addrBytes.end());
+ plainTextToken.insert(plainTextToken.end(), dcid.begin(), dcid.end());
+ const auto encryptedToken = sodEncryptSym(std::string_view(reinterpret_cast<const char*>(plainTextToken.data()), plainTextToken.size()), s_quicRetryTokenKey, nonce, false);
+ // a bit sad, let's see if we can do better later
+ auto encryptedTokenPacket = PacketBuffer(encryptedToken.begin(), encryptedToken.end());
+ encryptedTokenPacket.insert(encryptedTokenPacket.begin(), nonce.value.begin(), nonce.value.end());
+ return encryptedTokenPacket;
+ }
+ catch (const std::exception& exp) {
+ vinfolog("Error while minting DoQ token: %s", exp.what());
+ throw;
+ }
}
// returns the original destination ID if the token is valid, nothing otherwise
static std::optional<PacketBuffer> validateToken(const PacketBuffer& token, const PacketBuffer& dcid, const ComboAddress& peer)
{
- const std::array keyword = {'q', 'u', 'i', 'c', 'h', 'e'};
- auto addrBytes = peer.toByteString();
- auto minimumSize = keyword.size() + addrBytes.size();
- if (token.size() <= minimumSize) {
- return std::nullopt;
- }
- if (std::memcmp(&*keyword.begin(), &*token.begin(), keyword.size()) != 0) {
- return std::nullopt;
+ try {
+ SodiumNonce nonce;
+ auto addrBytes = peer.toByteString();
+ const uint64_t now = time(nullptr);
+ const auto minimumSize = nonce.value.size() + sizeof(now) + addrBytes.size();
+ if (token.size() <= minimumSize) {
+ return std::nullopt;
+ }
+
+ memcpy(nonce.value.data(), token.data(), nonce.value.size());
+
+ auto cipher = std::string_view(reinterpret_cast<const char*>(&token.at(nonce.value.size())), token.size() - nonce.value.size());
+ auto plainText = sodDecryptSym(cipher, s_quicRetryTokenKey, nonce, false);
+
+ if (plainText.size() <= sizeof(now) + addrBytes.size()) {
+ return std::nullopt;
+ }
+
+ uint64_t ttd{0};
+ memcpy(&ttd, plainText.data(), sizeof(ttd));
+ if (ttd < now) {
+ return std::nullopt;
+ }
+
+ if (std::memcmp(&plainText.at(sizeof(ttd)), &*addrBytes.begin(), addrBytes.size()) != 0) {
+ return std::nullopt;
+ }
+ return PacketBuffer(plainText.begin() + (sizeof(ttd) + addrBytes.size()), plainText.end());
}
- if (std::memcmp(&token.at(keyword.size()), &*addrBytes.begin(), addrBytes.size()) != 0) {
+ catch (const std::exception& exp) {
+ vinfolog("Error while validating DoQ token: %s", exp.what());
return std::nullopt;
}
- return PacketBuffer(token.begin() + keyword.size() + addrBytes.size(), token.end());
}
static void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, uint32_t version)
processDOQQuery(std::move(du));
}
- catch (const std::exception& e) {
- vinfolog("Had error parsing DoQ DNS packet from %s: %s", remote.toStringWithPort(), e.what());
+ catch (const std::exception& exp) {
+ vinfolog("Had error parsing DoQ DNS packet from %s: %s", remote.toStringWithPort(), exp.what());
}
}
size_t scid_len = scid.size();
std::array<uint8_t, QUICHE_MAX_CONN_ID_LEN> dcid;
size_t dcid_len = dcid.size();
- std::array<uint8_t, TOKEN_LEN> token;
+ std::array<uint8_t, MAX_TOKEN_LEN> token;
size_t token_len = token.size();
auto res = quiche_header_info(reinterpret_cast<const uint8_t*>(bufferStr.data()), bufferStr.size(), LOCAL_CONN_ID_LEN,
#ifdef HAVE_LIBSODIUM
-string newKey()
+string newKey(bool base64Encoded)
{
std::string key;
key.resize(crypto_secretbox_KEYBYTES);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
randombytes_buf(reinterpret_cast<unsigned char*>(key.data()), key.size());
+ if (!base64Encoded) {
+ return key;
+ }
return "\"" + Base64Encode(key) + "\"";
}
return key.size() == crypto_secretbox_KEYBYTES;
}
-std::string sodEncryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce)
+std::string sodEncryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce)
{
if (!sodIsValidKey(key)) {
- throw std::runtime_error("Invalid encryption key of size " + std::to_string(key.size()) + ", use setKey() to set a valid key");
+ throw std::runtime_error("Invalid encryption key of size " + std::to_string(key.size()) + " (" + std::to_string(crypto_secretbox_KEYBYTES) + " expected), use setKey() to set a valid key");
}
std::string ciphertext;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
reinterpret_cast<const unsigned char*>(key.data()));
- nonce.increment();
+ if (incrementNonce) {
+ nonce.increment();
+ }
+
return ciphertext;
}
-std::string sodDecryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce)
+std::string sodDecryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce)
{
std::string decrypted;
throw std::runtime_error("Could not decrypt message, please check that the key configured with setKey() is correct");
}
- nonce.increment();
+ if (incrementNonce) {
+ nonce.increment();
+ }
+
return decrypted;
}
{
}
-std::string sodEncryptSym(const std::string& msg, const std::string& key, SodiumNonce& nonce)
+std::string sodEncryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce)
{
- return msg;
+ return std::string(msg);
}
-std::string sodDecryptSym(const std::string& msg, const std::string& key, SodiumNonce& nonce)
+std::string sodDecryptSym(const std::string_view& msg, const std::string& key, SodiumNonce& nonce, bool incrementNonce)
{
- return msg;
+ return std::string(msg);
}
-string newKey()
+string newKey(bool base64Encoded)
{
return "\"plaintext\"";
}