]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: add the ability to load a given tickets key from lua
authorCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Fri, 14 Jun 2024 11:53:27 +0000 (13:53 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 25 Nov 2024 15:55:05 +0000 (16:55 +0100)
(cherry picked from commit 0de40340927b6943a44728407fd2f1cc93c38e51)

pdns/dnsdist-doh-common.hh
pdns/dnsdist-lua.cc
pdns/dnsdistdist/dnsdist-doh-common.cc
pdns/libssl.cc
pdns/libssl.hh
pdns/tcpiohandler.cc
pdns/tcpiohandler.hh
regression-tests.dnsdist/test_TLS.py

index 0dc714df23a3c0273f3d6f6c748189cd4f3a0c4d..9d0a4669288ffe3db4aa939b76eaa1ee6b83ade7 100644 (file)
@@ -166,6 +166,10 @@ struct DOHFrontend
   {
   }
 
+  virtual void loadTicketsKey(const std::string& /* key */)
+  {
+  }
+
   virtual void handleTicketsKeyRotation()
   {
   }
@@ -187,6 +191,7 @@ struct DOHFrontend
 
   virtual void rotateTicketsKey(time_t now);
   virtual void loadTicketsKeys(const std::string& keyFile);
+  virtual void loadTicketsKey(const std::string& key);
   virtual void handleTicketsKeyRotation();
   virtual std::string getNextTicketsKeyRotation() const;
   virtual size_t getTicketsKeysCount();
index 411deb9d050591959264f7ea915de2ad86be3d45..2f57d547fe5f2d2fe244ca9ad549f25f1cbf5b62 100644 (file)
@@ -3058,6 +3058,12 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     }
   });
 
+  luaCtx.registerFunction<void (std::shared_ptr<DOHFrontend>::*)(const std::string&)>("loadTicketsKey", [](std::shared_ptr<DOHFrontend> frontend, const std::string& key) {
+    if (frontend != nullptr) {
+      frontend->loadTicketsKey(key);
+    }
+  });
+
   luaCtx.registerFunction<void (std::shared_ptr<DOHFrontend>::*)(const LuaArray<std::shared_ptr<DOHResponseMapEntry>>&)>("setResponsesMap", [](std::shared_ptr<DOHFrontend> frontend, const LuaArray<std::shared_ptr<DOHResponseMapEntry>>& map) {
     if (frontend != nullptr) {
       auto newMap = std::make_shared<std::vector<std::shared_ptr<DOHResponseMapEntry>>>();
@@ -3288,6 +3294,16 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     }
   });
 
+  luaCtx.registerFunction<void (std::shared_ptr<TLSFrontend>::*)(const std::string&)>("loadTicketsKey", [](std::shared_ptr<TLSFrontend>& frontend, const std::string& key) {
+    if (frontend == nullptr) {
+      return;
+    }
+    auto ctx = frontend->getContext();
+    if (ctx) {
+      ctx->loadTicketsKey(key);
+    }
+  });
+
   luaCtx.registerFunction<void (std::shared_ptr<TLSFrontend>::*)()>("reloadCertificates", [](const std::shared_ptr<TLSFrontend>& frontend) {
     if (frontend == nullptr) {
       return;
index dcbd183d7cefac12590bbe41b1a3e18bc9813e02..4ff9ff2900e3c47b46d1dfe484244442a0c7555c 100644 (file)
@@ -99,6 +99,11 @@ void DOHFrontend::loadTicketsKeys(const std::string& keyFile)
   return d_tlsContext.loadTicketsKeys(keyFile);
 }
 
+void DOHFrontend::loadTicketsKey(const std::string& key)
+{
+  return d_tlsContext.loadTicketsKey(key);
+}
+
 void DOHFrontend::handleTicketsKeyRotation()
 {
 }
index 0b7ce0ed4b0386a1a636920acfe531023c0a0237..451e7133ef798b985efb9d0d9a370eeaa92f305b 100644 (file)
@@ -687,6 +687,22 @@ void OpenSSLTLSTicketKeysRing::loadTicketsKeys(const std::string& keyFile)
   file.close();
 }
 
+void OpenSSLTLSTicketKeysRing::loadTicketsKey(const std::string& key)
+{
+  bool keyLoaded = false;
+  try {
+    auto newKey = std::make_shared<OpenSSLTLSTicketKey>(key);
+    addKey(std::move(newKey));
+    keyLoaded = true;
+  }
+  catch (const std::exception& e) {
+    /* if we haven't been able to load at least one key, fail */
+    if (!keyLoaded) {
+      throw;
+    }
+  }
+}
+
 void OpenSSLTLSTicketKeysRing::rotateTicketsKey(time_t /* now */)
 {
   auto newKey = std::make_shared<OpenSSLTLSTicketKey>();
@@ -729,6 +745,25 @@ OpenSSLTLSTicketKey::OpenSSLTLSTicketKey(std::ifstream& file)
 #endif /* HAVE_LIBSODIUM */
 }
 
+OpenSSLTLSTicketKey::OpenSSLTLSTicketKey(const std::string& key)
+{
+  if (key.size() != (sizeof(d_name) + sizeof(d_cipherKey) + sizeof(d_hmacKey))) {
+    throw std::runtime_error("Unable to load a ticket key from given data");
+  }
+  size_t from = 0;
+  memcpy(d_name, &key.at(from), sizeof(d_name));
+  from += sizeof(d_name);
+  memcpy(d_cipherKey, &key.at(from), sizeof(d_cipherKey));
+  from += sizeof(d_cipherKey);
+  memcpy(d_hmacKey, &key.at(from), sizeof(d_hmacKey));
+
+#ifdef HAVE_LIBSODIUM
+  sodium_mlock(d_name, sizeof(d_name));
+  sodium_mlock(d_cipherKey, sizeof(d_cipherKey));
+  sodium_mlock(d_hmacKey, sizeof(d_hmacKey));
+#endif /* HAVE_LIBSODIUM */
+}
+
 OpenSSLTLSTicketKey::~OpenSSLTLSTicketKey()
 {
 #ifdef HAVE_LIBSODIUM
index f961c1d25f3b0393dbbe30581dbbc1f6c7a2a527..915fa787e24b5b64ca135f9dfc246de21b83a1b6 100644 (file)
@@ -93,6 +93,7 @@ class OpenSSLTLSTicketKey
 public:
   OpenSSLTLSTicketKey();
   OpenSSLTLSTicketKey(std::ifstream& file);
+  OpenSSLTLSTicketKey(const std::string& key);
   ~OpenSSLTLSTicketKey();
 
   bool nameMatches(const unsigned char name[TLS_TICKETS_KEY_NAME_SIZE]) const;
@@ -122,6 +123,7 @@ public:
   std::shared_ptr<OpenSSLTLSTicketKey> getDecryptionKey(unsigned char name[TLS_TICKETS_KEY_NAME_SIZE], bool& activeKey);
   size_t getKeysCount();
   void loadTicketsKeys(const std::string& keyFile);
+  void loadTicketsKey(const std::string& key);
   void rotateTicketsKey(time_t now);
 
 private:
index a4a5995b0d16dec763ff915248d09d642130f741..547aaa40b9e6086470b108154d30e2db525e2409 100644 (file)
@@ -857,6 +857,15 @@ public:
     }
   }
 
+  void loadTicketsKey(const std::string& key) final
+  {
+    d_feContext->d_ticketKeys.loadTicketsKey(key);
+
+    if (d_ticketsKeyRotationDelay > 0) {
+      d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
+    }
+  }
+
   size_t getTicketsKeysCount() override
   {
     return d_feContext->d_ticketKeys.getKeysCount();
@@ -1002,7 +1011,24 @@ public:
     safe_memory_lock(d_key.data, d_key.size);
   }
 
-  GnuTLSTicketsKey(const std::string& keyFile)
+  GnuTLSTicketsKey(const std::string& key)
+  {
+    /* to be sure we are loading the correct amount of data, which
+       may change between versions, let's generate a correct key first */
+    if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
+      throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
+    }
+
+    safe_memory_lock(d_key.data, d_key.size);
+    if (key.size() != d_key.size) {
+      safe_memory_release(d_key.data, d_key.size);
+      gnutls_free(d_key.data);
+      d_key.data = nullptr;
+      throw std::runtime_error("Invalid GnuTLS ticket key size");
+    }
+    memcpy(d_key.data, key.data(), key.size());
+  }
+  GnuTLSTicketsKey(std::ifstream& file)
   {
     /* to be sure we are loading the correct amount of data, which
        may change between versions, let's generate a correct key first */
@@ -1013,17 +1039,17 @@ public:
     safe_memory_lock(d_key.data, d_key.size);
 
     try {
-      ifstream file(keyFile);
       file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
 
       if (file.fail()) {
-        file.close();
-        throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile);
+        throw std::runtime_error("Invalid GnuTLS tickets key file");
       }
 
-      file.close();
     }
     catch (const std::exception& e) {
+      safe_memory_release(d_key.data, d_key.size);
+      gnutls_free(d_key.data);
+      d_key.data = nullptr;
       safe_memory_release(d_key.data, d_key.size);
       gnutls_free(d_key.data);
       d_key.data = nullptr;
@@ -1813,14 +1839,26 @@ public:
     auto newKey = std::make_shared<GnuTLSTicketsKey>();
     addTicketsKey(now, std::move(newKey));
   }
-  void loadTicketsKeys(const std::string& file) final
+  void loadTicketsKey(const std::string& key) final
+  {
+    if (!d_enableTickets) {
+      return;
+    }
+
+    auto newKey = std::make_shared<GnuTLSTicketsKey>(key);
+    addTicketsKey(time(nullptr), std::move(newKey));
+  }
+
+  void loadTicketsKeys(const std::string& keyFile) final
   {
     if (!d_enableTickets) {
       return;
     }
 
+    std::ifstream file(keyFile);
     auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
     addTicketsKey(time(nullptr), std::move(newKey));
+    file.close();
   }
 
   size_t getTicketsKeysCount() override
index 8420529811e30b22d75fdee2ec09c040f281dc5a..4fb7006bdb083abb92c71872d05dcd04b2516486 100644 (file)
@@ -81,6 +81,10 @@ public:
   {
     throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file");
   }
+  virtual void loadTicketsKey(const std::string& /* key */)
+  {
+    throw std::runtime_error("This TLS backend does not have the capability to load a ticket key");
+  }
   void handleTicketsKeyRotation(time_t now)
   {
     if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) {
@@ -175,6 +179,13 @@ public:
     }
   }
 
+  void loadTicketsKey(const std::string& key)
+  {
+    if (d_ctx != nullptr) {
+      d_ctx->loadTicketsKey(key);
+    }
+  }
+
   std::shared_ptr<TLSCtx> getContext()
   {
     return std::atomic_load_explicit(&d_ctx, std::memory_order_acquire);
index 27c2de52fe19b6ac343eb6da5f0f30830217678f..f40c18cfe450138072bce8454edd07b2379c7ff2 100644 (file)
@@ -6,6 +6,9 @@ import ssl
 import subprocess
 import time
 import unittest
+import random
+import string
+
 from dnsdisttests import DNSDistTest, pickAvailablePort
 
 class TLSTests(object):
@@ -517,7 +520,7 @@ class TestPKCSTLSCertificate(DNSDistTest, TLSTests):
         cls.startDNSDist()
         cls.setUpSockets()
 
-class TestTLSTicketsKeyAddedCallback(DNSDistTest):
+class TestOpenSSLTLSTicketsKeyCallback(DNSDistTest):
     _consoleKey = DNSDistTest.generateConsoleKey()
     _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
 
@@ -536,20 +539,67 @@ class TestTLSTicketsKeyAddedCallback(DNSDistTest):
     newServer{address="127.0.0.1:%s"}
     addTLSLocal("127.0.0.1:%s", "%s", "%s", { provider="openssl" })
 
-    callbackCalled = 0
+    lastKey = ""
+    lastKeyLen = 0
+
     function keyAddedCallback(key, keyLen)
-      callbackCalled = keyLen
+      lastKey = key
+      lastKeyLen = keyLen
     end
+    setTicketsKeyAddedHook(keyAddedCallback)
+    """
+
+    def testSetTicketsKey(self):
+        """
+        TLSTicketsKey: test setting new key and the key added hook
+        """
 
+        newKey = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(80))
+        print("about to send command: `{}`".format("getTLSFrontend(0):setTicketsKey(\"{}\")".format(newKey)))
+        self.sendConsoleCommand("getTLSFrontend(0):loadTicketsKey(\"{}\")".format(newKey))
+        keyLen = self.sendConsoleCommand('lastKeyLen')
+        self.assertEqual(int(keyLen), 80)
+        lastKey = self.sendConsoleCommand('lastKey')
+        self.assertEqual(newKey, lastKey.strip())
+
+class TestGnuTLSTLSTicketsKeyCallback(DNSDistTest):
+    _consoleKey = DNSDistTest.generateConsoleKey()
+    _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+
+    _serverKey = 'server.key'
+    _serverCert = 'server.chain'
+    _serverName = 'tls.tests.dnsdist.org'
+    _caCert = 'ca.pem'
+    _tlsServerPort = pickAvailablePort()
+    _numberOfKeys = 5
+
+    _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey']
+    _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%s")
+
+    newServer{address="127.0.0.1:%s"}
+    addTLSLocal("127.0.0.1:%s", "%s", "%s", { provider="gnutls" })
+
+    lastKey = ""
+    lastKeyLen = 0
+
+    function keyAddedCallback(key, keyLen)
+      lastKey = key
+      lastKeyLen = keyLen
+    end
+    setTicketsKeyAddedHook(keyAddedCallback)
     """
 
-    def testLuaThreadCounter(self):
+    def testSetTicketsKey(self):
         """
-        LuaThread: Test the lua newThread interface
+        TLSTicketsKey: test setting new key and the key added hook
         """
-        self.sendConsoleCommand('setTicketsKeyAddedHook(keyAddedCallback)');
-        called = self.sendConsoleCommand('callbackCalled')
-        self.assertEqual(int(called), 0)
-        self.sendConsoleCommand("getTLSFrontend(0):rotateTicketsKey()")
-        called = self.sendConsoleCommand('callbackCalled')
-        self.assertGreater(int(called), 0)
+
+        newKey = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(64))
+        print("about to send command: `{}`".format("getTLSFrontend(0):setTicketsKey(\"{}\")".format(newKey)))
+        self.sendConsoleCommand("getTLSFrontend(0):loadTicketsKey(\"{}\")".format(newKey))
+        keyLen = self.sendConsoleCommand('lastKeyLen')
+        self.assertEqual(int(keyLen), 64)
+        lastKey = self.sendConsoleCommand('lastKey')
+        self.assertEqual(newKey, lastKey.strip())