From: Charles-Henri Bruyand Date: Fri, 14 Jun 2024 11:53:27 +0000 (+0200) Subject: dnsdist: add the ability to load a given tickets key from lua X-Git-Tag: rec-5.2.0-alpha1~32^2~2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=0de40340927b6943a44728407fd2f1cc93c38e51;p=thirdparty%2Fpdns.git dnsdist: add the ability to load a given tickets key from lua --- diff --git a/pdns/dnsdistdist/dnsdist-doh-common.cc b/pdns/dnsdistdist/dnsdist-doh-common.cc index dcbd183d7c..4ff9ff2900 100644 --- a/pdns/dnsdistdist/dnsdist-doh-common.cc +++ b/pdns/dnsdistdist/dnsdist-doh-common.cc @@ -99,6 +99,11 @@ void DOHFrontend::loadTicketsKeys(const std::string& keyFile) return d_tlsContext.loadTicketsKeys(keyFile); } +void DOHFrontend::loadTicketsKey(const std::string& key) +{ + return d_tlsContext.loadTicketsKey(key); +} + void DOHFrontend::handleTicketsKeyRotation() { } diff --git a/pdns/dnsdistdist/dnsdist-doh-common.hh b/pdns/dnsdistdist/dnsdist-doh-common.hh index 0dc714df23..9d0a466928 100644 --- a/pdns/dnsdistdist/dnsdist-doh-common.hh +++ b/pdns/dnsdistdist/dnsdist-doh-common.hh @@ -166,6 +166,10 @@ struct DOHFrontend { } + virtual void loadTicketsKey(const std::string& /* key */) + { + } + virtual void handleTicketsKeyRotation() { } @@ -187,6 +191,7 @@ struct DOHFrontend virtual void rotateTicketsKey(time_t now); virtual void loadTicketsKeys(const std::string& keyFile); + virtual void loadTicketsKey(const std::string& key); virtual void handleTicketsKeyRotation(); virtual std::string getNextTicketsKeyRotation() const; virtual size_t getTicketsKeysCount(); diff --git a/pdns/dnsdistdist/dnsdist-lua.cc b/pdns/dnsdistdist/dnsdist-lua.cc index 0a0094c68e..a21066715c 100644 --- a/pdns/dnsdistdist/dnsdist-lua.cc +++ b/pdns/dnsdistdist/dnsdist-lua.cc @@ -2990,6 +2990,12 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } }); + luaCtx.registerFunction::*)(const std::string&)>("loadTicketsKey", [](const std::shared_ptr& frontend, const std::string& key) { + if (frontend != nullptr) { + frontend->loadTicketsKey(key); + } + }); + luaCtx.registerFunction::*)(const LuaArray>&)>("setResponsesMap", [](const std::shared_ptr& frontend, const LuaArray>& map) { if (frontend != nullptr) { auto newMap = std::make_shared>>(); @@ -3223,6 +3229,16 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } }); + luaCtx.registerFunction::*)(const std::string&)>("loadTicketsKey", [](std::shared_ptr& frontend, const std::string& key) { + if (frontend == nullptr) { + return; + } + auto ctx = frontend->getContext(); + if (ctx) { + ctx->loadTicketsKey(key); + } + }); + luaCtx.registerFunction::*)()>("reloadCertificates", [](const std::shared_ptr& frontend) { if (frontend == nullptr) { return; diff --git a/pdns/libssl.cc b/pdns/libssl.cc index c81127c1ff..2493f73d1d 100644 --- a/pdns/libssl.cc +++ b/pdns/libssl.cc @@ -688,6 +688,22 @@ void OpenSSLTLSTicketKeysRing::loadTicketsKeys(const std::string& keyFile) file.close(); } +void OpenSSLTLSTicketKeysRing::loadTicketsKey(const std::string& key) +{ + bool keyLoaded = false; + try { + auto newKey = std::make_shared(key); + addKey(std::move(newKey)); + keyLoaded = true; + } + catch (const std::exception& e) { + /* if we haven't been able to load at least one key, fail */ + if (!keyLoaded) { + throw; + } + } +} + void OpenSSLTLSTicketKeysRing::rotateTicketsKey(time_t /* now */) { auto newKey = std::make_shared(); @@ -730,6 +746,25 @@ OpenSSLTLSTicketKey::OpenSSLTLSTicketKey(std::ifstream& file) #endif /* HAVE_LIBSODIUM */ } +OpenSSLTLSTicketKey::OpenSSLTLSTicketKey(const std::string& key) +{ + if (key.size() != (sizeof(d_name) + sizeof(d_cipherKey) + sizeof(d_hmacKey))) { + throw std::runtime_error("Unable to load a ticket key from given data"); + } + size_t from = 0; + memcpy(d_name, &key.at(from), sizeof(d_name)); + from += sizeof(d_name); + memcpy(d_cipherKey, &key.at(from), sizeof(d_cipherKey)); + from += sizeof(d_cipherKey); + memcpy(d_hmacKey, &key.at(from), sizeof(d_hmacKey)); + +#ifdef HAVE_LIBSODIUM + sodium_mlock(d_name, sizeof(d_name)); + sodium_mlock(d_cipherKey, sizeof(d_cipherKey)); + sodium_mlock(d_hmacKey, sizeof(d_hmacKey)); +#endif /* HAVE_LIBSODIUM */ +} + OpenSSLTLSTicketKey::~OpenSSLTLSTicketKey() { #ifdef HAVE_LIBSODIUM diff --git a/pdns/libssl.hh b/pdns/libssl.hh index 96f6dd9a6b..d927f7bfa2 100644 --- a/pdns/libssl.hh +++ b/pdns/libssl.hh @@ -93,6 +93,7 @@ class OpenSSLTLSTicketKey public: OpenSSLTLSTicketKey(); OpenSSLTLSTicketKey(std::ifstream& file); + OpenSSLTLSTicketKey(const std::string& key); ~OpenSSLTLSTicketKey(); bool nameMatches(const unsigned char name[TLS_TICKETS_KEY_NAME_SIZE]) const; @@ -122,6 +123,7 @@ public: std::shared_ptr getDecryptionKey(unsigned char name[TLS_TICKETS_KEY_NAME_SIZE], bool& activeKey); size_t getKeysCount(); void loadTicketsKeys(const std::string& keyFile); + void loadTicketsKey(const std::string& key); void rotateTicketsKey(time_t now); private: diff --git a/pdns/tcpiohandler.cc b/pdns/tcpiohandler.cc index 59382083f7..db07add7be 100644 --- a/pdns/tcpiohandler.cc +++ b/pdns/tcpiohandler.cc @@ -889,6 +889,15 @@ public: } } + void loadTicketsKey(const std::string& key) final + { + d_feContext->d_ticketKeys.loadTicketsKey(key); + + if (d_ticketsKeyRotationDelay > 0) { + d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay; + } + } + size_t getTicketsKeysCount() override { return d_feContext->d_ticketKeys.getKeysCount(); @@ -993,7 +1002,24 @@ public: safe_memory_lock(d_key.data, d_key.size); } - GnuTLSTicketsKey(const std::string& keyFile) + GnuTLSTicketsKey(const std::string& key) + { + /* to be sure we are loading the correct amount of data, which + may change between versions, let's generate a correct key first */ + if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) { + throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context"); + } + + safe_memory_lock(d_key.data, d_key.size); + if (key.size() != d_key.size) { + safe_memory_release(d_key.data, d_key.size); + gnutls_free(d_key.data); + d_key.data = nullptr; + throw std::runtime_error("Invalid GnuTLS ticket key size"); + } + memcpy(d_key.data, key.data(), key.size()); + } + GnuTLSTicketsKey(std::ifstream& file) { /* to be sure we are loading the correct amount of data, which may change between versions, let's generate a correct key first */ @@ -1004,17 +1030,17 @@ public: safe_memory_lock(d_key.data, d_key.size); try { - ifstream file(keyFile); file.read(reinterpret_cast(d_key.data), d_key.size); if (file.fail()) { - file.close(); - throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile); + throw std::runtime_error("Invalid GnuTLS tickets key file"); } - file.close(); } catch (const std::exception& e) { + safe_memory_release(d_key.data, d_key.size); + gnutls_free(d_key.data); + d_key.data = nullptr; safe_memory_release(d_key.data, d_key.size); gnutls_free(d_key.data); d_key.data = nullptr; @@ -1804,14 +1830,26 @@ public: auto newKey = std::make_shared(); addTicketsKey(now, std::move(newKey)); } - void loadTicketsKeys(const std::string& file) final + void loadTicketsKey(const std::string& key) final + { + if (!d_enableTickets) { + return; + } + + auto newKey = std::make_shared(key); + addTicketsKey(time(nullptr), std::move(newKey)); + } + + void loadTicketsKeys(const std::string& keyFile) final { if (!d_enableTickets) { return; } + std::ifstream file(keyFile); auto newKey = std::make_shared(file); addTicketsKey(time(nullptr), std::move(newKey)); + file.close(); } size_t getTicketsKeysCount() override diff --git a/pdns/tcpiohandler.hh b/pdns/tcpiohandler.hh index f24735cc88..742431da4a 100644 --- a/pdns/tcpiohandler.hh +++ b/pdns/tcpiohandler.hh @@ -81,6 +81,10 @@ public: { throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file"); } + virtual void loadTicketsKey(const std::string& /* key */) + { + throw std::runtime_error("This TLS backend does not have the capability to load a ticket key"); + } void handleTicketsKeyRotation(time_t now) { if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) { @@ -163,6 +167,13 @@ public: } } + void loadTicketsKey(const std::string& key) + { + if (d_ctx != nullptr) { + d_ctx->loadTicketsKey(key); + } + } + std::shared_ptr getContext() { return std::atomic_load_explicit(&d_ctx, std::memory_order_acquire); diff --git a/regression-tests.dnsdist/test_TLS.py b/regression-tests.dnsdist/test_TLS.py index 27c2de52fe..f40c18cfe4 100644 --- a/regression-tests.dnsdist/test_TLS.py +++ b/regression-tests.dnsdist/test_TLS.py @@ -6,6 +6,9 @@ import ssl import subprocess import time import unittest +import random +import string + from dnsdisttests import DNSDistTest, pickAvailablePort class TLSTests(object): @@ -517,7 +520,7 @@ class TestPKCSTLSCertificate(DNSDistTest, TLSTests): cls.startDNSDist() cls.setUpSockets() -class TestTLSTicketsKeyAddedCallback(DNSDistTest): +class TestOpenSSLTLSTicketsKeyCallback(DNSDistTest): _consoleKey = DNSDistTest.generateConsoleKey() _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii') @@ -536,20 +539,67 @@ class TestTLSTicketsKeyAddedCallback(DNSDistTest): newServer{address="127.0.0.1:%s"} addTLSLocal("127.0.0.1:%s", "%s", "%s", { provider="openssl" }) - callbackCalled = 0 + lastKey = "" + lastKeyLen = 0 + function keyAddedCallback(key, keyLen) - callbackCalled = keyLen + lastKey = key + lastKeyLen = keyLen end + setTicketsKeyAddedHook(keyAddedCallback) + """ + + def testSetTicketsKey(self): + """ + TLSTicketsKey: test setting new key and the key added hook + """ + newKey = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(80)) + print("about to send command: `{}`".format("getTLSFrontend(0):setTicketsKey(\"{}\")".format(newKey))) + self.sendConsoleCommand("getTLSFrontend(0):loadTicketsKey(\"{}\")".format(newKey)) + keyLen = self.sendConsoleCommand('lastKeyLen') + self.assertEqual(int(keyLen), 80) + lastKey = self.sendConsoleCommand('lastKey') + self.assertEqual(newKey, lastKey.strip()) + +class TestGnuTLSTLSTicketsKeyCallback(DNSDistTest): + _consoleKey = DNSDistTest.generateConsoleKey() + _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii') + + _serverKey = 'server.key' + _serverCert = 'server.chain' + _serverName = 'tls.tests.dnsdist.org' + _caCert = 'ca.pem' + _tlsServerPort = pickAvailablePort() + _numberOfKeys = 5 + + _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey'] + _config_template = """ + setKey("%s") + controlSocket("127.0.0.1:%s") + + newServer{address="127.0.0.1:%s"} + addTLSLocal("127.0.0.1:%s", "%s", "%s", { provider="gnutls" }) + + lastKey = "" + lastKeyLen = 0 + + function keyAddedCallback(key, keyLen) + lastKey = key + lastKeyLen = keyLen + end + setTicketsKeyAddedHook(keyAddedCallback) """ - def testLuaThreadCounter(self): + def testSetTicketsKey(self): """ - LuaThread: Test the lua newThread interface + TLSTicketsKey: test setting new key and the key added hook """ - self.sendConsoleCommand('setTicketsKeyAddedHook(keyAddedCallback)'); - called = self.sendConsoleCommand('callbackCalled') - self.assertEqual(int(called), 0) - self.sendConsoleCommand("getTLSFrontend(0):rotateTicketsKey()") - called = self.sendConsoleCommand('callbackCalled') - self.assertGreater(int(called), 0) + + newKey = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(64)) + print("about to send command: `{}`".format("getTLSFrontend(0):setTicketsKey(\"{}\")".format(newKey))) + self.sendConsoleCommand("getTLSFrontend(0):loadTicketsKey(\"{}\")".format(newKey)) + keyLen = self.sendConsoleCommand('lastKeyLen') + self.assertEqual(int(keyLen), 64) + lastKey = self.sendConsoleCommand('lastKey') + self.assertEqual(newKey, lastKey.strip())