]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: add support for a callback when a new tickets key is added to the tls context
authorCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Fri, 14 Jun 2024 15:34:04 +0000 (17:34 +0200)
committerCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Thu, 27 Jun 2024 08:31:16 +0000 (10:31 +0200)
pdns/dnsdistdist/dnsdist-doh-common.cc
pdns/dnsdistdist/dnsdist-doh-common.hh
pdns/dnsdistdist/dnsdist-lua.cc
pdns/dnsdistdist/docs/reference/config.rst
pdns/libssl.cc
pdns/libssl.hh
pdns/tcpiohandler.cc
pdns/tcpiohandler.hh
regression-tests.dnsdist/test_TLS.py

index dcbd183d7cefac12590bbe41b1a3e18bc9813e02..fc66f286bb342ff1e9324ba4bc6a7459d5a95ff6 100644 (file)
@@ -94,6 +94,11 @@ void DOHFrontend::rotateTicketsKey(time_t now)
   return d_tlsContext.rotateTicketsKey(now);
 }
 
+void DOHFrontend::setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook)
+{
+  return d_tlsContext.setTicketsKeyAddedHook(hook);
+}
+
 void DOHFrontend::loadTicketsKeys(const std::string& keyFile)
 {
   return d_tlsContext.loadTicketsKeys(keyFile);
index 0dc714df23a3c0273f3d6f6c748189cd4f3a0c4d..82ef70f83b3607f4cd0e130cbb67ba7f65ef1842 100644 (file)
@@ -162,6 +162,10 @@ struct DOHFrontend
   {
   }
 
+  virtual void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& /* hook */)
+  {
+  }
+
   virtual void loadTicketsKeys(const std::string& /* keyFile */)
   {
   }
@@ -185,6 +189,7 @@ struct DOHFrontend
   virtual void setup();
   virtual void reloadCertificates();
 
+  virtual void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook);
   virtual void rotateTicketsKey(time_t now);
   virtual void loadTicketsKeys(const std::string& keyFile);
   virtual void handleTicketsKeyRotation();
index c526a93cc2d6d0855b479b36b270b7bc4df8a803..bb5edcd22474bc690f298e24380fecee05dee4ec 100644 (file)
@@ -3011,6 +3011,13 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     }
   });
 
+
+  luaCtx.registerFunction<void (std::shared_ptr<DOHFrontend>::*)(const dnsdist_tickets_key_added_hook&)>("setTicketsKeyAddedHook", [](const std::shared_ptr<DOHFrontend>& frontend, const dnsdist_tickets_key_added_hook& hook) {
+    if (frontend != nullptr) {
+      frontend->setTicketsKeyAddedHook(hook);
+    }
+  });
+
   luaCtx.registerFunction<void (std::shared_ptr<DOHFrontend>::*)(const LuaArray<std::shared_ptr<DOHResponseMapEntry>>&)>("setResponsesMap", [](const std::shared_ptr<DOHFrontend>& frontend, const LuaArray<std::shared_ptr<DOHResponseMapEntry>>& map) {
     if (frontend != nullptr) {
       auto newMap = std::make_shared<std::vector<std::shared_ptr<DOHResponseMapEntry>>>();
@@ -3208,6 +3215,12 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     }
   });
 
+  luaCtx.registerFunction<void (std::shared_ptr<TLSCtx>::*)(const dnsdist_tickets_key_added_hook&)>("setTicketsKeyAddedHook", [](const std::shared_ptr<TLSCtx>& frontend, const dnsdist_tickets_key_added_hook& hook) {
+    if (frontend != nullptr) {
+      frontend->setTicketsKeyAddedHook(hook);
+    }
+  });
+
   luaCtx.registerFunction<void (std::shared_ptr<TLSCtx>::*)(const std::string&)>("loadTicketsKeys", [](std::shared_ptr<TLSCtx>& ctx, const std::string& file) {
     if (ctx != nullptr) {
       ctx->loadTicketsKeys(file);
@@ -3221,6 +3234,16 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     return frontend->d_addr.toStringWithPort();
   });
 
+  luaCtx.registerFunction<void (std::shared_ptr<TLSFrontend>::*)(const dnsdist_tickets_key_added_hook&)>("setTicketsKeyAddedHook", [](const std::shared_ptr<TLSFrontend>& frontend, const dnsdist_tickets_key_added_hook& hook) {
+    if (frontend == nullptr) {
+      return;
+    }
+      auto ctx = frontend->getContext();
+    if (ctx) {
+      ctx->setTicketsKeyAddedHook(hook);
+    }
+  });
+
   luaCtx.registerFunction<void (std::shared_ptr<TLSFrontend>::*)()>("rotateTicketsKey", [](std::shared_ptr<TLSFrontend>& frontend) {
     if (frontend == nullptr) {
       return;
index 7473624276e668f8956cf4e4e9615a4bf219ef85..db66c1eb8267a5fd7d57dc3955c22d417f471cd8 100644 (file)
@@ -2322,6 +2322,17 @@ DOHFrontend
 
      Replace the current TLS tickets key by a new random one.
 
+  .. method:: DOHFrontend:setTicketsKeyAddedHook(callback)
+
+     .. versionadded:: 1.9.0
+
+    Set a Lua function that will be called everytime a new tickets key is added. The function receives:
+
+    * the key content as a string
+    * the keylen as an integer
+
+    See :doc:`../advanced/tls-sessions-management` for more information.
+
   .. method:: DOHFrontend:setResponsesMap(rules)
 
      Set a list of HTTP response rules allowing to intercept HTTP queries very early, before the DNS payload has been processed, and send custom responses including error pages, redirects and static content.
@@ -2464,6 +2475,17 @@ TLSContext
 
      Replace the current TLS tickets key by a new random one.
 
+  .. method:: TLSContext:setTicketsKeyAddedHook(callback)
+
+     .. versionadded:: 1.9.0
+
+    Set a Lua function that will be called everytime a new tickets key is added. The function receives:
+
+    * the key content as a string
+    * the keylen as an integer
+
+    See :doc:`../advanced/tls-sessions-management` for more information.
+
 TLSFrontend
 ~~~~~~~~~~~
 
@@ -2505,6 +2527,17 @@ TLSFrontend
 
      Replace the current TLS tickets key by a new random one.
 
+  .. method:: TLSFrontend:setTicketsKeyAddedHook(callback)
+
+     .. versionadded:: 1.9.0
+
+    Set a Lua function that will be called everytime a new tickets key is added. The function receives:
+
+    * the key content as a string
+    * the keylen as an integer
+
+    See :doc:`../advanced/tls-sessions-management` for more information.
+
 EDNS on Self-generated answers
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
index 3f657326c432af1e6ab2190027be8ae31b5e331c..f72edfaca270510c0b612b94488b6312679df7a3 100644 (file)
@@ -631,6 +631,16 @@ OpenSSLTLSTicketKeysRing::~OpenSSLTLSTicketKeysRing() = default;
 void OpenSSLTLSTicketKeysRing::addKey(std::shared_ptr<OpenSSLTLSTicketKey>&& newKey)
 {
   d_ticketKeys.write_lock()->push_front(std::move(newKey));
+  if (d_ticketsKeyAddedHook) {
+    auto key = d_ticketKeys.read_lock()->front();
+    auto keyContent = key->content();
+    d_ticketsKeyAddedHook(keyContent.c_str(), keyContent.size());
+  }
+}
+
+void OpenSSLTLSTicketKeysRing::setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook)
+{
+  d_ticketsKeyAddedHook = hook;
 }
 
 std::shared_ptr<OpenSSLTLSTicketKey> OpenSSLTLSTicketKeysRing::getEncryptionKey()
@@ -737,6 +747,17 @@ bool OpenSSLTLSTicketKey::nameMatches(const unsigned char name[TLS_TICKETS_KEY_N
   return (memcmp(d_name, name, sizeof(d_name)) == 0);
 }
 
+std::string OpenSSLTLSTicketKey::content() const
+{
+  std::string result{};
+  result.reserve(TLS_TICKETS_KEY_NAME_SIZE + TLS_TICKETS_CIPHER_KEY_SIZE + TLS_TICKETS_MAC_KEY_SIZE);
+  result.append(reinterpret_cast<const char*>(d_name), TLS_TICKETS_KEY_NAME_SIZE);
+  result.append(reinterpret_cast<const char*>(d_cipherKey), TLS_TICKETS_CIPHER_KEY_SIZE);
+  result.append(reinterpret_cast<const char*>(d_hmacKey), TLS_TICKETS_MAC_KEY_SIZE);
+
+  return result;
+}
+
 #if OPENSSL_VERSION_MAJOR >= 3
 static const std::string sha256KeyName{"sha256"};
 #endif
index 8dd7ff373bf6b49546a225655239ff39aa3567c7..d0ed6a96bca27c8f9c89b1f05c4309ac4c8ad459 100644 (file)
@@ -100,6 +100,7 @@ public:
 #if OPENSSL_VERSION_MAJOR >= 3
   int encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx) const;
   bool decrypt(const unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx) const;
+  std::string content() const;
 #else
   int encrypt(unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx) const;
   bool decrypt(const unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx) const;
@@ -111,6 +112,8 @@ private:
   unsigned char d_hmacKey[TLS_TICKETS_MAC_KEY_SIZE];
 };
 
+using dnsdist_tickets_key_added_hook = std::function<void(const char* key, size_t keyLen)>;
+
 class OpenSSLTLSTicketKeysRing
 {
 public:
@@ -121,10 +124,11 @@ public:
   size_t getKeysCount();
   void loadTicketsKeys(const std::string& keyFile);
   void rotateTicketsKey(time_t now);
+  void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook);
 
 private:
   void addKey(std::shared_ptr<OpenSSLTLSTicketKey>&& newKey);
-
+  dnsdist_tickets_key_added_hook d_ticketsKeyAddedHook;
   SharedLockGuarded<boost::circular_buffer<std::shared_ptr<OpenSSLTLSTicketKey> > > d_ticketKeys;
 };
 
index cf82471ba84dfc36fde17bb78af544533769a952..87391ba2ab18d0392f4378884c2402ee1cd7d383 100644 (file)
@@ -813,6 +813,11 @@ public:
     }
   }
 
+  void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) override
+  {
+    d_feContext->d_ticketKeys.setTicketsKeyAddedHook(hook);
+  }
+
   void loadTicketsKeys(const std::string& keyFile) final
   {
     d_feContext->d_ticketKeys.loadTicketsKeys(keyFile);
@@ -987,6 +992,14 @@ public:
       throw;
     }
   }
+  std::string content() const
+  {
+    std::string result{};
+    if (d_key.data != nullptr && d_key.size > 0) {
+      result.append(reinterpret_cast<const char*>(d_key.data), d_key.size);
+    }
+    return result;
+  }
 
   ~GnuTLSTicketsKey()
   {
@@ -1730,6 +1743,11 @@ public:
     return connection;
   }
 
+  void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) override
+  {
+    d_ticketsKeyAddedHook = hook;
+  }
+
   void rotateTicketsKey(time_t now) override
   {
     if (!d_enableTickets) {
@@ -1745,6 +1763,12 @@ public:
     if (d_ticketsKeyRotationDelay > 0) {
       d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
     }
+
+    if (d_ticketsKeyAddedHook) {
+      auto ticketsKey = *(d_ticketsKey.read_lock());
+      auto content = ticketsKey->content();
+      d_ticketsKeyAddedHook(content.c_str(), content.size());
+    }
   }
 
   void loadTicketsKeys(const std::string& file) final
@@ -1792,6 +1816,7 @@ private:
   SharedLockGuarded<std::shared_ptr<GnuTLSTicketsKey>> d_ticketsKey{nullptr};
   bool d_enableTickets{true};
   bool d_validateCerts{true};
+  dnsdist_tickets_key_added_hook d_ticketsKeyAddedHook;
 };
 
 #endif /* HAVE_GNUTLS */
index 058d10443b713840304a24e37b87c05f9f5004e7..c592701eedf9c367997a68dbe2423fceed1344d0 100644 (file)
@@ -81,6 +81,10 @@ public:
   {
     throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file");
   }
+  virtual void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& /* hook */)
+  {
+    throw std::runtime_error("This TLS backend does not have the capability to setup a hook for added tickets keys");
+  }
 
   void handleTicketsKeyRotation(time_t now)
   {
@@ -152,6 +156,13 @@ public:
     }
   }
 
+  void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook)
+  {
+    if (d_ctx != nullptr) {
+      d_ctx->setTicketsKeyAddedHook(hook);
+    }
+  }
+
   void loadTicketsKeys(const std::string& file)
   {
     if (d_ctx != nullptr) {
index 9803ed550f961300ecfb656ede0eddd6edaa9a65..6138c3e6363eb7917d85e1c4bf933eb500794d86 100644 (file)
@@ -516,3 +516,40 @@ class TestPKCSTLSCertificate(DNSDistTest, TLSTests):
         cls.startResponders()
         cls.startDNSDist()
         cls.setUpSockets()
+
+class TestTLSTicketsKeyAddedCallback(DNSDistTest):
+    _consoleKey = DNSDistTest.generateConsoleKey()
+    _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+
+    _serverKey = 'server.key'
+    _serverCert = 'server.chain'
+    _serverName = 'tls.tests.dnsdist.org'
+    _caCert = 'ca.pem'
+    _tlsServerPort = pickAvailablePort()
+    _numberOfKeys = 5
+
+    _config_params = ['_consoleKeyB64', '_consolePort', '_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey']
+    _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%s")
+
+    newServer{address="127.0.0.1:%s"}
+    addTLSLocal("127.0.0.1:%s", "%s", "%s", { provider="openssl" })
+
+    callbackCalled = 0
+    function keyAddedCallback(key, keyLen)
+      callbackCalled = keyLen
+    end
+
+    """
+
+    def testLuaThreadCounter(self):
+        """
+        LuaThread: Test the lua newThread interface
+        """
+        self.sendConsoleCommand('getTLSFrontend(0):setTicketsKeyAddedHook(keyAddedCallback)');
+        called = self.sendConsoleCommand('callbackCalled')
+        self.assertEqual(int(called), 0)
+        self.sendConsoleCommand("getTLSFrontend(0):rotateTicketsKey()")
+        called = self.sendConsoleCommand('callbackCalled')
+        self.assertGreater(int(called), 0)