]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: add the ability to load a given tickets key from lua
authorCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Fri, 14 Jun 2024 11:53:27 +0000 (13:53 +0200)
committerCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Thu, 10 Oct 2024 10:12:42 +0000 (12:12 +0200)
pdns/dnsdistdist/dnsdist-doh-common.cc
pdns/dnsdistdist/dnsdist-doh-common.hh
pdns/dnsdistdist/dnsdist-lua.cc
pdns/libssl.cc
pdns/libssl.hh
pdns/tcpiohandler.cc
pdns/tcpiohandler.hh
regression-tests.dnsdist/test_TLS.py

index dcbd183d7cefac12590bbe41b1a3e18bc9813e02..4ff9ff2900e3c47b46d1dfe484244442a0c7555c 100644 (file)
@@ -99,6 +99,11 @@ void DOHFrontend::loadTicketsKeys(const std::string& keyFile)
   return d_tlsContext.loadTicketsKeys(keyFile);
 }
 
+void DOHFrontend::loadTicketsKey(const std::string& key)
+{
+  return d_tlsContext.loadTicketsKey(key);
+}
+
 void DOHFrontend::handleTicketsKeyRotation()
 {
 }
index 0dc714df23a3c0273f3d6f6c748189cd4f3a0c4d..9d0a4669288ffe3db4aa939b76eaa1ee6b83ade7 100644 (file)
@@ -166,6 +166,10 @@ struct DOHFrontend
   {
   }
 
+  virtual void loadTicketsKey(const std::string& /* key */)
+  {
+  }
+
   virtual void handleTicketsKeyRotation()
   {
   }
@@ -187,6 +191,7 @@ struct DOHFrontend
 
   virtual void rotateTicketsKey(time_t now);
   virtual void loadTicketsKeys(const std::string& keyFile);
+  virtual void loadTicketsKey(const std::string& key);
   virtual void handleTicketsKeyRotation();
   virtual std::string getNextTicketsKeyRotation() const;
   virtual size_t getTicketsKeysCount();
index 0a0094c68ec1ff75d9e088bfc9d31ef0b510012f..a21066715c04c40b22f30d678913c12537e366ee 100644 (file)
@@ -2990,6 +2990,12 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     }
   });
 
+  luaCtx.registerFunction<void (std::shared_ptr<DOHFrontend>::*)(const std::string&)>("loadTicketsKey", [](const std::shared_ptr<DOHFrontend>& frontend, const std::string& key) {
+    if (frontend != nullptr) {
+      frontend->loadTicketsKey(key);
+    }
+  });
+
   luaCtx.registerFunction<void (std::shared_ptr<DOHFrontend>::*)(const LuaArray<std::shared_ptr<DOHResponseMapEntry>>&)>("setResponsesMap", [](const std::shared_ptr<DOHFrontend>& frontend, const LuaArray<std::shared_ptr<DOHResponseMapEntry>>& map) {
     if (frontend != nullptr) {
       auto newMap = std::make_shared<std::vector<std::shared_ptr<DOHResponseMapEntry>>>();
@@ -3223,6 +3229,16 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     }
   });
 
+  luaCtx.registerFunction<void (std::shared_ptr<TLSFrontend>::*)(const std::string&)>("loadTicketsKey", [](std::shared_ptr<TLSFrontend>& frontend, const std::string& key) {
+    if (frontend == nullptr) {
+      return;
+    }
+    auto ctx = frontend->getContext();
+    if (ctx) {
+      ctx->loadTicketsKey(key);
+    }
+  });
+
   luaCtx.registerFunction<void (std::shared_ptr<TLSFrontend>::*)()>("reloadCertificates", [](const std::shared_ptr<TLSFrontend>& frontend) {
     if (frontend == nullptr) {
       return;
index c81127c1ff466b704c9cc9ff401a67e002f51e89..2493f73d1d0b5c3d43eaa57bd1ec230bfc9c56af 100644 (file)
@@ -688,6 +688,22 @@ void OpenSSLTLSTicketKeysRing::loadTicketsKeys(const std::string& keyFile)
   file.close();
 }
 
+void OpenSSLTLSTicketKeysRing::loadTicketsKey(const std::string& key)
+{
+  bool keyLoaded = false;
+  try {
+    auto newKey = std::make_shared<OpenSSLTLSTicketKey>(key);
+    addKey(std::move(newKey));
+    keyLoaded = true;
+  }
+  catch (const std::exception& e) {
+    /* if we haven't been able to load at least one key, fail */
+    if (!keyLoaded) {
+      throw;
+    }
+  }
+}
+
 void OpenSSLTLSTicketKeysRing::rotateTicketsKey(time_t /* now */)
 {
   auto newKey = std::make_shared<OpenSSLTLSTicketKey>();
@@ -730,6 +746,25 @@ OpenSSLTLSTicketKey::OpenSSLTLSTicketKey(std::ifstream& file)
 #endif /* HAVE_LIBSODIUM */
 }
 
+OpenSSLTLSTicketKey::OpenSSLTLSTicketKey(const std::string& key)
+{
+  if (key.size() != (sizeof(d_name) + sizeof(d_cipherKey) + sizeof(d_hmacKey))) {
+    throw std::runtime_error("Unable to load a ticket key from given data");
+  }
+  size_t from = 0;
+  memcpy(d_name, &key.at(from), sizeof(d_name));
+  from += sizeof(d_name);
+  memcpy(d_cipherKey, &key.at(from), sizeof(d_cipherKey));
+  from += sizeof(d_cipherKey);
+  memcpy(d_hmacKey, &key.at(from), sizeof(d_hmacKey));
+
+#ifdef HAVE_LIBSODIUM
+  sodium_mlock(d_name, sizeof(d_name));
+  sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
+  sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
+#endif /* HAVE_LIBSODIUM */
+}
+
 OpenSSLTLSTicketKey::~OpenSSLTLSTicketKey()
 {
 #ifdef HAVE_LIBSODIUM
index 96f6dd9a6b1a5cd05eb47547a49461f0abbaeae4..d927f7bfa2b58f53821ec58ee450a5e3ad6d2236 100644 (file)
@@ -93,6 +93,7 @@ class OpenSSLTLSTicketKey
 public:
   OpenSSLTLSTicketKey();
   OpenSSLTLSTicketKey(std::ifstream& file);
+  OpenSSLTLSTicketKey(const std::string& key);
   ~OpenSSLTLSTicketKey();
 
   bool nameMatches(const unsigned char name[TLS_TICKETS_KEY_NAME_SIZE]) const;
@@ -122,6 +123,7 @@ public:
   std::shared_ptr<OpenSSLTLSTicketKey> getDecryptionKey(unsigned char name[TLS_TICKETS_KEY_NAME_SIZE], bool& activeKey);
   size_t getKeysCount();
   void loadTicketsKeys(const std::string& keyFile);
+  void loadTicketsKey(const std::string& key);
   void rotateTicketsKey(time_t now);
 
 private:
index 59382083f7ee0e735b02f3cac0de2b4ccee37169..db07add7bea81962e91d42ac67761ca12f8f7dc2 100644 (file)
@@ -889,6 +889,15 @@ public:
     }
   }
 
+  void loadTicketsKey(const std::string& key) final
+  {
+    d_feContext->d_ticketKeys.loadTicketsKey(key);
+
+    if (d_ticketsKeyRotationDelay > 0) {
+      d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
+    }
+  }
+
   size_t getTicketsKeysCount() override
   {
     return d_feContext->d_ticketKeys.getKeysCount();
@@ -993,7 +1002,24 @@ public:
     safe_memory_lock(d_key.data, d_key.size);
   }
 
-  GnuTLSTicketsKey(const std::string& keyFile)
+  GnuTLSTicketsKey(const std::string& key)
+  {
+    /* to be sure we are loading the correct amount of data, which
+       may change between versions, let's generate a correct key first */
+    if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
+      throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
+    }
+
+    safe_memory_lock(d_key.data, d_key.size);
+    if (key.size() != d_key.size) {
+      safe_memory_release(d_key.data, d_key.size);
+      gnutls_free(d_key.data);
+      d_key.data = nullptr;
+      throw std::runtime_error("Invalid GnuTLS ticket key size");
+    }
+    memcpy(d_key.data, key.data(), key.size());
+  }
+  GnuTLSTicketsKey(std::ifstream& file)
   {
     /* to be sure we are loading the correct amount of data, which
        may change between versions, let's generate a correct key first */
@@ -1004,17 +1030,17 @@ public:
     safe_memory_lock(d_key.data, d_key.size);
 
     try {
-      ifstream file(keyFile);
       file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
 
       if (file.fail()) {
-        file.close();
-        throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile);
+        throw std::runtime_error("Invalid GnuTLS tickets key file");
       }
 
-      file.close();
     }
     catch (const std::exception& e) {
+      safe_memory_release(d_key.data, d_key.size);
+      gnutls_free(d_key.data);
+      d_key.data = nullptr;
       safe_memory_release(d_key.data, d_key.size);
       gnutls_free(d_key.data);
       d_key.data = nullptr;
@@ -1804,14 +1830,26 @@ public:
     auto newKey = std::make_shared<GnuTLSTicketsKey>();
     addTicketsKey(now, std::move(newKey));
   }
-  void loadTicketsKeys(const std::string& file) final
+  void loadTicketsKey(const std::string& key) final
+  {
+    if (!d_enableTickets) {
+      return;
+    }
+
+    auto newKey = std::make_shared<GnuTLSTicketsKey>(key);
+    addTicketsKey(time(nullptr), std::move(newKey));
+  }
+
+  void loadTicketsKeys(const std::string& keyFile) final
   {
     if (!d_enableTickets) {
       return;
     }
 
+    std::ifstream file(keyFile);
     auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
     addTicketsKey(time(nullptr), std::move(newKey));
+    file.close();
   }
 
   size_t getTicketsKeysCount() override
index f24735cc88d7b25362d656654f5a73288845b7a3..742431da4ad8b5d7304227fc5a67878cbcc1ce1e 100644 (file)
@@ -81,6 +81,10 @@ public:
   {
     throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file");
   }
+  virtual void loadTicketsKey(const std::string& /* key */)
+  {
+    throw std::runtime_error("This TLS backend does not have the capability to load a ticket key");
+  }
   void handleTicketsKeyRotation(time_t now)
   {
     if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) {
@@ -163,6 +167,13 @@ public:
     }
   }
 
+  void loadTicketsKey(const std::string& key)
+  {
+    if (d_ctx != nullptr) {
+      d_ctx->loadTicketsKey(key);
+    }
+  }
+
   std::shared_ptr<TLSCtx> getContext()
   {
     return std::atomic_load_explicit(&d_ctx, std::memory_order_acquire);
index 27c2de52fe19b6ac343eb6da5f0f30830217678f..f40c18cfe450138072bce8454edd07b2379c7ff2 100644 (file)
@@ -6,6 +6,9 @@ import ssl
 import subprocess
 import time
 import unittest
+import random
+import string
+
 from dnsdisttests import DNSDistTest, pickAvailablePort
 
 class TLSTests(object):
@@ -517,7 +520,7 @@ class TestPKCSTLSCertificate(DNSDistTest, TLSTests):
         cls.startDNSDist()
         cls.setUpSockets()
 
-class TestTLSTicketsKeyAddedCallback(DNSDistTest):
+class TestOpenSSLTLSTicketsKeyCallback(DNSDistTest):
     _consoleKey = DNSDistTest.generateConsoleKey()
     _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
 
@@ -536,20 +539,67 @@ class TestTLSTicketsKeyAddedCallback(DNSDistTest):
     newServer{address="127.0.0.1:%s"}
     addTLSLocal("127.0.0.1:%s", "%s", "%s", { provider="openssl" })
 
-    callbackCalled = 0
+    lastKey = ""
+    lastKeyLen = 0
+
     function keyAddedCallback(key, keyLen)
-      callbackCalled = keyLen
+      lastKey = key
+      lastKeyLen = keyLen
     end
+    setTicketsKeyAddedHook(keyAddedCallback)
+    """
+
+    def testSetTicketsKey(self):
+        """
+        TLSTicketsKey: test setting new key and the key added hook
+        """
 
+        newKey = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(80))
+        print("about to send command: `{}`".format("getTLSFrontend(0):setTicketsKey(\"{}\")".format(newKey)))
+        self.sendConsoleCommand("getTLSFrontend(0):loadTicketsKey(\"{}\")".format(newKey))
+        keyLen = self.sendConsoleCommand('lastKeyLen')
+        self.assertEqual(int(keyLen), 80)
+        lastKey = self.sendConsoleCommand('lastKey')
+        self.assertEqual(newKey, lastKey.strip())
+
+class TestGnuTLSTLSTicketsKeyCallback(DNSDistTest):
+    _consoleKey = DNSDistTest.generateConsoleKey()
+    _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+
+    _serverKey = 'server.key'
+    _serverCert = 'server.chain'
+    _serverName = 'tls.tests.dnsdist.org'
+    _caCert = 'ca.pem'
+    _tlsServerPort = pickAvailablePort()
+    _numberOfKeys = 5
+
+    _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey']
+    _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%s")
+
+    newServer{address="127.0.0.1:%s"}
+    addTLSLocal("127.0.0.1:%s", "%s", "%s", { provider="gnutls" })
+
+    lastKey = ""
+    lastKeyLen = 0
+
+    function keyAddedCallback(key, keyLen)
+      lastKey = key
+      lastKeyLen = keyLen
+    end
+    setTicketsKeyAddedHook(keyAddedCallback)
     """
 
-    def testLuaThreadCounter(self):
+    def testSetTicketsKey(self):
         """
-        LuaThread: Test the lua newThread interface
+        TLSTicketsKey: test setting new key and the key added hook
         """
-        self.sendConsoleCommand('setTicketsKeyAddedHook(keyAddedCallback)');
-        called = self.sendConsoleCommand('callbackCalled')
-        self.assertEqual(int(called), 0)
-        self.sendConsoleCommand("getTLSFrontend(0):rotateTicketsKey()")
-        called = self.sendConsoleCommand('callbackCalled')
-        self.assertGreater(int(called), 0)
+
+        newKey = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(64))
+        print("about to send command: `{}`".format("getTLSFrontend(0):setTicketsKey(\"{}\")".format(newKey)))
+        self.sendConsoleCommand("getTLSFrontend(0):loadTicketsKey(\"{}\")".format(newKey))
+        keyLen = self.sendConsoleCommand('lastKeyLen')
+        self.assertEqual(int(keyLen), 64)
+        lastKey = self.sendConsoleCommand('lastKey')
+        self.assertEqual(newKey, lastKey.strip())