]> git.ipfire.org Git - thirdparty/squid.git/blobdiff - src/ssl/crtd_message.cc
Source Format Enforcement (#763)
[thirdparty/squid.git] / src / ssl / crtd_message.cc
index 0d77a96835d2a6e69b7120a99302e61961db7dfc..6b2e1b014b70cd54237870180d66751a345de41d 100644 (file)
@@ -1,22 +1,21 @@
 /*
- * $Id$
+ * Copyright (C) 1996-2021 The Squid Software Foundation and contributors
+ *
+ * Squid software is distributed under GPLv2+ license and includes
+ * contributions from numerous individuals and organizations.
+ * Please see the COPYING and CONTRIBUTORS files for details.
  */
 
 #include "squid.h"
-#include "ssl/gadgets.h"
 #include "ssl/crtd_message.h"
-#if HAVE_CSTDLIB
+#include "ssl/gadgets.h"
+
 #include <cstdlib>
-#endif
-#if HAVE_CSTRING
 #include <cstring>
-#endif
-#if HAVE_STDEXCEPT
 #include <stdexcept>
-#endif
 
-Ssl::CrtdMessage::CrtdMessage()
-        :   body_size(0), state(BEFORE_CODE)
+Ssl::CrtdMessage::CrtdMessage(MessageKind kind)
+    :   body_size(0), state(kind == REPLY ? BEFORE_LENGTH: BEFORE_CODE)
 {}
 
 Ssl::CrtdMessage::ParseResult Ssl::CrtdMessage::parse(const char * buffer, size_t len)
@@ -26,7 +25,7 @@ Ssl::CrtdMessage::ParseResult Ssl::CrtdMessage::parse(const char * buffer, size_
         switch (state) {
         case BEFORE_CODE: {
             if (xisspace(*current_pos)) {
-                current_pos++;
+                ++current_pos;
                 break;
             }
             if (xisalpha(*current_pos)) {
@@ -39,7 +38,7 @@ Ssl::CrtdMessage::ParseResult Ssl::CrtdMessage::parse(const char * buffer, size_
         case CODE: {
             if (xisalnum(*current_pos) || *current_pos == '_') {
                 current_block += *current_pos;
-                current_pos++;
+                ++current_pos;
                 break;
             }
             if (xisspace(*current_pos)) {
@@ -53,7 +52,7 @@ Ssl::CrtdMessage::ParseResult Ssl::CrtdMessage::parse(const char * buffer, size_
         }
         case BEFORE_LENGTH: {
             if (xisspace(*current_pos)) {
-                current_pos++;
+                ++current_pos;
                 break;
             }
             if (xisdigit(*current_pos)) {
@@ -66,7 +65,7 @@ Ssl::CrtdMessage::ParseResult Ssl::CrtdMessage::parse(const char * buffer, size_
         case LENGTH: {
             if (xisdigit(*current_pos)) {
                 current_block += *current_pos;
-                current_pos++;
+                ++current_pos;
                 break;
             }
             if (xisspace(*current_pos)) {
@@ -84,7 +83,7 @@ Ssl::CrtdMessage::ParseResult Ssl::CrtdMessage::parse(const char * buffer, size_
                 break;
             }
             if (xisspace(*current_pos)) {
-                current_pos++;
+                ++current_pos;
                 break;
             } else {
                 state = BODY;
@@ -124,7 +123,6 @@ void Ssl::CrtdMessage::setBody(std::string const & aBody) { body = aBody; }
 
 void Ssl::CrtdMessage::setCode(std::string const & aCode) { code = aCode; }
 
-
 std::string Ssl::CrtdMessage::compose() const
 {
     if (code.empty()) return std::string();
@@ -168,7 +166,7 @@ void Ssl::CrtdMessage::parseBody(CrtdMessage::BodyParams & map, std::string & ot
 void Ssl::CrtdMessage::composeBody(CrtdMessage::BodyParams const & map, std::string const & other_part)
 {
     body.clear();
-    for (BodyParams::const_iterator i = map.begin(); i != map.end(); i++) {
+    for (BodyParams::const_iterator i = map.begin(); i != map.end(); ++i) {
         if (i != map.begin())
             body += "\n";
         body += i->first + "=" + i->second;
@@ -177,7 +175,6 @@ void Ssl::CrtdMessage::composeBody(CrtdMessage::BodyParams const & map, std::str
         body += '\n' + other_part;
 }
 
-
 bool Ssl::CrtdMessage::parseRequest(Ssl::CertificateProperties &certProperties, std::string &error)
 {
     Ssl::CrtdMessage::BodyParams map;
@@ -209,12 +206,21 @@ bool Ssl::CrtdMessage::parseRequest(Ssl::CertificateProperties &certProperties,
     i = map.find(Ssl::CrtdMessage::param_Sign);
     if (i != map.end()) {
         if ((certProperties.signAlgorithm = Ssl::certSignAlgorithmId(i->second.c_str())) == Ssl::algSignEnd) {
-            error = "Wrong signing algoritm: " + i->second;
+            error = "Wrong signing algorithm: ";
+            error += i->second;
             return false;
         }
     } else
         certProperties.signAlgorithm = Ssl::algSignTrusted;
 
+    i = map.find(Ssl::CrtdMessage::param_SignHash);
+    const char *signHashName = i != map.end() ? i->second.c_str() : SQUID_SSL_SIGN_HASH_IF_NONE;
+    if (!(certProperties.signHash = EVP_get_digestbyname(signHashName))) {
+        error = "Wrong signing hash: ";
+        error += signHashName;
+        return false;
+    }
+
     if (!Ssl::readCertAndPrivateKeyFromMemory(certProperties.signWithX509, certProperties.signWithPkey, certs_part.c_str())) {
         error = "Broken signing certificate!";
         return false;
@@ -242,6 +248,8 @@ void Ssl::CrtdMessage::composeRequest(Ssl::CertificateProperties const &certProp
         body +=  "\n" + Ssl::CrtdMessage::param_SetValidBefore + "=on";
     if (certProperties.signAlgorithm != Ssl::algSignEnd)
         body +=  "\n" +  Ssl::CrtdMessage::param_Sign + "=" +  certSignAlgorithm(certProperties.signAlgorithm);
+    if (certProperties.signHash)
+        body +=  "\n" + Ssl::CrtdMessage::param_SignHash + "=" + EVP_MD_name(certProperties.signHash);
 
     std::string certsPart;
     if (!Ssl::writeCertAndPrivateKeyToMemory(certProperties.signWithX509, certProperties.signWithPkey, certsPart))
@@ -259,3 +267,5 @@ const std::string Ssl::CrtdMessage::param_SetValidAfter(Ssl::CertAdaptAlgorithmS
 const std::string Ssl::CrtdMessage::param_SetValidBefore(Ssl::CertAdaptAlgorithmStr[algSetValidBefore]);
 const std::string Ssl::CrtdMessage::param_SetCommonName(Ssl::CertAdaptAlgorithmStr[algSetCommonName]);
 const std::string Ssl::CrtdMessage::param_Sign("Sign");
+const std::string Ssl::CrtdMessage::param_SignHash("SignHash");
+