]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
There is life in TLS...
authorOtto <otto.moerbeek@open-xchange.com>
Wed, 3 Feb 2021 13:29:15 +0000 (14:29 +0100)
committerOtto <otto.moerbeek@open-xchange.com>
Wed, 3 Feb 2021 13:29:15 +0000 (14:29 +0100)
pdns/sdig.cc
pdns/tcpiohandler.cc
pdns/tcpiohandler.hh

index ac89a2a6d7779e50e3ea3fa6c9b47625867e0fe3..82d4c16088cdf639f10893d94d2dff9e1f25a407 100644 (file)
 #include "minicurl.hh"
 #endif
 
+#include "tcpiohandler.hh"
+
 StatBag S;
 
+// Vars below used by tcpiohandler.cc
+bool g_verbose = true;
+bool g_syslog = false;
+
 static bool hidettl = false;
 
 static string ttl(uint32_t ttl)
@@ -33,7 +39,7 @@ static void usage()
   cerr << "sdig" << endl;
   cerr << "Syntax: sdig IP-ADDRESS-OR-DOH-URL PORT QNAME QTYPE "
           "[dnssec] [ednssubnet SUBNET/MASK] [hidesoadetails] [hidettl] "
-          "[recurse] [showflags] [tcp] [xpf XPFDATA] [class CLASSNUM] "
+          "[recurse] [showflags] [tcp] [dot] [xpf XPFDATA] [class CLASSNUM] "
           "[proxy UDP(0)/TCP(1) SOURCE-IP-ADDRESS-AND-PORT DESTINATION-IP-ADDRESS-AND-PORT]"
        << endl;
 }
@@ -195,12 +201,17 @@ static void printReply(const string& reply, bool showflags, bool hidesoadetails)
 
 int main(int argc, char** argv)
 try {
+    /* default timeout of 10s */
+  int timeout = 10;
   bool dnssec = false;
   bool recurse = false;
   bool tcp = false;
   bool showflags = false;
   bool hidesoadetails = false;
   bool doh = false;
+  bool dot = false;
+  bool fastOpen = false;
+  bool insecureDoT = false;
   bool fromstdin = false;
   boost::optional<Netmask> ednsnm;
   uint16_t xpfcode = 0, xpfversion = 0, xpfproto = 0;
@@ -241,6 +252,10 @@ try {
         hidettl = true;
       if (strcmp(argv[i], "tcp") == 0)
         tcp = true;
+      if (strcmp(argv[i], "dot") == 0)
+        dot = true;
+      if (strcmp(argv[i], "insecure") == 0)
+        insecureDoT = true;
       if (strcmp(argv[i], "ednssubnet") == 0) {
         if (argc < i + 2) {
           cerr << "ednssubnet needs an argument" << endl;
@@ -279,6 +294,17 @@ try {
     }
   }
 
+  if (dot) {
+    tcp = true;
+  }
+
+#ifndef HAVE_DNS_OVER_TLS
+  if (dot) {
+    cerr << "DoT requested but not compiled in" << endl;
+    exit(EXIT_FAILURE);
+  }
+#endif
+
   string reply;
   ComboAddress dest;
   if (*argv[1] == 'h') {
@@ -344,41 +370,46 @@ try {
 
     printReply(reply, showflags, hidesoadetails);
   } else if (tcp) {
+    std::shared_ptr<TLSCtx> tlsCtx{nullptr};
+    if (dot) {
+      TLSContextParameters tlsParams;
+      tlsParams.d_provider = "openssl";
+      tlsParams.d_validateCertificates = !insecureDoT;
+      tlsCtx = getTLSContext(tlsParams);
+    }
     uint16_t counter = 0;
     Socket sock(dest.sin4.sin_family, SOCK_STREAM);
-    sock.connect(dest);
-    sock.writen(proxyheader);
+    SConnectWithTimeout(sock.getHandle(), dest, timeout);
+    TCPIOHandler handler("buab", sock.getHandle(), timeout, tlsCtx, time(nullptr));
+    handler.connect(fastOpen, dest, timeout);
+    // we are writing the proxyheader inside the TLS connection. Is that right?
+    if (proxyheader.size() > 0 && handler.write(proxyheader.data(), proxyheader.size(), timeout) != proxyheader.size()) {
+      throw PDNSException("tcp write failed");
+    }
+
     for (const auto& it : questions) {
       vector<uint8_t> packet;
       s_expectedIDs.insert(counter);
       fillPacket(packet, it.first, it.second, dnssec, ednsnm, recurse, xpfcode,
                  xpfversion, xpfproto, xpfsrc, xpfdst, qclass, counter);
       counter++;
-
       uint16_t len = htons(packet.size());
-      if (sock.write((const char *)&len, 2) != 2)
+      if (handler.write(&len, sizeof(len), timeout) != sizeof(len))
         throw PDNSException("tcp write failed");
-      string question(packet.begin(), packet.end());
-      sock.writen(question);
+      if (handler.write(packet.data(), packet.size(), timeout) != packet.size()) {
+        throw PDNSException("tcp write failed");
+      }
     }
     for (size_t i = 0; i < questions.size(); i++) {
       uint16_t len;
-      if (sock.read((char *)&len, 2) != 2)
+      if (handler.read((char *)&len, sizeof(len), timeout) != sizeof(len)) {
         throw PDNSException("tcp read failed");
-
+      }
       len = ntohs(len);
-      char* creply = new char[len];
-      int n = 0;
-      int numread;
-      while (n < len) {
-        numread = sock.read(creply + n, len - n);
-        if (numread < 0)
-          throw PDNSException("tcp read failed");
-        n += numread;
+      reply.resize(len);
+      if (handler.read(&reply[0], len, timeout) != len) {
+        throw PDNSException("tcp read failed");
       }
-
-      reply = string(creply, len);
-      delete[] creply;
       printReply(reply, showflags, hidesoadetails);
     }
   } else // udp
@@ -391,7 +422,7 @@ try {
     Socket sock(dest.sin4.sin_family, SOCK_DGRAM);
     question = proxyheader + question;
     sock.sendTo(question, dest);
-    int result = waitForData(sock.getHandle(), 10);
+    int result = waitForData(sock.getHandle(), timeout);
     if (result < 0)
       throw std::runtime_error("Error waiting for data: " + stringerror());
     if (!result)
index e308e2791bb80bcb3671a599b0c5e67ee37682b3..3cbebeff8c85fe10a3f903b66787536043e62be5 100644 (file)
 
 #ifdef HAVE_DNS_OVER_TLS
 #ifdef HAVE_LIBSSL
+
+#ifdef ___OpenBSD__
+#define LIBRESSL_HAS_TLS1_3
+#endif
+
 #include <openssl/conf.h>
 #include <openssl/err.h>
 #include <openssl/rand.h>
@@ -48,6 +53,7 @@ public:
 class OpenSSLTLSConnection: public TLSConnection
 {
 public:
+  /* server side connection */
   OpenSSLTLSConnection(int socket, unsigned int timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(feContext), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout)
   {
     d_socket = socket;
@@ -75,6 +81,40 @@ public:
     SSL_set_ex_data(d_conn.get(), s_tlsConnIndex, this);
   }
 
+  /* client-side connection */
+  OpenSSLTLSConnection(const std::string& hostname, int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free)), d_hostname(hostname), d_timeout(timeout)
+  {
+    d_socket = socket;
+
+    if (!d_conn) {
+      vinfolog("Error creating TLS object");
+      if (g_verbose) {
+        ERR_print_errors_fp(stderr);
+      }
+      throw std::runtime_error("Error creating TLS object");
+    }
+
+    if (!SSL_set_fd(d_conn.get(), d_socket)) {
+      throw std::runtime_error("Error assigning socket");
+    }
+
+#if (OPENSSL_VERSION_NUMBER >= 0x1010000fL)
+    // XXX SSL_set_hostflags(d_conn.get(), X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
+    if (SSL_set1_host(d_conn.get(), d_hostname.c_str()) != 1) {
+      throw std::runtime_error("Error setting TLS hostname for certificate validation");
+    }
+#elif (OPENSSL_VERSION_NUMBER >= 0x10002000L)
+    X509_VERIFY_PARAM *param = SSL_get0_param(d_conn.get());
+    /* Enable automatic hostname checks */
+    X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
+    if (X509_VERIFY_PARAM_set1_host(param, d_hostname.c_str(), d_hostname.size()) != 1) {
+      throw std::runtime_error("Error setting TLS hostname for certificate validation");
+    }
+#else
+    /* no hostname validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
+#endif
+  }
+
   IOState convertIORequestToIOState(int res) const
   {
     int error = SSL_get_error(d_conn.get(), res);
@@ -115,6 +155,55 @@ public:
     }
   }
 
+  IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
+  {
+    /* sorry */
+    (void) fastOpen;
+    (void) remote;
+
+    int res = SSL_connect(d_conn.get());
+    if (res == 1) {
+      return IOState::Done;
+    }
+    else if (res < 0) {
+      return convertIORequestToIOState(res);
+    }
+
+    throw std::runtime_error("Error establishing a TLS connection");
+  }
+
+  void connect(bool fastOpen, const ComboAddress& remote, unsigned int timeout) override
+  {
+    /* sorry */
+    (void) fastOpen;
+    (void) remote;
+
+    time_t start = 0;
+    unsigned int remainingTime = timeout;
+    if (timeout) {
+      start = time(nullptr);
+    }
+
+    int res = 0;
+    do {
+      res = SSL_connect(d_conn.get());
+      if (res < 0) {
+        handleIORequest(res, remainingTime);
+      }
+
+      if (timeout) {
+        time_t now = time(nullptr);
+        unsigned int elapsed = now - start;
+        if (now < start || elapsed >= remainingTime) {
+          throw runtime_error("Timeout while establishing TLS connection");
+        }
+        start = now;
+        remainingTime -= elapsed;
+      }
+    }
+    while (res != 1);
+  }
+
   IOState tryHandshake() override
   {
     int res = SSL_accept(d_conn.get());
@@ -285,6 +374,7 @@ private:
 
   std::shared_ptr<OpenSSLFrontendContext> d_feContext;
   std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
+  std::string d_hostname;
   unsigned int d_timeout;
 };
 
@@ -294,7 +384,8 @@ int OpenSSLTLSConnection::s_tlsConnIndex = -1;
 class OpenSSLTLSIOCtx: public TLSCtx
 {
 public:
-  OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig))
+  /* server side context */
+  OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig)), d_ticketKeys{0}, d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
   {
     d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
 
@@ -328,8 +419,76 @@ public:
     }
   }
 
+  /* client side context */
+  OpenSSLTLSIOCtx(const TLSContextParameters& params): d_ticketKeys(0), d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
+  {
+    int sslOptions =
+      SSL_OP_NO_SSLv2 |
+      SSL_OP_NO_SSLv3 |
+      SSL_OP_NO_COMPRESSION |
+      SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
+      SSL_OP_SINGLE_DH_USE |
+      SSL_OP_SINGLE_ECDH_USE |
+      SSL_OP_CIPHER_SERVER_PREFERENCE;
+
+#if 0 // XXX
+    if (s_users.fetch_add(1) == 0) {
+      registerOpenSSLUser();
+
+      s_ticketsKeyIndex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
+
+      if (s_ticketsKeyIndex == -1) {
+        throw std::runtime_error("Error getting an index for tickets key");
+      }
+    }
+#endif
+
+#ifdef HAVE_TLS_CLIENT_METHOD
+    d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(TLS_client_method()), SSL_CTX_free);
+#else
+    d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free);
+#endif
+    if (!d_tlsCtx) {
+      ERR_print_errors_fp(stderr);
+      throw std::runtime_error("Error creating TLS context");
+    }
+
+    SSL_CTX_set_options(d_tlsCtx.get(), sslOptions);
+#if defined(SSL_CTX_set_ecdh_auto)
+    SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1);
+#endif
+
+    if (!params.d_ciphers.empty()) {
+      if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), params.d_ciphers.c_str()) != 1) {
+        ERR_print_errors_fp(stderr);
+        throw std::runtime_error("Error setting the cipher list to '" + params.d_ciphers + "' for the TLS context");
+      }
+    }
+#ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
+    if (!params.d_ciphers13.empty()) {
+      if (SSL_CTX_set_ciphersuites(d_tlsCtx.get(), params.d_ciphers13.c_str()) != 1) {
+        ERR_print_errors_fp(stderr);
+        throw std::runtime_error("Error setting the TLS 1.3 cipher list to '" + params.d_ciphers13 + "' for the TLS context");
+      }
+    }
+#endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
+
+    if (params.d_validateCertificates) {
+      SSL_CTX_set_verify(d_tlsCtx.get(), SSL_VERIFY_PEER, nullptr);
+#if (OPENSSL_VERSION_NUMBER < 0x10002000L)
+      warnlog("TLS hostname validation requested but not supported for OpenSSL < 1.0.2");
+#endif
+    }
+  }
+
   ~OpenSSLTLSIOCtx() override
   {
+    d_tlsCtx.reset();
+#if 0 // XXX
+    if (s_users.fetch_sub(1) == 1) {
+      unregisterOpenSSLUser();
+    }
+#endif
   }
 
   static int ticketKeyCb(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc)
@@ -373,6 +532,11 @@ public:
     return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(socket, timeout, d_feContext));
   }
 
+  std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, int socket, unsigned int timeout) override
+  {
+    return std::unique_ptr<OpenSSLTLSConnection>(new OpenSSLTLSConnection(host, socket, timeout, d_tlsCtx.get()));
+  }
+
   void rotateTicketsKey(time_t now) override
   {
     d_feContext->d_ticketKeys.rotateTicketsKey(now);
@@ -398,6 +562,9 @@ public:
 
 private:
   std::shared_ptr<OpenSSLFrontendContext> d_feContext;
+  OpenSSLTLSTicketKeysRing d_ticketKeys;
+  std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx;
+  static std::atomic<uint64_t> s_users;
 };
 
 #endif /* HAVE_LIBSSL */
@@ -498,7 +665,7 @@ private:
 class GnuTLSConnection: public TLSConnection
 {
 public:
-
+  /* server side connection */
   GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
   {
     unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
@@ -538,6 +705,116 @@ public:
     gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
   }
 
+  /* client-side connection */
+  GnuTLSConnection(const std::string& host, int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, bool validateCerts): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host)
+  {
+    unsigned int sslOptions = GNUTLS_CLIENT | GNUTLS_NONBLOCK;
+#ifdef GNUTLS_NO_SIGNAL
+    sslOptions |= GNUTLS_NO_SIGNAL;
+#endif
+
+    d_socket = socket;
+
+    gnutls_session_t conn;
+    if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
+      throw std::runtime_error("Error creating TLS connection");
+    }
+
+    d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
+    conn = nullptr;
+
+    int rc = gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds);
+    if (rc != GNUTLS_E_SUCCESS) {
+      throw std::runtime_error("Error setting certificate and key to TLS connection: " + std::string(gnutls_strerror(rc)));
+    }
+
+    rc = gnutls_priority_set(d_conn.get(), priorityCache);
+    if (rc != GNUTLS_E_SUCCESS) {
+      throw std::runtime_error("Error setting ciphers to TLS connection: " + std::string(gnutls_strerror(rc)));
+    }
+
+    gnutls_transport_set_int(d_conn.get(), d_socket);
+
+    /* timeouts are in milliseconds */
+    gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
+    gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
+
+    if (!d_host.empty()) {
+      gnutls_session_set_verify_cert(d_conn.get(), d_host.c_str(), GNUTLS_VERIFY_ALLOW_UNSORTED_CHAIN);
+      rc = gnutls_server_name_set(d_conn.get(), GNUTLS_NAME_DNS, d_host.c_str(), d_host.size());
+      if (rc != GNUTLS_E_SUCCESS) {
+        throw std::runtime_error("Error setting the SNI value to '" + d_host + "' on TLS connection: " + std::string(gnutls_strerror(rc)));
+      }
+    }
+  }
+
+  IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
+  {
+    int ret = 0;
+
+    if (fastOpen) {
+#ifdef HAVE_GNUTLS_TRANSPORT_SET_FASTOPEN
+      gnutls_transport_set_fastopen(d_conn.get(), d_socket, const_cast<struct sockaddr*>(reinterpret_cast<const struct sockaddr*>(&remote)), remote.getSocklen(), 0);
+#endif
+    }
+
+    do {
+      ret = gnutls_handshake(d_conn.get());
+      if (ret == GNUTLS_E_SUCCESS) {
+        return IOState::Done;
+      }
+      else if (ret == GNUTLS_E_AGAIN) {
+        int direction = gnutls_record_get_direction(d_conn.get());
+        return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
+      }
+      else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
+        throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
+      }
+    } while (ret == GNUTLS_E_INTERRUPTED);
+
+    throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
+  }
+
+  void connect(bool fastOpen, const ComboAddress& remote, unsigned int timeout) override
+  {
+    time_t start = 0;
+    unsigned int remainingTime = timeout;
+    if (timeout) {
+      start = time(nullptr);
+    }
+
+    IOState state;
+    do {
+      state = tryConnect(fastOpen, remote);
+      if (state == IOState::Done) {
+        return;
+      }
+      else if (state == IOState::NeedRead) {
+        int result = waitForData(d_socket, remainingTime);
+        if (result <= 0) {
+          throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
+        }
+      }
+      else if (state == IOState::NeedWrite) {
+        int result = waitForRWData(d_socket, false, remainingTime, 0);
+        if (result <= 0) {
+          throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
+        }
+      }
+
+      if (timeout) {
+        time_t now = time(nullptr);
+        unsigned int elapsed = now - start;
+        if (now < start || elapsed >= remainingTime) {
+          throw runtime_error("Timeout while establishing TLS connection");
+        }
+        start = now;
+        remainingTime -= elapsed;
+      }
+    }
+    while (state != IOState::Done);
+  }
+
   void doHandshake() override
   {
     int ret = 0;
@@ -760,11 +1037,13 @@ public:
 private:
   std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
   std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
+  std::string d_host;
 };
 
 class GnuTLSIOCtx: public TLSCtx
 {
 public:
+  /* server side context */
   GnuTLSIOCtx(TLSFrontend& fe): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_tlsConfig.d_enableTickets)
   {
     int rc = 0;
@@ -820,6 +1099,41 @@ public:
     }
   }
 
+  /* client side context */
+  GnuTLSIOCtx(const TLSContextParameters& params): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(true), d_validateCerts(params.d_validateCertificates)
+  {
+    int rc = 0;
+
+    gnutls_certificate_credentials_t creds;
+    rc = gnutls_certificate_allocate_credentials(&creds);
+    if (rc != GNUTLS_E_SUCCESS) {
+      throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
+    }
+
+    d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
+    creds = nullptr;
+
+    if (params.d_validateCertificates) {
+      if (params.d_caStore.empty()) {
+        rc = gnutls_certificate_set_x509_system_trust(d_creds.get());
+        if (rc < 0) {
+          throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
+        }
+      }
+      else {
+        rc = gnutls_certificate_set_x509_trust_file(d_creds.get(), params.d_caStore.c_str(), GNUTLS_X509_FMT_PEM);
+        if (rc < 0) {
+          throw std::runtime_error("Error adding '" + params.d_caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
+        }
+      }
+    }
+
+    rc = gnutls_priority_init(&d_priorityCache, params.d_ciphers.empty() ? "NORMAL" : params.d_ciphers.c_str(), nullptr);
+    if (rc != GNUTLS_E_SUCCESS) {
+      throw std::runtime_error("Error setting up TLS cipher preferences to 'NORMAL' (" + std::string(gnutls_strerror(rc)) + ")");
+    }
+  }
+
   virtual ~GnuTLSIOCtx() override
   {
     d_creds.reset();
@@ -842,6 +1156,11 @@ public:
     return std::unique_ptr<GnuTLSConnection>(new GnuTLSConnection(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets));
   }
 
+  std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, int socket, unsigned int timeout) override
+  {
+    return std::unique_ptr<GnuTLSConnection>(new GnuTLSConnection(host, socket, timeout, d_creds.get(), d_priorityCache, d_validateCerts));
+  }
+
   void rotateTicketsKey(time_t now) override
   {
     if (!d_enableTickets) {
@@ -889,6 +1208,7 @@ private:
   std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey{nullptr};
   ReadWriteLock d_lock;
   bool d_enableTickets{true};
+  bool d_validateCerts{true};
 };
 
 #endif /* HAVE_GNUTLS */
@@ -924,3 +1244,31 @@ bool TLSFrontend::setupTLS()
 #endif /* HAVE_DNS_OVER_TLS */
   return true;
 }
+
+std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params)
+{
+#ifdef HAVE_DNS_OVER_TLS
+  /* get the "best" available provider */
+  if (!params.d_provider.empty()) {
+#ifdef HAVE_GNUTLS
+    if (params.d_provider == "gnutls") {
+      return std::make_shared<GnuTLSIOCtx>(params);
+    }
+#endif /* HAVE_GNUTLS */
+#ifdef HAVE_LIBSSL
+    if (params.d_provider == "openssl") {
+      return std::make_shared<OpenSSLTLSIOCtx>(params);
+    }
+#endif /* HAVE_LIBSSL */
+  }
+#ifdef HAVE_GNUTLS
+  return std::make_shared<GnuTLSIOCtx>(params);
+#else /* HAVE_GNUTLS */
+#ifdef HAVE_LIBSSL
+  return std::make_shared<OpenSSLTLSIOCtx>(params);
+#endif /* HAVE_LIBSSL */
+#endif /* HAVE_GNUTLS */
+
+#endif /* HAVE_DNS_OVER_TLS */
+  return nullptr;
+}
index 98f1e8ed55557ed05a4e5f1cc5f58ac3620ca08c..f326438290f5b09ad842a358d8a912c0a4b5d5f8 100644 (file)
@@ -13,6 +13,8 @@ class TLSConnection
 public:
   virtual ~TLSConnection() { }
   virtual void doHandshake() = 0;
+  virtual IOState tryConnect(bool fastOpen, const ComboAddress& remote) = 0;
+  virtual void connect(bool fastOpen, const ComboAddress& remote, unsigned int timeout) = 0;
   virtual IOState tryHandshake() = 0;
   virtual size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0) = 0;
   virtual size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) = 0;
@@ -59,6 +61,7 @@ public:
   }
   virtual ~TLSCtx() {}
   virtual std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) = 0;
+  virtual std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, int socket, unsigned int timeout) = 0;
   virtual void rotateTicketsKey(time_t now) = 0;
   virtual void loadTicketsKeys(const std::string& file)
   {
@@ -177,6 +180,14 @@ private:
 class TCPIOHandler
 {
 public:
+  enum class Type { Client, Server };
+
+  TCPIOHandler(const std::string& host, int socket, unsigned int timeout, std::shared_ptr<TLSCtx> ctx, time_t now): d_socket(socket)
+  {
+    if (ctx) {
+      d_conn = ctx->getClientConnection(host, d_socket, timeout);
+    }
+  }
 
   TCPIOHandler(int socket, unsigned int timeout, std::shared_ptr<TLSCtx> ctx, time_t now): d_socket(socket)
   {
@@ -195,6 +206,28 @@ public:
     }
   }
 
+  IOState tryConnect(bool fastOpen, const ComboAddress& remote)
+  {
+    /* yes, this is only the TLS connect not the socket one,
+       sorry about that */
+    if (d_conn) {
+      return d_conn->tryConnect(fastOpen, remote);
+    }
+    d_fastOpen = fastOpen;
+
+    return IOState::Done;
+  }
+
+  void connect(bool fastOpen, const ComboAddress& remote, unsigned int timeout)
+  {
+    /* yes, this is only the TLS connect not the socket one,
+       sorry about that */
+    if (d_conn) {
+      d_conn->connect(fastOpen, remote, timeout);
+    }
+    d_fastOpen = fastOpen;
+  }
+
   IOState tryHandshake()
   {
     if (d_conn) {
@@ -342,4 +375,17 @@ public:
 private:
   std::unique_ptr<TLSConnection> d_conn{nullptr};
   int d_socket{-1};
+  bool d_fastOpen{false};
 };
+
+struct TLSContextParameters
+{
+  std::string d_provider;
+  std::string d_ciphers;
+  std::string d_ciphers13;
+  std::string d_caStore;
+  bool d_validateCertificates{true};
+};
+
+std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params);
+