}
/* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */
-void libssl_set_npn_select_callback(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, int (*cb)(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg)
+void libssl_set_npn_select_callback(SSL_CTX* ctx, int (*cb)(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg)
{
#ifdef HAVE_SSL_CTX_SET_NEXT_PROTO_SELECT_CB
- SSL_CTX_set_next_proto_select_cb(ctx.get(), cb, arg);
+ SSL_CTX_set_next_proto_select_cb(ctx, cb, arg);
#endif
}
-void libssl_set_alpn_select_callback(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, int (*cb)(SSL* s, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg)
+void libssl_set_alpn_select_callback(SSL_CTX* ctx, int (*cb)(SSL* s, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg)
{
#ifdef HAVE_SSL_CTX_SET_ALPN_SELECT_CB
- SSL_CTX_set_alpn_select_cb(ctx.get(), cb, arg);
+ SSL_CTX_set_alpn_select_cb(ctx, cb, arg);
#endif
}
-bool libssl_set_alpn_protos(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, const std::vector<std::vector<uint8_t>>& protos)
+bool libssl_set_alpn_protos(SSL_CTX* ctx, const std::vector<std::vector<uint8_t>>& protos)
{
#ifdef HAVE_SSL_CTX_SET_ALPN_PROTOS
std::vector<uint8_t> wire;
wire.push_back(length);
wire.insert(wire.end(), proto.begin(), proto.end());
}
- return SSL_CTX_set_alpn_protos(ctx.get(), wire.data(), wire.size()) == 0;
+ return SSL_CTX_set_alpn_protos(ctx, wire.data(), wire.size()) == 0;
#else
return false;
#endif
}
/* client-side connection */
- OpenSSLTLSConnection(const std::string& hostname, int socket, const struct timeval& timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free)), d_hostname(hostname), d_timeout(timeout)
+ OpenSSLTLSConnection(const std::string& hostname, int socket, const struct timeval& timeout, std::shared_ptr<SSL_CTX>& tlsCtx): d_tlsCtx(tlsCtx), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx.get()), SSL_free)), d_hostname(hostname), d_timeout(timeout)
{
d_socket = socket;
static std::atomic_flag s_initTLSConnIndex;
std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
+ /* server context */
std::shared_ptr<OpenSSLFrontendContext> d_feContext;
+ /* client context */
+ std::shared_ptr<SSL_CTX> d_tlsCtx;
std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
std::string d_hostname;
struct timeval d_timeout;
{
public:
/* server side context */
- OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig)), d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
+ OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig))
{
d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
}
/* client side context */
- OpenSSLTLSIOCtx(const TLSContextParameters& params): d_tlsCtx(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(nullptr, SSL_CTX_free))
+ OpenSSLTLSIOCtx(const TLSContextParameters& params)
{
int sslOptions =
SSL_OP_NO_SSLv2 |
registerOpenSSLUser();
#ifdef HAVE_TLS_CLIENT_METHOD
- d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(TLS_client_method()), SSL_CTX_free);
+ d_tlsCtx = std::shared_ptr<SSL_CTX>(SSL_CTX_new(TLS_client_method()), SSL_CTX_free);
#else
- d_tlsCtx = std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free);
+ d_tlsCtx = std::shared_ptr<SSL_CTX>(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free);
#endif
if (!d_tlsCtx) {
ERR_print_errors_fp(stderr);
std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, int socket, const struct timeval& timeout) override
{
- return std::make_unique<OpenSSLTLSConnection>(host, socket, timeout, d_tlsCtx.get());
+ return std::make_unique<OpenSSLTLSConnection>(host, socket, timeout, d_tlsCtx);
}
void rotateTicketsKey(time_t now) override
{
if (d_feContext && d_feContext->d_tlsCtx) {
d_alpnProtos = protos;
- libssl_set_alpn_select_callback(d_feContext->d_tlsCtx, alpnServerSelectCallback, this);
+ libssl_set_alpn_select_callback(d_feContext->d_tlsCtx.get(), alpnServerSelectCallback, this);
return true;
}
if (d_tlsCtx) {
- return libssl_set_alpn_protos(d_tlsCtx, protos);
+ return libssl_set_alpn_protos(d_tlsCtx.get(), protos);
}
return false;
}
bool setNextProtocolSelectCallback(bool(*cb)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen)) override
{
d_nextProtocolSelectCallback = cb;
- libssl_set_npn_select_callback(d_tlsCtx, npnSelectCallback, this);
+ libssl_set_npn_select_callback(d_tlsCtx.get(), npnSelectCallback, this);
return true;
}
}
std::vector<std::vector<uint8_t>> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent
- std::shared_ptr<OpenSSLFrontendContext> d_feContext;
- std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx
+ std::shared_ptr<OpenSSLFrontendContext> d_feContext{nullptr};
+ std::shared_ptr<SSL_CTX> d_tlsCtx{nullptr}; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx
bool (*d_nextProtocolSelectCallback)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen){nullptr};
};
{
public:
/* server side connection */
- GnuTLSConnection(int socket, const struct timeval& timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_ticketsKey(ticketsKey), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit))
+ GnuTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<gnutls_certificate_credentials_st>& creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_creds(creds), d_ticketsKey(ticketsKey), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit))
{
unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
#ifdef GNUTLS_NO_SIGNAL
d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
conn = nullptr;
- if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds) != GNUTLS_E_SUCCESS) {
+ if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get()) != GNUTLS_E_SUCCESS) {
throw std::runtime_error("Error setting certificate and key to TLS connection");
}
}
/* client-side connection */
- GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, bool validateCerts): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host), d_client(true)
+ GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, std::shared_ptr<gnutls_certificate_credentials_st>& creds, const gnutls_priority_t priorityCache, bool validateCerts): d_creds(creds), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host), d_client(true)
{
unsigned int sslOptions = GNUTLS_CLIENT | GNUTLS_NONBLOCK;
#ifdef GNUTLS_NO_SIGNAL
d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
conn = nullptr;
- int rc = gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, creds);
+ int rc = gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get());
if (rc != GNUTLS_E_SUCCESS) {
throw std::runtime_error("Error setting certificate and key to TLS connection: " + std::string(gnutls_strerror(rc)));
}
}
private:
- std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
+ std::shared_ptr<gnutls_certificate_credentials_st> d_creds;
std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
+ std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
std::string d_host;
bool d_client{false};
bool d_handshakeDone{false};
{
public:
/* server side context */
- GnuTLSIOCtx(TLSFrontend& fe): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(fe.d_tlsConfig.d_enableTickets)
+ GnuTLSIOCtx(TLSFrontend& fe): d_enableTickets(fe.d_tlsConfig.d_enableTickets)
{
int rc = 0;
d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
}
- d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
+ d_creds = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
creds = nullptr;
for (const auto& pair : fe.d_tlsConfig.d_certKeyPairs) {
}
/* client side context */
- GnuTLSIOCtx(const TLSContextParameters& params): d_creds(std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(nullptr, gnutls_certificate_free_credentials)), d_enableTickets(true), d_validateCerts(params.d_validateCertificates)
+ GnuTLSIOCtx(const TLSContextParameters& params): d_contextParameters(std::make_unique<TLSContextParameters>(params)), d_enableTickets(true), d_validateCerts(params.d_validateCertificates)
{
int rc = 0;
throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
}
- d_creds = std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)>(creds, gnutls_certificate_free_credentials);
+ d_creds = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
creds = nullptr;
if (params.d_validateCertificates) {
ticketsKey = *(d_ticketsKey.read_lock());
}
- auto connection = std::make_unique<GnuTLSConnection>(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets);
+ auto connection = std::make_unique<GnuTLSConnection>(socket, timeout, d_creds, d_priorityCache, ticketsKey, d_enableTickets);
if (!d_protos.empty()) {
connection->setALPNProtos(d_protos);
}
return connection;
}
+ static std::shared_ptr<gnutls_certificate_credentials_st> getPerThreadCredentials(bool validate, const std::string& caStore)
+ {
+ static thread_local std::map<std::pair<bool, std::string>, std::shared_ptr<gnutls_certificate_credentials_st>> t_credentials;
+ auto& entry = t_credentials[{validate, caStore}];
+ if (!entry) {
+ gnutls_certificate_credentials_t creds;
+ int rc = gnutls_certificate_allocate_credentials(&creds);
+ if (rc != GNUTLS_E_SUCCESS) {
+ throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
+ }
+
+ entry = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
+ creds = nullptr;
+
+ if (validate) {
+ if (caStore.empty()) {
+ rc = gnutls_certificate_set_x509_system_trust(entry.get());
+ if (rc < 0) {
+ throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
+ }
+ }
+ else {
+ rc = gnutls_certificate_set_x509_trust_file(entry.get(), caStore.c_str(), GNUTLS_X509_FMT_PEM);
+ if (rc < 0) {
+ throw std::runtime_error("Error adding '" + caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
+ }
+ }
+ }
+ }
+ return entry;
+ }
+
std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, int socket, const struct timeval& timeout) override
{
- auto connection = std::make_unique<GnuTLSConnection>(host, socket, timeout, d_creds.get(), d_priorityCache, d_validateCerts);
+ auto creds = getPerThreadCredentials(d_contextParameters->d_validateCertificates, d_contextParameters->d_caStore);
+ auto connection = std::make_unique<GnuTLSConnection>(host, socket, timeout, creds, d_priorityCache, d_validateCerts);
if (!d_protos.empty()) {
connection->setALPNProtos(d_protos);
}
}
private:
- std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)> d_creds;
+ /* client context parameters */
+ std::unique_ptr<TLSContextParameters> d_contextParameters{nullptr};
+ std::shared_ptr<gnutls_certificate_credentials_st> d_creds;
std::vector<std::vector<uint8_t>> d_protos;
gnutls_priority_t d_priorityCache{nullptr};
SharedLockGuarded<std::shared_ptr<GnuTLSTicketsKey>> d_ticketsKey{nullptr};