]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #5065: appid: prevent multiple oob reads in ssl
authorBohdan Hryniv -X (bhryniv - SOFTSERVE INC at Cisco) <bhryniv@cisco.com>
Tue, 6 Jan 2026 20:43:22 +0000 (20:43 +0000)
committerChris Sherwin (chsherwi) <chsherwi@cisco.com>
Tue, 6 Jan 2026 20:43:22 +0000 (20:43 +0000)
Merge in SNORT/snort3 from ~BHRYNIV/snort3:fix_ssl_oob to master

Squashed commit of the following:

commit e1d42bb9c34f6e2af3ec0a94a404a64291ff8c20
Author: Bohdan Hryniv <bhryniv@cisco>
Date:   Tue Dec 16 08:45:23 2025 -0500

    appid: prevent multiple oob reads in ssl

src/protocols/ssl.cc
src/protocols/ssl.h
src/protocols/test/ssl_protocol_test.cc

index 874610c643a235688d8dfb96924efc45b8a5d08b..4ef9a356c7fc553959c7f558912c45ec1eeb0ad2 100644 (file)
@@ -229,14 +229,17 @@ static uint32_t SSL_decode_handshake_v3(const uint8_t* pkt, int size,
         case SSL_HS_CERT:
             if (server_cert_data != nullptr)
             {
+                if (size < SSL_CERTS_LEN_SIZE)
+                    return retval | SSL_TRUNCATED_FLAG;
+
                 certs_rec = (const ServiceSSLV3CertsRecord*)handshake;
                 server_cert_data->certs_len = ntoh3(certs_rec->certs_len);
-                if (server_cert_data->certs_len + sizeof(certs_rec->certs_len) > (unsigned int)size)
-                {
+
+                if (server_cert_data->certs_len + SSL_CERTS_LEN_SIZE > (unsigned int)size)
                     return retval | SSL_TRUNCATED_FLAG;
-                }
+
                 server_cert_data->certs_data = (uint8_t*)snort_alloc(server_cert_data->certs_len);
-                memcpy(server_cert_data->certs_data, pkt + sizeof(certs_rec->certs_len), server_cert_data->certs_len);
+                memcpy(server_cert_data->certs_data, pkt + SSL_CERTS_LEN_SIZE, server_cert_data->certs_len);
 
                 snort::parse_server_certificates(server_cert_data);
             }
@@ -327,7 +330,7 @@ static uint32_t SSL_decode_v3(const uint8_t* pkt, int size, uint32_t pkt_flags,
             break;
 
         case SSL_ALERT_REC:
-            if (reclen == sizeof(SSL_alert_t))
+            if (reclen == sizeof(SSL_alert_t) && size >= (int)sizeof(SSL_alert_t))
             {
                 const SSL_alert_t* ssl_alert = (const SSL_alert_t*)pkt;
                 if (ssl_alert->level == SSL_ALERT_LEVEL_FATAL && info_flags)
@@ -910,8 +913,9 @@ bool parse_server_certificates(SSLV3ServerCertData* server_cert_data)
             if (lastpos != -1)
             {
                 X509_NAME_ENTRY* e = X509_NAME_get_entry(cert_subject, lastpos);
-                const unsigned char* str_data = ASN1_STRING_get0_data(X509_NAME_ENTRY_get_data(e));
-                int length = strlen((const char*)str_data);
+                ASN1_STRING* asn1_str = X509_NAME_ENTRY_get_data(e);
+                const unsigned char* str_data = ASN1_STRING_get0_data(asn1_str);
+                int length = ASN1_STRING_length(asn1_str);
 
                 bool wildcard = false;
                 if ((wildcard = (length > 2 and *str_data == '*' and *(str_data + 1) == '.')))
@@ -929,8 +933,9 @@ bool parse_server_certificates(SSLV3ServerCertData* server_cert_data)
             if (lastpos != -1)
             {
                 X509_NAME_ENTRY* e = X509_NAME_get_entry(cert_subject, lastpos);
-                const unsigned char* str_data = ASN1_STRING_get0_data(X509_NAME_ENTRY_get_data(e));
-                org_unit_len = strlen((const char*)str_data);
+                ASN1_STRING* asn1_str = X509_NAME_ENTRY_get_data(e);
+                const unsigned char* str_data = ASN1_STRING_get0_data(asn1_str);
+                org_unit_len = ASN1_STRING_length(asn1_str);
                 org_unit = snort_strndup((const char*)(str_data), org_unit_len);
             }
         }
index 564a2cee64d264ad457e543e87e851f6fc32f631..c875c292961db589bf0fde194cb6a3903cda1d76 100644 (file)
@@ -60,6 +60,8 @@
     + (uint32_t)(((const uint8_t*)(msb_ptr))[1] << 8) \
     + (uint32_t)(((const uint8_t*)(msb_ptr))[2])))
 
+#define SSL_CERTS_LEN_SIZE 3
+
 #define SSL_VERFLAGS \
     (SSL_VER_SSLV2_FLAG | SSL_VER_SSLV3_FLAG | \
     SSL_VER_TLS10_FLAG | SSL_VER_TLS11_FLAG | \
index 39c4cfbcf5da01a15deb3d8b73469b8390000576..5892903271d1ce18463665e833c0e85f46f13792 100644 (file)
@@ -37,8 +37,15 @@ void LogMessage(char const*, ...) {}
 
 using namespace snort;
 
+// When set, mocks return valid objects
+static const unsigned char* g_test_asn1_data = nullptr;
+static int g_test_asn1_len = 0;
+
+static char g_mock_mem[64];
+
 typedef struct X509_name_entry_st X509_NAME_ENTRY;
-X509_NAME *X509_get_subject_name(const X509 *a) { return nullptr; }
+X509_NAME *X509_get_subject_name(const X509 *a) 
+{ return g_test_asn1_data ? (X509_NAME*)g_mock_mem : nullptr; }
 X509_NAME *X509_get_issuer_name(const X509 *a) { return nullptr; }
 void X509_free(X509* a) { }
 #if OPENSSL_VERSION_NUMBER < 0x30000000L
@@ -46,16 +53,21 @@ int X509_NAME_get_index_by_NID(X509_NAME *name, int nid, int lastpos)
 #else
 int X509_NAME_get_index_by_NID(const X509_NAME *name, int nid, int lastpos)
 #endif
-{ return -1; }
-X509_NAME_ENTRY *X509_NAME_get_entry(const X509_NAME *name, int loc) { return nullptr; }
-ASN1_STRING *X509_NAME_ENTRY_get_data(const X509_NAME_ENTRY *ne) { return nullptr; }
-const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *x) { return nullptr; }
-X509* d2i_X509(X509 **a, const unsigned char **in, long len) { return nullptr; }
+{ return g_test_asn1_data ? 0 : -1; }
+X509_NAME_ENTRY *X509_NAME_get_entry(const X509_NAME *name, int loc) 
+{ return g_test_asn1_data ? (X509_NAME_ENTRY*)(g_mock_mem + 16) : nullptr; }
+ASN1_STRING *X509_NAME_ENTRY_get_data(const X509_NAME_ENTRY *ne) 
+{ return g_test_asn1_data ? (ASN1_STRING*)(g_mock_mem + 32) : nullptr; }
+const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *x) 
+{ return g_test_asn1_data; }
+X509* d2i_X509(X509 **a, const unsigned char **in, long len) 
+{ return g_test_asn1_data ? (X509*)(g_mock_mem + 48) : nullptr; }
 int X509_NAME_print_ex(BIO *out, const X509_NAME *nm, int indent, unsigned long flags) { return 0; }
 BIO *BIO_new(const BIO_METHOD *type) { return nullptr; }
 int BIO_free(BIO *a) { return 0; }
 long BIO_ctrl(BIO *bp, int cmd, long larg, void *parg) { return 0; }
 const BIO_METHOD *BIO_s_mem(void) { return nullptr; }
+int ASN1_STRING_length(const ASN1_STRING *x) { return g_test_asn1_len; }
 
 namespace snort
 {
@@ -64,9 +76,14 @@ char* snort_strdup(const char* str)
     return str ? strdup(str) : nullptr;
 }
 
-char* snort_strndup(const char* src, size_t)
+char* snort_strndup(const char* src, size_t n)
 {
-    return snort_strdup(src);
+    if (!src)
+        return nullptr;
+    char* dst = new char[n + 1];
+    memcpy(dst, src, n);
+    dst[n] = '\0';
+    return dst;
 }
 }
 
@@ -78,6 +95,8 @@ TEST_GROUP(ssl_protocol_tests)
 
     void teardown() override
     {
+        g_test_asn1_data = nullptr;
+        g_test_asn1_len = 0;
     }
 };
 
@@ -182,6 +201,68 @@ TEST(ssl_protocol_tests, parse_server_hello_invalid_packet_len)
     CHECK_EQUAL(ParseHelloResult::FRAGMENTED_PACKET, result);
 }
 
+
+TEST(ssl_protocol_tests, ssl_hs_cert_truncated_certs_len)
+{
+    uint8_t test_data[] = {
+        0x16,                           // Content Type: Handshake
+        0x03, 0x03,                 // Version: TLS 1.2
+        0x00, 0x04,                 // Length: 4 bytes (just the handshake header)
+        0x0b,                           // Handshake Type: Certificate
+        0x00, 0x00, 0x00        // Handshake Length: 0 (no cert data follows)
+        // Missing: certs_len (3 bytes) 
+    };
+
+    SSLV3ClientHelloData client_hello;
+    SSLV3ServerCertData server_cert;
+    uint32_t result = SSL_decode(test_data, sizeof(test_data), 0, 0,
+        nullptr, nullptr, 0, nullptr, &client_hello, &server_cert, nullptr);
+
+    CHECK(result & SSL_TRUNCATED_FLAG);
+}
+
+TEST(ssl_protocol_tests, ssl_alert_rec_zero_size)
+{
+    uint8_t test_data[] = {
+        0x15,                 // Content Type: Alert
+        0x03, 0x03,       // Version: TLS 1.2
+        0x00, 0x02        // Length: 2 (claims 2 bytes but 0 bytes follow)
+    };
+
+    uint32_t info_flags = 0;
+    uint32_t result = SSL_decode(test_data, sizeof(test_data), 0, 0,
+        nullptr, nullptr, 0, &info_flags, nullptr, nullptr, nullptr);
+
+    CHECK(result & SSL_ALERT_FLAG);
+
+    CHECK_EQUAL(0, (info_flags & SSL_ALERT_LVL_FATAL_FLAG));
+}
+
+TEST(ssl_protocol_tests, ssl_cert_common_name_parsing)
+{
+    uint8_t cn_data[4] = { 'T', 'E', 'S', 'T' };
+    
+    g_test_asn1_data = cn_data;
+    g_test_asn1_len = 4;
+    
+    // Minimal cert data to trigger parsing
+    uint8_t cert_data[] = {
+        0x00, 0x00, 0x03,  // cert length: 3
+        0x30, 0x01, 0x00   // minimal DER
+    };
+    
+    SSLV3ServerCertData* server_cert = new SSLV3ServerCertData();
+    server_cert->certs_data = (uint8_t*)snort_alloc(sizeof(cert_data));
+    memcpy(server_cert->certs_data, cert_data, sizeof(cert_data));
+    server_cert->certs_len = sizeof(cert_data);
+    
+    parse_server_certificates(server_cert);
+    
+    delete server_cert;
+    
+    CHECK(true);
+}
+
 int main(int argc, char** argv)
 {
     return CommandLineTestRunner::RunAllTests(argc, argv);