]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #3910: ssl: parse and publish server common name from server certificate
authorSerhii Vlasiuk -X (svlasiuk - SOFTSERVE INC at Cisco) <svlasiuk@cisco.com>
Tue, 18 Jul 2023 18:37:09 +0000 (18:37 +0000)
committerSteve Chew (stechew) <stechew@cisco.com>
Tue, 18 Jul 2023 18:37:09 +0000 (18:37 +0000)
Merge in SNORT/snort3 from ~SVLASIUK/snort3:ssl_server_common_name to master

Squashed commit of the following:

commit f314e115effcbb33b323324fd90b72a1ddca71b4
Author: Serhii Vlasiuk <svlasiuk@cisco.com>
Date:   Tue Jul 11 17:11:46 2023 +0300

    ssl: parse and publish server common name from server certificate

src/network_inspectors/appid/service_plugins/service_ssl.cc
src/protocols/ssl.cc
src/protocols/ssl.h
src/pub_sub/ssl_events.h
src/service_inspectors/ssl/ssl_inspector.cc

index e8e05877963eb9998a7d12307246012a2f14c3af..8b04a8addfe39174e9f2886a64b53bfb6df2df87 100644 (file)
@@ -25,8 +25,6 @@
 
 #include "service_ssl.h"
 
-#include <openssl/x509.h>
-
 #include "app_info_table.h"
 #include "protocols/packet.h"
 #include "protocols/ssl.h"
@@ -46,10 +44,6 @@ enum SSLContentType
 #define SSL2_SERVER_HELLO 4
 #define PCT_SERVER_HELLO 2
 
-#define FIELD_SEPARATOR "/"
-#define COMMON_NAME_STR "/CN="
-#define ORG_NAME_STR "/O="
-
 enum SSLState
 {
     SSL_STATE_INITIATE,    // Client initiates.
@@ -65,16 +59,10 @@ struct ServiceSSLData
     int tot_length;
     /* From client: */
     SSLV3ClientHelloData client_hello;
-    /* While collecting certificates: */
-    unsigned certs_len;   // (Total) length of certificate(s).
-    uint8_t* certs_data;  // Certificate(s) data (each proceeded by length (3 bytes)).
+    /* From server: */
+    SSLV3ServerCertData server_cert;
     int in_certs;         // Currently collecting certificates?
     int certs_curr_len;   // Current amount of collected certificate data.
-    /* Data collected from certificates afterwards: */
-    char* common_name;
-    int common_name_strlen;
-    char* org_name;
-    int org_name_strlen;
     uint8_t* cached_data;
     uint16_t cached_len;
 };
@@ -89,19 +77,6 @@ struct ServiceSSLV3Hdr
     uint16_t len;
 };
 
-/* Usually referred to as a Certificate Handshake. */
-struct ServiceSSLV3CertsRecord
-{
-    uint8_t type;
-    uint8_t length_msb;
-    uint16_t length;
-    uint8_t certs_len[3];  // 3-byte length, network byte order.
-    /* Certificate(s) follow.
-     * For each:
-     *  - Length: 3 bytes
-     *  - Data  : "Length" bytes */
-};
-
 struct ServiceSSLPCTHdr
 {
     uint8_t len;
@@ -137,12 +112,6 @@ struct ServiceSSLV2Hdr
 
 #pragma pack()
 
-/* Convert 3-byte lengths in TLS headers to integers. */
-#define ntoh3(msb_ptr) \
-    ((uint32_t)((uint32_t)(((const uint8_t*)(msb_ptr))[0] << 16) \
-    + (uint32_t)(((const uint8_t*)(msb_ptr))[1] << 8) \
-    + (uint32_t)(((const uint8_t*)(msb_ptr))[2])))
-
 static const uint8_t SSL_PATTERN_PCT[] = { 0x02, 0x00, 0x80, 0x01 };
 static const uint8_t SSL_PATTERN3_0[] = { 0x16, 0x03, 0x00 };
 static const uint8_t SSL_PATTERN3_1[] = { 0x16, 0x03, 0x01 };
@@ -218,10 +187,8 @@ static void ssl_cache_free(uint8_t*& ssl_cache, uint16_t& len)
 static void ssl_free(void* ss)
 {
     ServiceSSLData* ss_tmp = (ServiceSSLData*)ss;
-    snort_free(ss_tmp->certs_data);
-    snort_free(ss_tmp->common_name);
     ss_tmp->client_hello.clear();
-    snort_free(ss_tmp->org_name);
+    ss_tmp->server_cert.clear();
     ssl_cache_free(ss_tmp->cached_data, ss_tmp->cached_len);
     snort_free(ss_tmp);
 }
@@ -247,109 +214,6 @@ static void parse_client_initiation(const uint8_t* data, uint16_t size, ServiceS
     parse_client_hello_data(data, size, &ss->client_hello);
 }
 
-static bool parse_certificates(ServiceSSLData* ss)
-{
-    bool success = false;
-    if (ss->certs_data and ss->certs_len)
-    {
-        char* common_name = nullptr;
-        char* org_name = nullptr;
-        const uint8_t* data = ss->certs_data;
-        int len = ss->certs_len;
-        int common_name_tot_len = 0;
-        int org_name_tot_len  = 0;
-        success = true;
-
-        while (len > 0 and !(common_name and org_name))
-        {
-            X509* cert = nullptr;
-            char* cert_name = nullptr;
-            char* start = nullptr;
-
-            int cert_len = ntoh3(data);
-            data += 3;
-            len -= 3;
-            if (len < cert_len)
-            {
-                success = false;
-                break;
-            }
-            /* d2i_X509() increments the data ptr for us. */
-            cert = d2i_X509(nullptr, (const unsigned char**)&data, cert_len);
-            len -= cert_len;
-            if (!cert)
-            {
-                success = false;
-                break;
-            }
-
-            /* only look for common name or org name if we don't already have one */
-            if (!common_name or !org_name)
-            {
-                if ((cert_name = X509_NAME_oneline(X509_get_subject_name(cert), nullptr, 0)))
-                {
-                    if (!common_name)
-                    {
-                        if ((start = strstr(cert_name, COMMON_NAME_STR)))
-                        {
-                            int length = 0;
-                            start += strlen(COMMON_NAME_STR);
-                            length = strlen(start);
-                            if (length > 2 and *start == '*' and *(start+1) == '.')
-                            {
-                                start += 2; // remove leading *.
-                                length -= 2;
-                            }
-                            common_name = snort_strndup(start, length);
-                            common_name_tot_len += length;
-                            start = nullptr;
-                        }
-                    }
-                    if (!org_name)
-                    {
-                        if ((start = strstr(cert_name, COMMON_NAME_STR)))
-                        {
-                            int length;
-                            start += strlen(COMMON_NAME_STR);
-                            length = strlen(start);
-                            if (length > 2 and *start == '*' and *(start+1) == '.')
-                            {
-                                start += 2; // remove leading *.
-                                length -= 2;
-                            }
-                            org_name = snort_strndup(start, length);
-                            org_name_tot_len += length;
-                            start = nullptr;
-                        }
-                    }
-                    free(cert_name);
-                    cert_name = nullptr;
-                }
-            }
-            X509_free(cert);
-        }
-
-        if (common_name)
-        {
-            ss->common_name = common_name;
-            ss->common_name_strlen = common_name_tot_len;
-        }
-
-        if (org_name)
-        {
-            ss->org_name = org_name;
-            ss->org_name_strlen = org_name_tot_len;
-        }
-
-        /* No longer need entire certificates. We have what we came for. */
-        snort_free(ss->certs_data);
-        ss->certs_data = nullptr;
-        ss->certs_len = 0;
-    }
-
-    return success;
-}
-
 static void save_ssl_cache(ServiceSSLData* ss, uint16_t size, const uint8_t* data)
 {
     ss->cached_data = (uint8_t*)snort_calloc(size, sizeof(uint8_t));
@@ -516,29 +380,29 @@ int SslServiceDetector::validate(AppIdDiscoveryArgs& args)
                 {
                 case SSLV3RecordType::CERTIFICATE:
                     /* Start pulling out certificates. */
-                    if (!ss->certs_data)
+                    if (!ss->server_cert.certs_data)
                     {
                         if (size < sizeof(ServiceSSLV3CertsRecord))
                             goto fail;
 
                         certs_rec = (const ServiceSSLV3CertsRecord*)data;
-                        ss->certs_len = ntoh3(certs_rec->certs_len);
-                        ss->certs_data = (uint8_t*)snort_alloc(ss->certs_len);
-                        if ((size - sizeof(ServiceSSLV3CertsRecord)) < ss->certs_len)
+                        ss->server_cert.certs_len = ntoh3(certs_rec->certs_len);
+                        ss->server_cert.certs_data = (uint8_t*)snort_alloc(ss->server_cert.certs_len);
+                        if ((size - sizeof(ServiceSSLV3CertsRecord)) < ss->server_cert.certs_len)
                         {
                             /* Will have to get more next time around. */
                             ss->in_certs = 1;
                             /* Skip over header to data */
                             ss->certs_curr_len = size - sizeof(ServiceSSLV3CertsRecord);
-                            memcpy(ss->certs_data, data + sizeof(ServiceSSLV3CertsRecord),
+                            memcpy(ss->server_cert.certs_data, data + sizeof(ServiceSSLV3CertsRecord),
                                 ss->certs_curr_len);
                         }
                         else
                         {
                             /* Can get it all this time. */
                             ss->in_certs       = 0;
-                            ss->certs_curr_len = ss->certs_len;
-                            memcpy(ss->certs_data, data + sizeof(ServiceSSLV3CertsRecord),
+                            ss->certs_curr_len = ss->server_cert.certs_len;
+                            memcpy(ss->server_cert.certs_data, data + sizeof(ServiceSSLV3CertsRecord),
                                 ss->certs_curr_len);
                             break;
                         }
@@ -578,22 +442,22 @@ int SslServiceDetector::validate(AppIdDiscoveryArgs& args)
             else
             {
                 /* See if there's more certificate data to grab. */
-                if (ss->in_certs && ss->certs_data)
+                if (ss->in_certs && ss->server_cert.certs_data)
                 {
-                    if (size < (ss->certs_len - ss->certs_curr_len))
+                    if (size < (ss->server_cert.certs_len - ss->certs_curr_len))
                     {
                         /* Will have to get more next time around. */
-                        memcpy(ss->certs_data + ss->certs_curr_len, data, size);
+                        memcpy(ss->server_cert.certs_data + ss->certs_curr_len, data, size);
                         ss->in_certs = 1;
                         ss->certs_curr_len += size;
                     }
                     else
                     {
                         /* Can get it all this time. */
-                        memcpy(ss->certs_data + ss->certs_curr_len, data,
-                            ss->certs_len - ss->certs_curr_len);
+                        memcpy(ss->server_cert.certs_data + ss->certs_curr_len, data,
+                            ss->server_cert.certs_len - ss->certs_curr_len);
                         ss->in_certs = 0;
-                        ss->certs_curr_len = ss->certs_len;
+                        ss->certs_curr_len = ss->server_cert.certs_len;
                     }
                 }
 
@@ -617,37 +481,42 @@ int SslServiceDetector::validate(AppIdDiscoveryArgs& args)
 
 inprocess:
     if (reallocated_data)
+    {
         snort_free(reallocated_data);
+        reallocated_data = nullptr;
+    }
     service_inprocess(args.asd, args.pkt, args.dir);
     return APPID_INPROCESS;
 
 fail:
     if (reallocated_data)
+    {
         snort_free(reallocated_data);
-    snort_free(ss->certs_data);
+        reallocated_data = nullptr;
+    }
     ss->client_hello.clear();
-    snort_free(ss->common_name);
-    snort_free(ss->org_name);
-    ss->certs_data = nullptr;
-    ss->common_name = ss->org_name = nullptr;
+    ss->server_cert.clear();
     fail_service(args.asd, args.pkt, args.dir);
     return APPID_NOMATCH;
 
 success:
     if (reallocated_data)
+    {
         snort_free(reallocated_data);
+        reallocated_data = nullptr;
+    }
         
-    if (ss->certs_data && ss->certs_len)
+    if (ss->server_cert.certs_data && ss->server_cert.certs_len)
     {
         if (!(args.asd.scan_flags & SCAN_CERTVIZ_ENABLED_FLAG) and
-            (!parse_certificates(ss)))
+            (!parse_server_certificates(&ss->server_cert)))
         {
             goto fail;
         }
     }
 
     args.asd.set_session_flags(APPID_SESSION_SSL_SESSION);
-    if (ss->client_hello.host_name || ss->common_name || ss->org_name)
+    if (ss->client_hello.host_name || ss->server_cert.common_name || ss->server_cert.org_name)
     {
         if (!args.asd.tsession)
             args.asd.tsession = new TlsSession();
@@ -658,24 +527,25 @@ success:
             args.asd.tsession->set_tls_host(ss->client_hello.host_name, 0, args.change_bits);
             args.asd.scan_flags |= SCAN_SSL_HOST_FLAG;
         }
-        else if (ss->common_name)
+        else if (ss->server_cert.common_name)
         {
             /* Use common name (from server) if we didn't get host name (from client). */
-            args.asd.tsession->set_tls_host(ss->common_name, ss->common_name_strlen, args.change_bits);
+            args.asd.tsession->set_tls_host(ss->server_cert.common_name, ss->server_cert.common_name_strlen,
+                args.change_bits);
             args.asd.scan_flags |= SCAN_SSL_HOST_FLAG;
         }
 
         /* TLS Common Name */
-        if (ss->common_name)
+        if (ss->server_cert.common_name)
         {
-            args.asd.tsession->set_tls_cname(ss->common_name, 0, args.change_bits);
+            args.asd.tsession->set_tls_cname(ss->server_cert.common_name, 0, args.change_bits);
             args.asd.scan_flags |= SCAN_SSL_CERTIFICATE_FLAG;
         }
         /* TLS Org Unit */
-        if (ss->org_name)
-            args.asd.tsession->set_tls_org_unit(ss->org_name, 0);
+        if (ss->server_cert.org_name)
+            args.asd.tsession->set_tls_org_unit(ss->server_cert.org_name, 0);
 
-        ss->client_hello.host_name = ss->common_name = ss->org_name = nullptr;
+        ss->client_hello.host_name = ss->server_cert.common_name = ss->server_cert.org_name = nullptr;
         args.asd.tsession->set_tls_handshake_done();
     }
     return add_service(args.change_bits, args.asd, args.pkt, args.dir,
index 837dad78b9872fbcb66e85779846742928a54f1a..fba5edc3b2efc0de2e0e20b22bd0d227da7b3c04 100644 (file)
 
 #include "ssl.h"
 
+#include <openssl/x509.h>
+
 #include "packet.h"
 #include "utils/util.h"
 
+#define COMMON_NAME_STR "/CN="
+
 #define THREE_BYTE_LEN(x) ((x)[2] | (x)[1] << 8 | (x)[0] << 16)
 
 #define SSL_ERROR_FLAGS \
@@ -52,6 +56,25 @@ void SSLV3ClientHelloData::clear()
     host_name = nullptr;
 }
 
+SSLV3ServerCertData::~SSLV3ServerCertData()
+{
+    snort_free(certs_data);
+    snort_free(common_name);
+    snort_free(org_name);
+}
+
+void SSLV3ServerCertData::clear()
+{
+    snort_free(certs_data);
+    certs_data = nullptr;
+
+    snort_free(common_name);
+    common_name = nullptr;
+
+    snort_free(org_name);
+    org_name = nullptr;
+}
+
 static uint32_t SSL_decode_version_v3(uint8_t major, uint8_t minor)
 {
     /* Should only be called internally and by functions which have previously
@@ -81,9 +104,11 @@ static uint32_t SSL_decode_version_v3(uint8_t major, uint8_t minor)
 }
 
 static uint32_t SSL_decode_handshake_v3(const uint8_t* pkt, int size,
-    uint32_t cur_flags, uint32_t pkt_flags, SSLV3ClientHelloData* client_hello_data)
+    uint32_t cur_flags, uint32_t pkt_flags, SSLV3ClientHelloData* client_hello_data,
+    SSLV3ServerCertData* server_cert_data)
 {
     const SSL_handshake_hello_t* hello;
+    const ServiceSSLV3CertsRecord* certs_rec;
     uint32_t retval = 0;
 
     while (size > 0)
@@ -174,7 +199,17 @@ static uint32_t SSL_decode_handshake_v3(const uint8_t* pkt, int size,
             break;
 
         case SSL_HS_CERT:
-            retval |= SSL_CERTIFICATE_FLAG;
+            if (server_cert_data != nullptr)
+            {
+                certs_rec = (const ServiceSSLV3CertsRecord*)handshake;
+                server_cert_data->certs_len = ntoh3(certs_rec->certs_len);
+                server_cert_data->certs_data = (uint8_t*)snort_alloc(server_cert_data->certs_len);
+                memcpy(server_cert_data->certs_data, pkt + sizeof(certs_rec->certs_len), server_cert_data->certs_len);
+
+                snort::parse_server_certificates(server_cert_data);
+            }
+
+            retval |= SSL_CERTIFICATE_FLAG; 
             break;
 
         /* The following types are not presently of interest */
@@ -205,7 +240,7 @@ static uint32_t SSL_decode_handshake_v3(const uint8_t* pkt, int size,
 
 static uint32_t SSL_decode_v3(const uint8_t* pkt, int size, uint32_t pkt_flags,
     uint8_t* alert_flags, uint16_t* partial_rec_len, int max_hb_len, uint32_t* info_flags,
-    SSLV3ClientHelloData* client_hello_data)
+    SSLV3ClientHelloData* client_hello_data, SSLV3ServerCertData* server_cert_data)
 {
     uint32_t retval = 0;
     uint16_t hblen;
@@ -299,7 +334,7 @@ static uint32_t SSL_decode_v3(const uint8_t* pkt, int size, uint32_t pkt_flags,
             if (!(retval & SSL_CHANGE_CIPHER_FLAG))
             {
                 int hsize = size < (int)reclen ? size : (int)reclen;
-                retval |= SSL_decode_handshake_v3(pkt, hsize, retval, pkt_flags, client_hello_data);
+                retval |= SSL_decode_handshake_v3(pkt, hsize, retval, pkt_flags, client_hello_data, server_cert_data);
             }
             else if (ccs)
             {
@@ -443,7 +478,7 @@ namespace snort
 uint32_t SSL_decode(
     const uint8_t* pkt, int size, uint32_t pkt_flags, uint32_t prev_flags,
     uint8_t* alert_flags, uint16_t* partial_rec_len, int max_hb_len, uint32_t* info_flags,
-    SSLV3ClientHelloData* sslv3_chello_data)
+    SSLV3ClientHelloData* client_hello_data, SSLV3ServerCertData* server_cert_data)
 {
     if (!pkt || !size)
         return SSL_ARG_ERROR_FLAG;
@@ -464,7 +499,8 @@ uint32_t SSL_decode(
          * SSLv2 as TLS,the decoder will either catch a bad type, bad version, or
          * indicate that it is truncated. */
         if (size == 5)
-            return SSL_decode_v3(pkt, size, pkt_flags, alert_flags, partial_rec_len, max_hb_len, info_flags, sslv3_chello_data);
+            return SSL_decode_v3(pkt, size, pkt_flags, alert_flags, partial_rec_len, max_hb_len, info_flags,
+                client_hello_data, server_cert_data);
 
         /* At this point, 'size' has to be > 5 */
 
@@ -511,7 +547,8 @@ uint32_t SSL_decode(
         }
     }
 
-    return SSL_decode_v3(pkt, size, pkt_flags, alert_flags, partial_rec_len, max_hb_len, info_flags, sslv3_chello_data);
+    return SSL_decode_v3(pkt, size, pkt_flags, alert_flags, partial_rec_len, max_hb_len, info_flags,
+        client_hello_data, server_cert_data);
 }
 
 /* very simplistic - just enough to say this is binary data - the rules will make a final
@@ -562,7 +599,7 @@ bool IsTlsServerHello(const uint8_t* ptr, const uint8_t* end)
 
 bool IsSSL(const uint8_t* ptr, int len, int pkt_flags)
 {
-    uint32_t ssl_flags = SSL_decode(ptr, len, pkt_flags, 0, nullptr, nullptr, 0, nullptr);
+    uint32_t ssl_flags = SSL_decode(ptr, len, pkt_flags, 0, nullptr, nullptr, 0, nullptr, nullptr);
 
     if ((ssl_flags != SSL_ARG_ERROR_FLAG) &&
         !(ssl_flags & SSL_ERROR_FLAGS))
@@ -658,4 +695,80 @@ void parse_client_hello_data(const uint8_t* pkt, uint16_t size, SSLV3ClientHello
     }
 }
 
+bool parse_server_certificates(SSLV3ServerCertData* server_cert_data)
+{
+    if (!server_cert_data->certs_data or !server_cert_data->certs_len)
+        return false;
+
+    char* common_name = nullptr;
+    char* org_name = nullptr;
+    const uint8_t* data = server_cert_data->certs_data;
+    int len = server_cert_data->certs_len;
+    int common_name_len = 0;
+    int org_name_len  = 0;
+
+    while (len > 0 and !(common_name and org_name))
+    {
+        X509* cert = nullptr;
+        char* cert_name = nullptr;
+        char* start = nullptr;
+
+        int cert_len = ntoh3(data);
+        data += 3;
+        len -= 3;
+        if (len < cert_len)
+            return false;
+
+        /* d2i_X509() increments the data ptr for us. */
+        cert = d2i_X509(nullptr, (const unsigned char**)&data, cert_len);
+        len -= cert_len;
+        if (!cert)
+            return false;
+
+        if (nullptr == (cert_name = X509_NAME_oneline(X509_get_subject_name(cert), nullptr, 0)))
+        {
+            X509_free(cert);
+            continue;
+        }
+
+        if (!common_name and (start = strstr(cert_name, COMMON_NAME_STR)))
+        {
+            start += strlen(COMMON_NAME_STR);
+            int length = strlen(start);
+            if (length > 2 and *start == '*' and *(start+1) == '.')
+            {
+                start += 2; // remove leading *.
+                length -= 2;
+            }
+            common_name = snort_strndup(start, length);
+            common_name_len = length;
+
+            org_name = snort_strndup(start, length);
+            org_name_len = length;
+
+            start = nullptr;
+        }
+
+        free(cert_name);
+        cert_name = nullptr;
+        X509_free(cert);
+    }
+
+    if (common_name)
+    {
+        server_cert_data->common_name = common_name;
+        server_cert_data->common_name_strlen = common_name_len;
+
+        server_cert_data->org_name = org_name;
+        server_cert_data->org_name_strlen = org_name_len;
+    }
+
+    /* No longer need entire certificates. We have what we came for. */
+    snort_free(server_cert_data->certs_data);
+    server_cert_data->certs_data = nullptr;
+    server_cert_data->certs_len = 0;
+
+    return true;
+}
+
 } // namespace snort
index fe43cab844ce07fd760550d031e13aa041441954..ed5d1dda4744cac13eaa93e26a4ccc2dafb66bdc 100644 (file)
 #define SSL_VER_TLS11_FLAG      0x00040000
 #define SSL_VER_TLS12_FLAG      0x00080000
 
+/* Convert 3-byte lengths in TLS headers to integers. */
+#define ntoh3(msb_ptr) \
+    ((uint32_t)((uint32_t)(((const uint8_t*)(msb_ptr))[0] << 16) \
+    + (uint32_t)(((const uint8_t*)(msb_ptr))[1] << 8) \
+    + (uint32_t)(((const uint8_t*)(msb_ptr))[2])))
+
 #define SSL_VERFLAGS \
     (SSL_VER_SSLV2_FLAG | SSL_VER_SSLV3_FLAG | \
     SSL_VER_TLS10_FLAG | SSL_VER_TLS11_FLAG | \
@@ -202,6 +208,20 @@ struct SSLV3ClientHelloData
     char* host_name = nullptr;
 };
 
+struct SSLV3ServerCertData
+{
+    ~SSLV3ServerCertData();
+    void clear();
+    /* While collecting certificates: */
+    unsigned certs_len;   // (Total) length of certificate(s).
+    uint8_t* certs_data = nullptr;  // Certificate(s) data (each proceeded by length (3 bytes)).
+    /* Data collected from certificates afterwards: */
+    char* common_name = nullptr;
+    int common_name_strlen;
+    char* org_name = nullptr;
+    int org_name_strlen;
+};
+
 enum class SSLV3RecordType : uint8_t
 {
     CLIENT_HELLO = 1,
@@ -213,6 +233,19 @@ enum class SSLV3RecordType : uint8_t
     CERTIFICATE_STATUS = 22
 };
 
+/* Usually referred to as a Certificate Handshake. */
+struct ServiceSSLV3CertsRecord
+{
+    uint8_t type;
+    uint8_t length_msb;
+    uint16_t length;
+    uint8_t certs_len[3];  // 3-byte length, network byte order.
+    /* Certificate(s) follow.
+     * For each:
+     *  - Length: 3 bytes
+     *  - Data  : "Length" bytes */
+};
+
 /* Usually referred to as a TLS Handshake. */
 struct ServiceSSLV3Record
 {
@@ -275,9 +308,10 @@ namespace snort
 uint32_t SSL_decode(
     const uint8_t* pkt, int size, uint32_t pktflags, uint32_t prevflags,
     uint8_t* alert_flags, uint16_t* partial_rec_len, int hblen, uint32_t* info_flags = nullptr,
-    SSLV3ClientHelloData* data = nullptr);
+    SSLV3ClientHelloData* data = nullptr, SSLV3ServerCertData* server_cert_data = nullptr);
 
     void parse_client_hello_data(const uint8_t* pkt, uint16_t size, SSLV3ClientHelloData*);
+    bool parse_server_certificates(SSLV3ServerCertData* server_cert_data);
 
 SO_PUBLIC bool IsTlsClientHello(const uint8_t* ptr, const uint8_t* end);
 SO_PUBLIC bool IsTlsServerHello(const uint8_t* ptr, const uint8_t* end);
index a0566b66111211e8d69df9b15c373385a130995d..602e7010f3393455a45ee68b673956c62d7f3dcd 100644 (file)
 
 #include "framework/data_bus.h"
 
-struct  SslEventIds { enum : unsigned { CHELLO_SERVER_NAME, num_ids }; };
+struct SslEventIds
+{
+    enum : unsigned
+    {
+        CHELLO_SERVER_NAME,
+        SERVER_COMMON_NAME,
+    
+        num_ids
+    };
+};
 
-const snort::PubKey ssl_chello_pub_key { "ssl_chello", SslEventIds::num_ids };
+const snort::PubKey ssl_pub_key { "ssl", SslEventIds::num_ids };
 
 class SslClientHelloEvent : public snort::DataEvent
 {
@@ -47,4 +56,22 @@ private:
     const snort::Packet* packet;
 };
 
+class SslServerCommonNameEvent : public snort::DataEvent
+{
+public:
+    SslServerCommonNameEvent(const std::string& server_common_name, const snort::Packet* packet) :
+        server_common_name(server_common_name), packet(packet)
+        { }
+
+    const snort::Packet* get_packet() const override
+    { return packet; }
+
+    const std::string& get_common_name() const
+    { return server_common_name; }
+
+private:
+    const std::string server_common_name;
+    const snort::Packet* packet;
+};
+
 #endif
index 8d17364b876e282c4501b13ff87f84afbd8da02e..032cbf51d510ff41302f4bfe614a01ad7c3a8f6f 100644 (file)
@@ -59,7 +59,7 @@ using namespace snort;
 THREAD_LOCAL ProfileStats sslPerfStats;
 THREAD_LOCAL SslStats sslstats;
 
-static unsigned ssl_chello_pub_id = 0;
+static unsigned pub_id = 0;
 
 const PegInfo ssl_peg_names[] =
 {
@@ -309,13 +309,21 @@ static void snort_ssl(SSL_PROTO_CONF* config, Packet* p)
     uint8_t heartbleed_type = 0;
     uint32_t info_flags = 0;
     SSLV3ClientHelloData client_hello_data;
+    SSLV3ServerCertData server_cert_data;
     uint32_t new_flags = SSL_decode(p->data, (int)p->dsize, p->packet_flags, sd->ssn_flags,
-        &heartbleed_type, &(sd->partial_rec_len[dir+index]), config->max_heartbeat_len, &info_flags, &client_hello_data);
+        &heartbleed_type, &(sd->partial_rec_len[dir+index]), config->max_heartbeat_len, &info_flags, &client_hello_data,
+        &server_cert_data);
 
     if (client_hello_data.host_name != nullptr)
     {
         SslClientHelloEvent event(client_hello_data.host_name, p);
-        DataBus::publish(ssl_chello_pub_id, SslEventIds::CHELLO_SERVER_NAME, event);
+        DataBus::publish(pub_id, SslEventIds::CHELLO_SERVER_NAME, event);
+    }
+
+    if (server_cert_data.common_name != nullptr)
+    {
+        SslServerCommonNameEvent event(server_cert_data.common_name, p);
+        DataBus::publish(pub_id, SslEventIds::SERVER_COMMON_NAME, event);
     }
 
     if (heartbleed_type & SSL_HEARTBLEED_REQUEST)
@@ -500,7 +508,7 @@ void Ssl::eval(Packet* p)
 
 bool Ssl::configure(SnortConfig*)
 {
-    ssl_chello_pub_id = DataBus::get_id(ssl_chello_pub_key);
+    pub_id = DataBus::get_id(ssl_pub_key);
 
     DataBus::subscribe(intrinsic_pub_key, IntrinsicEventIds::FINALIZE_PACKET, new SslFinalizePacketHandler());
     DataBus::subscribe(intrinsic_pub_key, IntrinsicEventIds::OPPORTUNISTIC_TLS, new SslStartTlsEventtHandler());