]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: move the setTicketsKeyAddedHook to a unique callback for every tls context
authorCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Thu, 27 Jun 2024 13:02:39 +0000 (15:02 +0200)
committerCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Thu, 27 Jun 2024 13:25:35 +0000 (15:25 +0200)
pdns/dnsdistdist/dnsdist-doh-common.cc
pdns/dnsdistdist/dnsdist-doh-common.hh
pdns/dnsdistdist/dnsdist-lua-hooks.cc
pdns/dnsdistdist/dnsdist-lua-hooks.hh
pdns/dnsdistdist/dnsdist-lua.cc
pdns/dnsdistdist/docs/reference/config.rst
pdns/libssl.cc
pdns/libssl.hh
pdns/tcpiohandler.cc
pdns/tcpiohandler.hh
regression-tests.dnsdist/test_TLS.py

index fc66f286bb342ff1e9324ba4bc6a7459d5a95ff6..dcbd183d7cefac12590bbe41b1a3e18bc9813e02 100644 (file)
@@ -94,11 +94,6 @@ void DOHFrontend::rotateTicketsKey(time_t now)
   return d_tlsContext.rotateTicketsKey(now);
 }
 
-void DOHFrontend::setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook)
-{
-  return d_tlsContext.setTicketsKeyAddedHook(hook);
-}
-
 void DOHFrontend::loadTicketsKeys(const std::string& keyFile)
 {
   return d_tlsContext.loadTicketsKeys(keyFile);
index 82ef70f83b3607f4cd0e130cbb67ba7f65ef1842..0dc714df23a3c0273f3d6f6c748189cd4f3a0c4d 100644 (file)
@@ -162,10 +162,6 @@ struct DOHFrontend
   {
   }
 
-  virtual void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& /* hook */)
-  {
-  }
-
   virtual void loadTicketsKeys(const std::string& /* keyFile */)
   {
   }
@@ -189,7 +185,6 @@ struct DOHFrontend
   virtual void setup();
   virtual void reloadCertificates();
 
-  virtual void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook);
   virtual void rotateTicketsKey(time_t now);
   virtual void loadTicketsKeys(const std::string& keyFile);
   virtual void handleTicketsKeyRotation();
index c5ccb48915c14d171c456a7811477d3fff2d59eb..621e73451205fbcff62b7a14da4d4b6bc2d1e98f 100644 (file)
@@ -2,6 +2,7 @@
 #include "dnsdist-lua-hooks.hh"
 #include "dnsdist-lua.hh"
 #include "lock.hh"
+#include "tcpiohandler.hh"
 
 namespace dnsdist::lua::hooks
 {
@@ -26,12 +27,28 @@ void clearMaintenanceHooks()
   s_maintenanceHooks.lock()->clear();
 }
 
+void setTicketsKeyAddedHook(const LuaContext& context, const TicketsKeyAddedHook& hook)
+{
+  TLSCtx::setTicketsKeyAddedHook([hook](const std::string& key) {
+    try {
+      hook(key.c_str(), key.size());
+    }
+    catch (const std::exception& exp) {
+      warnlog("Error calling the Lua hook after new tickets key has been added", exp.what());
+    }
+  });
+}
+
 void setupLuaHooks(LuaContext& luaCtx)
 {
   luaCtx.writeFunction("addMaintenanceCallback", [&luaCtx](const MaintenanceCallback& callback) {
     setLuaSideEffect();
     addMaintenanceCallback(luaCtx, callback);
   });
+  luaCtx.writeFunction("setTicketsKeyAddedHook", [&luaCtx](const TicketsKeyAddedHook& hook) {
+    setLuaSideEffect();
+    setTicketsKeyAddedHook(luaCtx, hook);
+  });
 }
 
 }
index 11a9084883ee8a74585acc611bb10bdb9ea3fae8..8cbb7c903ae9130945b5f1c8d7b592a4cb352e37 100644 (file)
@@ -28,8 +28,11 @@ class LuaContext;
 namespace dnsdist::lua::hooks
 {
 using MaintenanceCallback = std::function<void()>;
+using TicketsKeyAddedHook = std::function<void(const char*, size_t)>;
+
 void runMaintenanceHooks(const LuaContext& context);
 void addMaintenanceCallback(const LuaContext& context, MaintenanceCallback callback);
+void setTicketsKeyAddedHook(const LuaContext& context, const TicketsKeyAddedHook& hook);
 void clearMaintenanceHooks();
 void setupLuaHooks(LuaContext& luaCtx);
 }
index bb5edcd22474bc690f298e24380fecee05dee4ec..c526a93cc2d6d0855b479b36b270b7bc4df8a803 100644 (file)
@@ -3011,13 +3011,6 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     }
   });
 
-
-  luaCtx.registerFunction<void (std::shared_ptr<DOHFrontend>::*)(const dnsdist_tickets_key_added_hook&)>("setTicketsKeyAddedHook", [](const std::shared_ptr<DOHFrontend>& frontend, const dnsdist_tickets_key_added_hook& hook) {
-    if (frontend != nullptr) {
-      frontend->setTicketsKeyAddedHook(hook);
-    }
-  });
-
   luaCtx.registerFunction<void (std::shared_ptr<DOHFrontend>::*)(const LuaArray<std::shared_ptr<DOHResponseMapEntry>>&)>("setResponsesMap", [](const std::shared_ptr<DOHFrontend>& frontend, const LuaArray<std::shared_ptr<DOHResponseMapEntry>>& map) {
     if (frontend != nullptr) {
       auto newMap = std::make_shared<std::vector<std::shared_ptr<DOHResponseMapEntry>>>();
@@ -3215,12 +3208,6 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     }
   });
 
-  luaCtx.registerFunction<void (std::shared_ptr<TLSCtx>::*)(const dnsdist_tickets_key_added_hook&)>("setTicketsKeyAddedHook", [](const std::shared_ptr<TLSCtx>& frontend, const dnsdist_tickets_key_added_hook& hook) {
-    if (frontend != nullptr) {
-      frontend->setTicketsKeyAddedHook(hook);
-    }
-  });
-
   luaCtx.registerFunction<void (std::shared_ptr<TLSCtx>::*)(const std::string&)>("loadTicketsKeys", [](std::shared_ptr<TLSCtx>& ctx, const std::string& file) {
     if (ctx != nullptr) {
       ctx->loadTicketsKeys(file);
@@ -3234,16 +3221,6 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     return frontend->d_addr.toStringWithPort();
   });
 
-  luaCtx.registerFunction<void (std::shared_ptr<TLSFrontend>::*)(const dnsdist_tickets_key_added_hook&)>("setTicketsKeyAddedHook", [](const std::shared_ptr<TLSFrontend>& frontend, const dnsdist_tickets_key_added_hook& hook) {
-    if (frontend == nullptr) {
-      return;
-    }
-      auto ctx = frontend->getContext();
-    if (ctx) {
-      ctx->setTicketsKeyAddedHook(hook);
-    }
-  });
-
   luaCtx.registerFunction<void (std::shared_ptr<TLSFrontend>::*)()>("rotateTicketsKey", [](std::shared_ptr<TLSFrontend>& frontend) {
     if (frontend == nullptr) {
       return;
index db66c1eb8267a5fd7d57dc3955c22d417f471cd8..80ad8ab46546d0123f712c2af5012089b033455f 100644 (file)
@@ -2173,6 +2173,17 @@ Other functions
   Code is supplied as a string, not as a function object.
   Note that this function does nothing in 'client' or 'config-check' modes.
 
+.. function:: setTicketsKeyAddedHook(callback)
+
+  .. versionadded:: 1.9.0
+
+  Set a Lua function that will be called everytime a new tickets key is added. The function receives:
+
+  * the key content as a string
+  * the keylen as an integer
+
+  See :doc:`../advanced/tls-sessions-management` for more information.
+
 .. function:: submitToMainThread(cmd, dict)
 
   .. versionadded:: 1.8.0
@@ -2322,17 +2333,6 @@ DOHFrontend
 
      Replace the current TLS tickets key by a new random one.
 
-  .. method:: DOHFrontend:setTicketsKeyAddedHook(callback)
-
-     .. versionadded:: 1.9.0
-
-    Set a Lua function that will be called everytime a new tickets key is added. The function receives:
-
-    * the key content as a string
-    * the keylen as an integer
-
-    See :doc:`../advanced/tls-sessions-management` for more information.
-
   .. method:: DOHFrontend:setResponsesMap(rules)
 
      Set a list of HTTP response rules allowing to intercept HTTP queries very early, before the DNS payload has been processed, and send custom responses including error pages, redirects and static content.
@@ -2475,17 +2475,6 @@ TLSContext
 
      Replace the current TLS tickets key by a new random one.
 
-  .. method:: TLSContext:setTicketsKeyAddedHook(callback)
-
-     .. versionadded:: 1.9.0
-
-    Set a Lua function that will be called everytime a new tickets key is added. The function receives:
-
-    * the key content as a string
-    * the keylen as an integer
-
-    See :doc:`../advanced/tls-sessions-management` for more information.
-
 TLSFrontend
 ~~~~~~~~~~~
 
@@ -2527,17 +2516,6 @@ TLSFrontend
 
      Replace the current TLS tickets key by a new random one.
 
-  .. method:: TLSFrontend:setTicketsKeyAddedHook(callback)
-
-     .. versionadded:: 1.9.0
-
-    Set a Lua function that will be called everytime a new tickets key is added. The function receives:
-
-    * the key content as a string
-    * the keylen as an integer
-
-    See :doc:`../advanced/tls-sessions-management` for more information.
-
 EDNS on Self-generated answers
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
index f72edfaca270510c0b612b94488b6312679df7a3..cd9ad076fe947809b792ecce66621cfcabad9aef 100644 (file)
@@ -42,6 +42,7 @@
 
 #undef CERT
 #include "misc.hh"
+#include "tcpiohandler.hh"
 
 #if (OPENSSL_VERSION_NUMBER < 0x1010000fL || (defined LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x2090100fL)
 /* OpenSSL < 1.1.0 needs support for threading/locking in the calling application. */
@@ -631,18 +632,13 @@ OpenSSLTLSTicketKeysRing::~OpenSSLTLSTicketKeysRing() = default;
 void OpenSSLTLSTicketKeysRing::addKey(std::shared_ptr<OpenSSLTLSTicketKey>&& newKey)
 {
   d_ticketKeys.write_lock()->push_front(std::move(newKey));
-  if (d_ticketsKeyAddedHook) {
+  if (TLSCtx::hasTicketsKeyAddedHook()) {
     auto key = d_ticketKeys.read_lock()->front();
     auto keyContent = key->content();
-    d_ticketsKeyAddedHook(keyContent.c_str(), keyContent.size());
+    TLSCtx::getTicketsKeyAddedHook()(keyContent);
   }
 }
 
-void OpenSSLTLSTicketKeysRing::setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook)
-{
-  d_ticketsKeyAddedHook = hook;
-}
-
 std::shared_ptr<OpenSSLTLSTicketKey> OpenSSLTLSTicketKeysRing::getEncryptionKey()
 {
   return d_ticketKeys.read_lock()->front();
index d0ed6a96bca27c8f9c89b1f05c4309ac4c8ad459..c1ed2067407e7f56498fbb66f62d33417b4e547f 100644 (file)
@@ -112,8 +112,6 @@ private:
   unsigned char d_hmacKey[TLS_TICKETS_MAC_KEY_SIZE];
 };
 
-using dnsdist_tickets_key_added_hook = std::function<void(const char* key, size_t keyLen)>;
-
 class OpenSSLTLSTicketKeysRing
 {
 public:
@@ -124,11 +122,9 @@ public:
   size_t getKeysCount();
   void loadTicketsKeys(const std::string& keyFile);
   void rotateTicketsKey(time_t now);
-  void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook);
 
 private:
   void addKey(std::shared_ptr<OpenSSLTLSTicketKey>&& newKey);
-  dnsdist_tickets_key_added_hook d_ticketsKeyAddedHook;
   SharedLockGuarded<boost::circular_buffer<std::shared_ptr<OpenSSLTLSTicketKey> > > d_ticketKeys;
 };
 
index 87391ba2ab18d0392f4378884c2402ee1cd7d383..1fb91ef5e0ab7b060ef29bd6268faedae5a89ee4 100644 (file)
@@ -22,6 +22,7 @@ const bool TCPIOHandler::s_disableConnectForUnitTests = false;
 
 #include "libssl.hh"
 
+dnsdist_tickets_key_added_hook TLSCtx::s_ticketsKeyAddedHook{nullptr};
 
 class OpenSSLFrontendContext
 {
@@ -813,11 +814,6 @@ public:
     }
   }
 
-  void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) override
-  {
-    d_feContext->d_ticketKeys.setTicketsKeyAddedHook(hook);
-  }
-
   void loadTicketsKeys(const std::string& keyFile) final
   {
     d_feContext->d_ticketKeys.loadTicketsKeys(keyFile);
@@ -1743,19 +1739,12 @@ public:
     return connection;
   }
 
-  void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook) override
-  {
-    d_ticketsKeyAddedHook = hook;
-  }
-
-  void rotateTicketsKey(time_t now) override
+  void addTicketsKey(time_t now, std::shared_ptr<GnuTLSTicketsKey>&& newKey)
   {
     if (!d_enableTickets) {
       return;
     }
 
-    auto newKey = std::make_shared<GnuTLSTicketsKey>();
-
     {
       *(d_ticketsKey.write_lock()) = std::move(newKey);
     }
@@ -1764,13 +1753,21 @@ public:
       d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
     }
 
-    if (d_ticketsKeyAddedHook) {
+    if (TLSCtx::hasTicketsKeyAddedHook()) {
       auto ticketsKey = *(d_ticketsKey.read_lock());
       auto content = ticketsKey->content();
-      d_ticketsKeyAddedHook(content.c_str(), content.size());
+      TLSCtx::getTicketsKeyAddedHook()(content);
     }
   }
+  void rotateTicketsKey(time_t now) override
+  {
+    if (!d_enableTickets) {
+      return;
+    }
 
+    auto newKey = std::make_shared<GnuTLSTicketsKey>();
+    addTicketsKey(now, std::move(newKey));
+  }
   void loadTicketsKeys(const std::string& file) final
   {
     if (!d_enableTickets) {
@@ -1778,13 +1775,7 @@ public:
     }
 
     auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
-    {
-      *(d_ticketsKey.write_lock()) = std::move(newKey);
-    }
-
-    if (d_ticketsKeyRotationDelay > 0) {
-      d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
-    }
+    addTicketsKey(time(nullptr), std::move(newKey));
   }
 
   size_t getTicketsKeysCount() override
@@ -1816,7 +1807,6 @@ private:
   SharedLockGuarded<std::shared_ptr<GnuTLSTicketsKey>> d_ticketsKey{nullptr};
   bool d_enableTickets{true};
   bool d_validateCerts{true};
-  dnsdist_tickets_key_added_hook d_ticketsKeyAddedHook;
 };
 
 #endif /* HAVE_GNUTLS */
index c592701eedf9c367997a68dbe2423fceed1344d0..59817beefe5dc540d79c7192446f7cb951600f8c 100644 (file)
@@ -66,6 +66,8 @@ protected:
   bool d_resumedFromInactiveTicketKey{false};
 };
 
+using dnsdist_tickets_key_added_hook = std::function<void(const std::string& key)>;
+
 class TLSCtx
 {
 public:
@@ -81,11 +83,6 @@ public:
   {
     throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file");
   }
-  virtual void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& /* hook */)
-  {
-    throw std::runtime_error("This TLS backend does not have the capability to setup a hook for added tickets keys");
-  }
-
   void handleTicketsKeyRotation(time_t now)
   {
     if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) {
@@ -128,10 +125,25 @@ public:
     return false;
   }
 
+  static void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook)
+  {
+    TLSCtx::s_ticketsKeyAddedHook = hook;
+  }
+  static const dnsdist_tickets_key_added_hook& getTicketsKeyAddedHook()
+  {
+    return TLSCtx::s_ticketsKeyAddedHook;
+  }
+  static bool hasTicketsKeyAddedHook()
+  {
+    return TLSCtx::s_ticketsKeyAddedHook != nullptr;
+  }
 protected:
   std::atomic_flag d_rotatingTicketsKey;
   std::atomic<time_t> d_ticketsKeyNextRotation{0};
   time_t d_ticketsKeyRotationDelay{0};
+
+private:
+  static dnsdist_tickets_key_added_hook s_ticketsKeyAddedHook;
 };
 
 class TLSFrontend
@@ -156,13 +168,6 @@ public:
     }
   }
 
-  void setTicketsKeyAddedHook(const dnsdist_tickets_key_added_hook& hook)
-  {
-    if (d_ctx != nullptr) {
-      d_ctx->setTicketsKeyAddedHook(hook);
-    }
-  }
-
   void loadTicketsKeys(const std::string& file)
   {
     if (d_ctx != nullptr) {
index 6138c3e6363eb7917d85e1c4bf933eb500794d86..27c2de52fe19b6ac343eb6da5f0f30830217678f 100644 (file)
@@ -547,7 +547,7 @@ class TestTLSTicketsKeyAddedCallback(DNSDistTest):
         """
         LuaThread: Test the lua newThread interface
         """
-        self.sendConsoleCommand('getTLSFrontend(0):setTicketsKeyAddedHook(keyAddedCallback)');
+        self.sendConsoleCommand('setTicketsKeyAddedHook(keyAddedCallback)');
         called = self.sendConsoleCommand('callbackCalled')
         self.assertEqual(int(called), 0)
         self.sendConsoleCommand("getTLSFrontend(0):rotateTicketsKey()")