]> git.ipfire.org Git - thirdparty/squid.git/blobdiff - src/ssl/cert_validate_message.cc
Source Format Enforcement (#763)
[thirdparty/squid.git] / src / ssl / cert_validate_message.cc
index 8983b20a977a1d2376e0a636955db8d988c6b226..48fab388b7e8dd07534c5e660f300a51cf9ce8be 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 1996-2016 The Squid Software Foundation and contributors
+ * Copyright (C) 1996-2021 The Squid Software Foundation and contributors
  *
  * Squid software is distributed under GPLv2+ license and includes
  * contributions from numerous individuals and organizations.
 #include "acl/FilledChecklist.h"
 #include "globals.h"
 #include "helper.h"
+#include "security/CertError.h"
 #include "ssl/cert_validate_message.h"
 #include "ssl/ErrorDetail.h"
 #include "ssl/support.h"
 #include "util.h"
 
+/// Retrieves the certificates chain used to verify the peer.
+/// This is the full chain built by OpenSSL while verifying the server
+/// certificate or, if this is not available, the chain sent by server.
+/// \return the certificates chain or nil
+static STACK_OF(X509) *
+PeerValidationCertificatesChain(const Security::SessionPointer &ssl)
+{
+    assert(ssl);
+    // The full chain built by openSSL while verifying the server cert,
+    // retrieved from verify callback:
+    if (const auto certs = static_cast<STACK_OF(X509) *>(SSL_get_ex_data(ssl.get(), ssl_ex_index_ssl_cert_chain)))
+        return certs;
+
+    /// Last resort: certificates chain sent by server
+    return SSL_get_peer_cert_chain(ssl.get()); // may be nil
+}
+
 void
 Ssl::CertValidationMsg::composeRequest(CertValidationRequest const &vcert)
 {
     body.clear();
     body += Ssl::CertValidationMsg::param_host + "=" + vcert.domainName;
-    STACK_OF(X509) *peerCerts = static_cast<STACK_OF(X509) *>(SSL_get_ex_data(vcert.ssl, ssl_ex_index_ssl_cert_chain));
 
-    if (const char *sslVersion = SSL_get_version(vcert.ssl))
+    if (const char *sslVersion = SSL_get_version(vcert.ssl.get()))
         body += "\n" +  Ssl::CertValidationMsg::param_proto_version + "=" + sslVersion;
 
-    if (const char *cipherName = SSL_CIPHER_get_name(SSL_get_current_cipher(vcert.ssl)))
+    if (const char *cipherName = SSL_CIPHER_get_name(SSL_get_current_cipher(vcert.ssl.get())))
         body += "\n" +  Ssl::CertValidationMsg::param_cipher + "=" + cipherName;
 
-    if (!peerCerts)
-        peerCerts = SSL_get_peer_cert_chain(vcert.ssl);
-
+    STACK_OF(X509) *peerCerts = PeerValidationCertificatesChain(vcert.ssl);
     if (peerCerts) {
         Ssl::BIO_Pointer bio(BIO_new(BIO_s_mem()));
         for (int i = 0; i < sk_X509_num(peerCerts); ++i) {
@@ -48,7 +63,7 @@ Ssl::CertValidationMsg::composeRequest(CertValidationRequest const &vcert)
 
     if (vcert.errors) {
         int i = 0;
-        for (const Ssl::CertErrors *err = vcert.errors; err; err = err->next, ++i) {
+        for (const Security::CertErrors *err = vcert.errors; err; err = err->next, ++i) {
             body +="\n";
             body = body + param_error_name + xitoa(i) + "=" + GetErrorName(err->element.code) + "\n";
             int errorCertPos = -1;
@@ -70,14 +85,16 @@ get_error_id(const char *label, size_t len)
     const char *e = label + len -1;
     while (e != label && xisdigit(*e)) --e;
     if (e != label) ++e;
-    return strtol(e, 0 , 10);
+    return strtol(e, 0, 10);
 }
 
 bool
-Ssl::CertValidationMsg::parseResponse(CertValidationResponse &resp, STACK_OF(X509) *peerCerts, std::string &error)
+Ssl::CertValidationMsg::parseResponse(CertValidationResponse &resp, std::string &error)
 {
     std::vector<CertItem> certs;
 
+    const STACK_OF(X509) *peerCerts = PeerValidationCertificatesChain(resp.ssl);
+
     const char *param = body.c_str();
     while (*param) {
         while (xisspace(*param)) param++;
@@ -154,7 +171,7 @@ Ssl::CertValidationMsg::parseResponse(CertValidationResponse &resp, STACK_OF(X50
             return false;
         }
 
-        param = value + value_len +1;
+        param = value + value_len;
     }
 
     /*Run through parsed errors to check for errors*/
@@ -180,6 +197,13 @@ Ssl::CertValidationMsg::getCertByName(std::vector<CertItem> const &certs, std::s
     return NULL;
 }
 
+uint64_t
+Ssl::CertValidationResponse::MemoryUsedByResponse(const CertValidationResponse::Pointer &)
+{
+    // XXX: This math does not account for most of the response size!
+    return sizeof(CertValidationResponse);
+}
+
 Ssl::CertValidationResponse::RecvdError &
 Ssl::CertValidationResponse::getError(int errorId)
 {
@@ -194,42 +218,12 @@ Ssl::CertValidationResponse::getError(int errorId)
     return errors.back();
 }
 
-Ssl::CertValidationResponse::RecvdError::RecvdError(const RecvdError &old)
-{
-    id = old.id;
-    error_no = old.error_no;
-    error_reason = old.error_reason;
-    setCert(old.cert.get());
-}
-
-Ssl::CertValidationResponse::RecvdError & Ssl::CertValidationResponse::RecvdError::operator = (const RecvdError &old)
-{
-    id = old.id;
-    error_no = old.error_no;
-    error_reason = old.error_reason;
-    setCert(old.cert.get());
-    return *this;
-}
-
 void
 Ssl::CertValidationResponse::RecvdError::setCert(X509 *aCert)
 {
     cert.resetAndLock(aCert);
 }
 
-Ssl::CertValidationMsg::CertItem::CertItem(const CertItem &old)
-{
-    name = old.name;
-    setCert(old.cert.get());
-}
-
-Ssl::CertValidationMsg::CertItem & Ssl::CertValidationMsg::CertItem::operator = (const CertItem &old)
-{
-    name = old.name;
-    setCert(old.cert.get());
-    return *this;
-}
-
 void
 Ssl::CertValidationMsg::CertItem::setCert(X509 *aCert)
 {