return d_tlsContext.loadTicketsKeys(keyFile);
}
+void DOHFrontend::loadTicketsKey(const std::string& key)
+{
+ return d_tlsContext.loadTicketsKey(key);
+}
+
void DOHFrontend::handleTicketsKeyRotation()
{
}
{
}
+ virtual void loadTicketsKey(const std::string& /* key */)
+ {
+ }
+
virtual void handleTicketsKeyRotation()
{
}
virtual void rotateTicketsKey(time_t now);
virtual void loadTicketsKeys(const std::string& keyFile);
+ virtual void loadTicketsKey(const std::string& key);
virtual void handleTicketsKeyRotation();
virtual std::string getNextTicketsKeyRotation() const;
virtual size_t getTicketsKeysCount();
}
});
+ luaCtx.registerFunction<void (std::shared_ptr<DOHFrontend>::*)(const std::string&)>("loadTicketsKey", [](const std::shared_ptr<DOHFrontend>& frontend, const std::string& key) {
+ if (frontend != nullptr) {
+ frontend->loadTicketsKey(key);
+ }
+ });
+
luaCtx.registerFunction<void (std::shared_ptr<DOHFrontend>::*)(const LuaArray<std::shared_ptr<DOHResponseMapEntry>>&)>("setResponsesMap", [](const std::shared_ptr<DOHFrontend>& frontend, const LuaArray<std::shared_ptr<DOHResponseMapEntry>>& map) {
if (frontend != nullptr) {
auto newMap = std::make_shared<std::vector<std::shared_ptr<DOHResponseMapEntry>>>();
}
});
+ luaCtx.registerFunction<void (std::shared_ptr<TLSFrontend>::*)(const std::string&)>("loadTicketsKey", [](std::shared_ptr<TLSFrontend>& frontend, const std::string& key) {
+ if (frontend == nullptr) {
+ return;
+ }
+ auto ctx = frontend->getContext();
+ if (ctx) {
+ ctx->loadTicketsKey(key);
+ }
+ });
+
luaCtx.registerFunction<void (std::shared_ptr<TLSFrontend>::*)()>("reloadCertificates", [](const std::shared_ptr<TLSFrontend>& frontend) {
if (frontend == nullptr) {
return;
file.close();
}
+void OpenSSLTLSTicketKeysRing::loadTicketsKey(const std::string& key)
+{
+ bool keyLoaded = false;
+ try {
+ auto newKey = std::make_shared<OpenSSLTLSTicketKey>(key);
+ addKey(std::move(newKey));
+ keyLoaded = true;
+ }
+ catch (const std::exception& e) {
+ /* if we haven't been able to load at least one key, fail */
+ if (!keyLoaded) {
+ throw;
+ }
+ }
+}
+
void OpenSSLTLSTicketKeysRing::rotateTicketsKey(time_t /* now */)
{
auto newKey = std::make_shared<OpenSSLTLSTicketKey>();
#endif /* HAVE_LIBSODIUM */
}
+OpenSSLTLSTicketKey::OpenSSLTLSTicketKey(const std::string& key)
+{
+ if (key.size() != (sizeof(d_name) + sizeof(d_cipherKey) + sizeof(d_hmacKey))) {
+ throw std::runtime_error("Unable to load a ticket key from given data");
+ }
+ size_t from = 0;
+ memcpy(d_name, &key.at(from), sizeof(d_name));
+ from += sizeof(d_name);
+ memcpy(d_cipherKey, &key.at(from), sizeof(d_cipherKey));
+ from += sizeof(d_cipherKey);
+ memcpy(d_hmacKey, &key.at(from), sizeof(d_hmacKey));
+
+#ifdef HAVE_LIBSODIUM
+ sodium_mlock(d_name, sizeof(d_name));
+ sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
+ sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
+#endif /* HAVE_LIBSODIUM */
+}
+
OpenSSLTLSTicketKey::~OpenSSLTLSTicketKey()
{
#ifdef HAVE_LIBSODIUM
public:
OpenSSLTLSTicketKey();
OpenSSLTLSTicketKey(std::ifstream& file);
+ OpenSSLTLSTicketKey(const std::string& key);
~OpenSSLTLSTicketKey();
bool nameMatches(const unsigned char name[TLS_TICKETS_KEY_NAME_SIZE]) const;
std::shared_ptr<OpenSSLTLSTicketKey> getDecryptionKey(unsigned char name[TLS_TICKETS_KEY_NAME_SIZE], bool& activeKey);
size_t getKeysCount();
void loadTicketsKeys(const std::string& keyFile);
+ void loadTicketsKey(const std::string& key);
void rotateTicketsKey(time_t now);
private:
}
}
+ void loadTicketsKey(const std::string& key) final
+ {
+ d_feContext->d_ticketKeys.loadTicketsKey(key);
+
+ if (d_ticketsKeyRotationDelay > 0) {
+ d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
+ }
+ }
+
size_t getTicketsKeysCount() override
{
return d_feContext->d_ticketKeys.getKeysCount();
safe_memory_lock(d_key.data, d_key.size);
}
- GnuTLSTicketsKey(const std::string& keyFile)
+ GnuTLSTicketsKey(const std::string& key)
+ {
+ /* to be sure we are loading the correct amount of data, which
+ may change between versions, let's generate a correct key first */
+ if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
+ }
+
+ safe_memory_lock(d_key.data, d_key.size);
+ if (key.size() != d_key.size) {
+ safe_memory_release(d_key.data, d_key.size);
+ gnutls_free(d_key.data);
+ d_key.data = nullptr;
+ throw std::runtime_error("Invalid GnuTLS ticket key size");
+ }
+ memcpy(d_key.data, key.data(), key.size());
+ }
+ GnuTLSTicketsKey(std::ifstream& file)
{
/* to be sure we are loading the correct amount of data, which
may change between versions, let's generate a correct key first */
safe_memory_lock(d_key.data, d_key.size);
try {
- ifstream file(keyFile);
file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
if (file.fail()) {
- file.close();
- throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile);
+ throw std::runtime_error("Invalid GnuTLS tickets key file");
}
- file.close();
}
catch (const std::exception& e) {
+ safe_memory_release(d_key.data, d_key.size);
+ gnutls_free(d_key.data);
+ d_key.data = nullptr;
safe_memory_release(d_key.data, d_key.size);
gnutls_free(d_key.data);
d_key.data = nullptr;
auto newKey = std::make_shared<GnuTLSTicketsKey>();
addTicketsKey(now, std::move(newKey));
}
- void loadTicketsKeys(const std::string& file) final
+ void loadTicketsKey(const std::string& key) final
+ {
+ if (!d_enableTickets) {
+ return;
+ }
+
+ auto newKey = std::make_shared<GnuTLSTicketsKey>(key);
+ addTicketsKey(time(nullptr), std::move(newKey));
+ }
+
+ void loadTicketsKeys(const std::string& keyFile) final
{
if (!d_enableTickets) {
return;
}
+ std::ifstream file(keyFile);
auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
addTicketsKey(time(nullptr), std::move(newKey));
+ file.close();
}
size_t getTicketsKeysCount() override
{
throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file");
}
+ virtual void loadTicketsKey(const std::string& /* key */)
+ {
+ throw std::runtime_error("This TLS backend does not have the capability to load a ticket key");
+ }
void handleTicketsKeyRotation(time_t now)
{
if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) {
}
}
+ void loadTicketsKey(const std::string& key)
+ {
+ if (d_ctx != nullptr) {
+ d_ctx->loadTicketsKey(key);
+ }
+ }
+
std::shared_ptr<TLSCtx> getContext()
{
return std::atomic_load_explicit(&d_ctx, std::memory_order_acquire);
import subprocess
import time
import unittest
+import random
+import string
+
from dnsdisttests import DNSDistTest, pickAvailablePort
class TLSTests(object):
cls.startDNSDist()
cls.setUpSockets()
-class TestTLSTicketsKeyAddedCallback(DNSDistTest):
+class TestOpenSSLTLSTicketsKeyCallback(DNSDistTest):
_consoleKey = DNSDistTest.generateConsoleKey()
_consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
newServer{address="127.0.0.1:%s"}
addTLSLocal("127.0.0.1:%s", "%s", "%s", { provider="openssl" })
- callbackCalled = 0
+ lastKey = ""
+ lastKeyLen = 0
+
function keyAddedCallback(key, keyLen)
- callbackCalled = keyLen
+ lastKey = key
+ lastKeyLen = keyLen
end
+ setTicketsKeyAddedHook(keyAddedCallback)
+ """
+
+ def testSetTicketsKey(self):
+ """
+ TLSTicketsKey: test setting new key and the key added hook
+ """
+ newKey = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(80))
+ print("about to send command: `{}`".format("getTLSFrontend(0):setTicketsKey(\"{}\")".format(newKey)))
+ self.sendConsoleCommand("getTLSFrontend(0):loadTicketsKey(\"{}\")".format(newKey))
+ keyLen = self.sendConsoleCommand('lastKeyLen')
+ self.assertEqual(int(keyLen), 80)
+ lastKey = self.sendConsoleCommand('lastKey')
+ self.assertEqual(newKey, lastKey.strip())
+
+class TestGnuTLSTLSTicketsKeyCallback(DNSDistTest):
+ _consoleKey = DNSDistTest.generateConsoleKey()
+ _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+
+ _serverKey = 'server.key'
+ _serverCert = 'server.chain'
+ _serverName = 'tls.tests.dnsdist.org'
+ _caCert = 'ca.pem'
+ _tlsServerPort = pickAvailablePort()
+ _numberOfKeys = 5
+
+ _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey']
+ _config_template = """
+ setKey("%s")
+ controlSocket("127.0.0.1:%s")
+
+ newServer{address="127.0.0.1:%s"}
+ addTLSLocal("127.0.0.1:%s", "%s", "%s", { provider="gnutls" })
+
+ lastKey = ""
+ lastKeyLen = 0
+
+ function keyAddedCallback(key, keyLen)
+ lastKey = key
+ lastKeyLen = keyLen
+ end
+ setTicketsKeyAddedHook(keyAddedCallback)
"""
- def testLuaThreadCounter(self):
+ def testSetTicketsKey(self):
"""
- LuaThread: Test the lua newThread interface
+ TLSTicketsKey: test setting new key and the key added hook
"""
- self.sendConsoleCommand('setTicketsKeyAddedHook(keyAddedCallback)');
- called = self.sendConsoleCommand('callbackCalled')
- self.assertEqual(int(called), 0)
- self.sendConsoleCommand("getTLSFrontend(0):rotateTicketsKey()")
- called = self.sendConsoleCommand('callbackCalled')
- self.assertGreater(int(called), 0)
+
+ newKey = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(64))
+ print("about to send command: `{}`".format("getTLSFrontend(0):setTicketsKey(\"{}\")".format(newKey)))
+ self.sendConsoleCommand("getTLSFrontend(0):loadTicketsKey(\"{}\")".format(newKey))
+ keyLen = self.sendConsoleCommand('lastKeyLen')
+ self.assertEqual(int(keyLen), 64)
+ lastKey = self.sendConsoleCommand('lastKey')
+ self.assertEqual(newKey, lastKey.strip())