case SSL_HS_CERT:
if (server_cert_data != nullptr)
{
+ if (size < SSL_CERTS_LEN_SIZE)
+ return retval | SSL_TRUNCATED_FLAG;
+
certs_rec = (const ServiceSSLV3CertsRecord*)handshake;
server_cert_data->certs_len = ntoh3(certs_rec->certs_len);
- if (server_cert_data->certs_len + sizeof(certs_rec->certs_len) > (unsigned int)size)
- {
+
+ if (server_cert_data->certs_len + SSL_CERTS_LEN_SIZE > (unsigned int)size)
return retval | SSL_TRUNCATED_FLAG;
- }
+
server_cert_data->certs_data = (uint8_t*)snort_alloc(server_cert_data->certs_len);
- memcpy(server_cert_data->certs_data, pkt + sizeof(certs_rec->certs_len), server_cert_data->certs_len);
+ memcpy(server_cert_data->certs_data, pkt + SSL_CERTS_LEN_SIZE, server_cert_data->certs_len);
snort::parse_server_certificates(server_cert_data);
}
break;
case SSL_ALERT_REC:
- if (reclen == sizeof(SSL_alert_t))
+ if (reclen == sizeof(SSL_alert_t) && size >= (int)sizeof(SSL_alert_t))
{
const SSL_alert_t* ssl_alert = (const SSL_alert_t*)pkt;
if (ssl_alert->level == SSL_ALERT_LEVEL_FATAL && info_flags)
if (lastpos != -1)
{
X509_NAME_ENTRY* e = X509_NAME_get_entry(cert_subject, lastpos);
- const unsigned char* str_data = ASN1_STRING_get0_data(X509_NAME_ENTRY_get_data(e));
- int length = strlen((const char*)str_data);
+ ASN1_STRING* asn1_str = X509_NAME_ENTRY_get_data(e);
+ const unsigned char* str_data = ASN1_STRING_get0_data(asn1_str);
+ int length = ASN1_STRING_length(asn1_str);
bool wildcard = false;
if ((wildcard = (length > 2 and *str_data == '*' and *(str_data + 1) == '.')))
if (lastpos != -1)
{
X509_NAME_ENTRY* e = X509_NAME_get_entry(cert_subject, lastpos);
- const unsigned char* str_data = ASN1_STRING_get0_data(X509_NAME_ENTRY_get_data(e));
- org_unit_len = strlen((const char*)str_data);
+ ASN1_STRING* asn1_str = X509_NAME_ENTRY_get_data(e);
+ const unsigned char* str_data = ASN1_STRING_get0_data(asn1_str);
+ org_unit_len = ASN1_STRING_length(asn1_str);
org_unit = snort_strndup((const char*)(str_data), org_unit_len);
}
}
using namespace snort;
+// When set, mocks return valid objects
+static const unsigned char* g_test_asn1_data = nullptr;
+static int g_test_asn1_len = 0;
+
+static char g_mock_mem[64];
+
typedef struct X509_name_entry_st X509_NAME_ENTRY;
-X509_NAME *X509_get_subject_name(const X509 *a) { return nullptr; }
+X509_NAME *X509_get_subject_name(const X509 *a)
+{ return g_test_asn1_data ? (X509_NAME*)g_mock_mem : nullptr; }
X509_NAME *X509_get_issuer_name(const X509 *a) { return nullptr; }
void X509_free(X509* a) { }
#if OPENSSL_VERSION_NUMBER < 0x30000000L
#else
int X509_NAME_get_index_by_NID(const X509_NAME *name, int nid, int lastpos)
#endif
-{ return -1; }
-X509_NAME_ENTRY *X509_NAME_get_entry(const X509_NAME *name, int loc) { return nullptr; }
-ASN1_STRING *X509_NAME_ENTRY_get_data(const X509_NAME_ENTRY *ne) { return nullptr; }
-const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *x) { return nullptr; }
-X509* d2i_X509(X509 **a, const unsigned char **in, long len) { return nullptr; }
+{ return g_test_asn1_data ? 0 : -1; }
+X509_NAME_ENTRY *X509_NAME_get_entry(const X509_NAME *name, int loc)
+{ return g_test_asn1_data ? (X509_NAME_ENTRY*)(g_mock_mem + 16) : nullptr; }
+ASN1_STRING *X509_NAME_ENTRY_get_data(const X509_NAME_ENTRY *ne)
+{ return g_test_asn1_data ? (ASN1_STRING*)(g_mock_mem + 32) : nullptr; }
+const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *x)
+{ return g_test_asn1_data; }
+X509* d2i_X509(X509 **a, const unsigned char **in, long len)
+{ return g_test_asn1_data ? (X509*)(g_mock_mem + 48) : nullptr; }
int X509_NAME_print_ex(BIO *out, const X509_NAME *nm, int indent, unsigned long flags) { return 0; }
BIO *BIO_new(const BIO_METHOD *type) { return nullptr; }
int BIO_free(BIO *a) { return 0; }
long BIO_ctrl(BIO *bp, int cmd, long larg, void *parg) { return 0; }
const BIO_METHOD *BIO_s_mem(void) { return nullptr; }
+int ASN1_STRING_length(const ASN1_STRING *x) { return g_test_asn1_len; }
namespace snort
{
return str ? strdup(str) : nullptr;
}
-char* snort_strndup(const char* src, size_t)
+char* snort_strndup(const char* src, size_t n)
{
- return snort_strdup(src);
+ if (!src)
+ return nullptr;
+ char* dst = new char[n + 1];
+ memcpy(dst, src, n);
+ dst[n] = '\0';
+ return dst;
}
}
void teardown() override
{
+ g_test_asn1_data = nullptr;
+ g_test_asn1_len = 0;
}
};
CHECK_EQUAL(ParseHelloResult::FRAGMENTED_PACKET, result);
}
+
+TEST(ssl_protocol_tests, ssl_hs_cert_truncated_certs_len)
+{
+ uint8_t test_data[] = {
+ 0x16, // Content Type: Handshake
+ 0x03, 0x03, // Version: TLS 1.2
+ 0x00, 0x04, // Length: 4 bytes (just the handshake header)
+ 0x0b, // Handshake Type: Certificate
+ 0x00, 0x00, 0x00 // Handshake Length: 0 (no cert data follows)
+ // Missing: certs_len (3 bytes)
+ };
+
+ SSLV3ClientHelloData client_hello;
+ SSLV3ServerCertData server_cert;
+ uint32_t result = SSL_decode(test_data, sizeof(test_data), 0, 0,
+ nullptr, nullptr, 0, nullptr, &client_hello, &server_cert, nullptr);
+
+ CHECK(result & SSL_TRUNCATED_FLAG);
+}
+
+TEST(ssl_protocol_tests, ssl_alert_rec_zero_size)
+{
+ uint8_t test_data[] = {
+ 0x15, // Content Type: Alert
+ 0x03, 0x03, // Version: TLS 1.2
+ 0x00, 0x02 // Length: 2 (claims 2 bytes but 0 bytes follow)
+ };
+
+ uint32_t info_flags = 0;
+ uint32_t result = SSL_decode(test_data, sizeof(test_data), 0, 0,
+ nullptr, nullptr, 0, &info_flags, nullptr, nullptr, nullptr);
+
+ CHECK(result & SSL_ALERT_FLAG);
+
+ CHECK_EQUAL(0, (info_flags & SSL_ALERT_LVL_FATAL_FLAG));
+}
+
+TEST(ssl_protocol_tests, ssl_cert_common_name_parsing)
+{
+ uint8_t cn_data[4] = { 'T', 'E', 'S', 'T' };
+
+ g_test_asn1_data = cn_data;
+ g_test_asn1_len = 4;
+
+ // Minimal cert data to trigger parsing
+ uint8_t cert_data[] = {
+ 0x00, 0x00, 0x03, // cert length: 3
+ 0x30, 0x01, 0x00 // minimal DER
+ };
+
+ SSLV3ServerCertData* server_cert = new SSLV3ServerCertData();
+ server_cert->certs_data = (uint8_t*)snort_alloc(sizeof(cert_data));
+ memcpy(server_cert->certs_data, cert_data, sizeof(cert_data));
+ server_cert->certs_len = sizeof(cert_data);
+
+ parse_server_certificates(server_cert);
+
+ delete server_cert;
+
+ CHECK(true);
+}
+
int main(int argc, char** argv)
{
return CommandLineTestRunner::RunAllTests(argc, argv);