]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/dnsdistdist/libssl.cc
dnsdist: Add OCSP stapling (from files) for DoT and DoH
[thirdparty/pdns.git] / pdns / dnsdistdist / libssl.cc
index 684a276c314b05de11411c6dd78ddba9ae16f575..ce7e2660b89c9a16ed75f224dcaf07e7c8722231 100644 (file)
@@ -5,9 +5,13 @@
 #ifdef HAVE_LIBSSL
 
 #include <atomic>
+#include <fstream>
+#include <cstring>
 #include <pthread.h>
+
 #include <openssl/conf.h>
 #include <openssl/err.h>
+#include <openssl/ocsp.h>
 #include <openssl/rand.h>
 #include <openssl/ssl.h>
 
@@ -86,4 +90,131 @@ void unregisterOpenSSLUser()
 #endif
 }
 
+int libssl_ocsp_stapling_callback(SSL* ssl, const std::map<int, std::string>& ocspMap)
+{
+  auto pkey = SSL_get_privatekey(ssl);
+  if (pkey == nullptr) {
+    return SSL_TLSEXT_ERR_NOACK;
+  }
+
+  /* look for an OCSP response for the corresponding private key type (RSA, ECDSA..) */
+  const auto& data = ocspMap.find(EVP_PKEY_base_id(pkey));
+  if (data == ocspMap.end()) {
+    return SSL_TLSEXT_ERR_NOACK;
+  }
+
+  /* we need to allocate a copy because OpenSSL will free the pointer passed to SSL_set_tlsext_status_ocsp_resp() */
+  void* copy = OPENSSL_malloc(data->second.size());
+  if (copy == nullptr) {
+    return SSL_TLSEXT_ERR_NOACK;
+  }
+
+  memcpy(copy, data->second.data(), data->second.size());
+  SSL_set_tlsext_status_ocsp_resp(ssl, copy, data->second.size());
+  return SSL_TLSEXT_ERR_OK;
+}
+
+static bool libssl_validate_ocsp_response(const std::string& response)
+{
+  auto responsePtr = reinterpret_cast<const unsigned char *>(response.data());
+  std::unique_ptr<OCSP_RESPONSE, void(*)(OCSP_RESPONSE*)> resp(d2i_OCSP_RESPONSE(nullptr, &responsePtr, response.size()), OCSP_RESPONSE_free);
+  if (resp == nullptr) {
+    throw std::runtime_error("Unable to parse OCSP response");
+  }
+
+  int status = OCSP_response_status(resp.get());
+  if (status != OCSP_RESPONSE_STATUS_SUCCESSFUL) {
+    throw std::runtime_error("OCSP response status is not successful: " + std::to_string(status));
+  }
+
+  std::unique_ptr<OCSP_BASICRESP, void(*)(OCSP_BASICRESP*)> basic(OCSP_response_get1_basic(resp.get()), OCSP_BASICRESP_free);
+  if (basic == nullptr) {
+    throw std::runtime_error("Error getting a basic OCSP response");
+  }
+
+  if (OCSP_resp_count(basic.get()) != 1) {
+    throw std::runtime_error("More than one single response in an OCSP basic response");
+  }
+
+  auto singleResponse = OCSP_resp_get0(basic.get(), 0);
+  if (singleResponse == nullptr) {
+    throw std::runtime_error("Error getting a single response from the basic OCSP response");
+  }
+
+  int reason;
+  ASN1_GENERALIZEDTIME* revTime = nullptr;
+  ASN1_GENERALIZEDTIME* thisUpdate = nullptr;
+  ASN1_GENERALIZEDTIME* nextUpdate = nullptr;
+
+  auto singleResponseStatus = OCSP_single_get0_status(singleResponse, &reason, &revTime, &thisUpdate, &nextUpdate);
+  if (singleResponseStatus != V_OCSP_CERTSTATUS_GOOD) {
+    throw std::runtime_error("Invalid status for OCSP single response (" + std::to_string(singleResponseStatus) + ")");
+  }
+  if (thisUpdate == nullptr || nextUpdate == nullptr) {
+    throw std::runtime_error("Error getting validity of OCSP single response");
+  }
+
+  auto validityResult = OCSP_check_validity(thisUpdate, nextUpdate, /* 5 minutes of leeway */ 5 * 60, -1);
+  if (validityResult == 0) {
+    throw std::runtime_error("OCSP single response is not yet, or no longer, valid");
+  }
+
+  return true;
+}
+
+std::map<int, std::string> libssl_load_ocsp_responses(const std::vector<std::string>& ocspFiles, std::vector<int> keyTypes)
+{
+  std::map<int, std::string> ocspResponses;
+
+  if (ocspFiles.size() > keyTypes.size()) {
+    throw std::runtime_error("More OCSP files than certificates and keys loaded!");
+  }
+
+  size_t count = 0;
+  for (const auto& filename : ocspFiles) {
+    std::ifstream file(filename, std::ios::binary);
+    std::string content;
+    while(file) {
+      char buffer[4096];
+      file.read(buffer, sizeof(buffer));
+      if (file.bad()) {
+        file.close();
+        throw std::runtime_error("Unable to load OCSP response from '" + filename + "'");
+      }
+      content.append(buffer, file.gcount());
+    }
+    file.close();
+
+    try {
+      libssl_validate_ocsp_response(content);
+      ocspResponses.insert({keyTypes.at(count), std::move(content)});
+    }
+    catch (const std::exception& e) {
+      throw std::runtime_error("Error checking the validity of OCSP response from '" + filename + "': " + e.what());
+    }
+    ++count;
+  }
+
+  return ocspResponses;
+}
+
+int libssl_get_last_key_type(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx)
+{
+#if (OPENSSL_VERSION_NUMBER >= 0x10002000L && !defined LIBRESSL_VERSION_NUMBER)
+  auto pkey = SSL_CTX_get0_privatekey(ctx.get());
+#else
+  auto temp = std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(ctx.get()), SSL_free);
+  if (!temp) {
+    return -1;
+  }
+  auto pkey = SSL_get_privatekey(temp.get());
+#endif
+
+  if (!pkey) {
+    return -1;
+  }
+
+  return EVP_PKEY_base_id(pkey);
+}
+
 #endif /* HAVE_LIBSSL */