From: Steve Chew (stechew) Date: Thu, 2 Feb 2023 14:33:44 +0000 (+0000) Subject: Pull request #3745: ssl: refactor client hello sni parsing X-Git-Tag: 3.1.55.0~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=46d5e51bb2e363d1a6b000de3f386a648280406c;p=thirdparty%2Fsnort3.git Pull request #3745: ssl: refactor client hello sni parsing Merge in SNORT/snort3 from ~SVLASIUK/snort3:ch_sni_parser to master Squashed commit of the following: commit afe66704e8e0249f023fdd6952092227d1af3e64 Author: Serhii Vlasiuk Date: Tue Jan 17 13:25:56 2023 +0200 ssl: refactor ssl client hello parser to be used by appid/ssl inspectors --- diff --git a/src/network_inspectors/appid/service_plugins/service_ssl.cc b/src/network_inspectors/appid/service_plugins/service_ssl.cc index 819a4eb31..c12e4d1f2 100644 --- a/src/network_inspectors/appid/service_plugins/service_ssl.cc +++ b/src/network_inspectors/appid/service_plugins/service_ssl.cc @@ -29,6 +29,7 @@ #include "app_info_table.h" #include "protocols/packet.h" +#include "protocols/ssl.h" using namespace snort; @@ -42,13 +43,6 @@ enum SSLContentType SSL_APPLICATION_DATA = 23 }; -#define SSL_CLIENT_HELLO 1 -#define SSL_SERVER_HELLO 2 -#define SSL_CERTIFICATE 11 -#define SSL_SERVER_KEY_XCHG 12 -#define SSL_SERVER_CERT_REQ 13 -#define SSL_SERVER_HELLO_DONE 14 -#define SSL_CERTIFICATE_STATUS 22 #define SSL2_SERVER_HELLO 4 #define PCT_SERVER_HELLO 2 @@ -56,9 +50,6 @@ enum SSLContentType #define COMMON_NAME_STR "/CN=" #define ORG_NAME_STR "/O=" -/* Extension types. */ -#define SSL_EXT_SERVER_NAME 0 - enum SSLState { SSL_STATE_INITIATE, // Client initiates. @@ -73,8 +64,7 @@ struct ServiceSSLData int length; int tot_length; /* From client: */ - char* host_name; - int host_name_strlen; + SSLV3ClientHelloData client_hello; /* While collecting certificates: */ unsigned certs_len; // (Total) length of certificate(s). uint8_t* certs_data; // Certificate(s) data (each proceeded by length (3 bytes)). @@ -99,20 +89,6 @@ struct ServiceSSLV3Hdr uint16_t len; }; -/* Usually referred to as a TLS Handshake. */ -struct ServiceSSLV3Record -{ - uint8_t type; - uint8_t length_msb; - uint16_t length; - uint16_t version; - struct - { - uint32_t time; - uint8_t data[28]; - } random; -}; - /* Usually referred to as a Certificate Handshake. */ struct ServiceSSLV3CertsRecord { @@ -126,16 +102,6 @@ struct ServiceSSLV3CertsRecord * - Data : "Length" bytes */ }; -struct ServiceSSLV3ExtensionServerName -{ - uint16_t type; - uint16_t length; - uint16_t list_length; - uint8_t string_length_msb; - uint16_t string_length; - /* String follows. */ -}; - struct ServiceSSLPCTHdr { uint8_t len; @@ -253,8 +219,8 @@ static void ssl_free(void* ss) { ServiceSSLData* ss_tmp = (ServiceSSLData*)ss; snort_free(ss_tmp->certs_data); - snort_free(ss_tmp->host_name); snort_free(ss_tmp->common_name); + ss_tmp->client_hello.clear(); snort_free(ss_tmp->org_name); ssl_cache_free(ss_tmp->cached_data, ss_tmp->cached_len); snort_free(ss_tmp); @@ -263,8 +229,6 @@ static void ssl_free(void* ss) static void parse_client_initiation(const uint8_t* data, uint16_t size, ServiceSSLData* ss) { const ServiceSSLV3Hdr* hdr3; - const ServiceSSLV3Record* rec; - unsigned length; uint16_t ver; /* Sanity check header stuff. */ @@ -280,85 +244,7 @@ static void parse_client_initiation(const uint8_t* data, uint16_t size, ServiceS data += sizeof(ServiceSSLV3Hdr); size -= sizeof(ServiceSSLV3Hdr); - if (size < sizeof(ServiceSSLV3Record)) - return; - rec = (const ServiceSSLV3Record*)data; - ver = ntohs(rec->version); - if (rec->type != SSL_CLIENT_HELLO || (ver != 0x0300 && ver != 0x0301 && ver != 0x0302 && - ver != 0x0303) || rec->length_msb) - { - return; - } - length = ntohs(rec->length) + offsetof(ServiceSSLV3Record, version); - if (size < length) - return; - data += sizeof(ServiceSSLV3Record); - size -= sizeof(ServiceSSLV3Record); - - /* Session ID (1-byte length). */ - if (size < 1) - return; - length = *((const uint8_t*)data); - data += length + 1; - if (size < (length + 1)) - return; - size -= length + 1; - - /* Cipher Suites (2-byte length). */ - if (size < 2) - return; - length = ntohs(*((const uint16_t*)data)); - data += length + 2; - if (size < (length + 2)) - return; - size -= length + 2; - - /* Compression Methods (1-byte length). */ - if (size < 1) - return; - length = *((const uint8_t*)data); - data += length + 1; - if (size < (length + 1)) - return; - size -= length + 1; - - /* Extensions (2-byte length) */ - if (size < 2) - return; - length = ntohs(*((const uint16_t*)data)); - data += 2; - size -= 2; - if (size < length) - return; - - /* We need at least type (2 bytes) and length (2 bytes) in the extension. */ - while (length >= 4) - { - const ServiceSSLV3ExtensionServerName* ext = (const ServiceSSLV3ExtensionServerName*)data; - if (ntohs(ext->type) == SSL_EXT_SERVER_NAME) - { - /* Found server host name. */ - if (length < sizeof(ServiceSSLV3ExtensionServerName)) - return; - - unsigned len = ntohs(ext->string_length); - if ((length - sizeof(ServiceSSLV3ExtensionServerName)) < len) - return; - - const uint8_t* str = data + offsetof(ServiceSSLV3ExtensionServerName, string_length) + - sizeof(ext->string_length); - ss->host_name = snort_strndup((const char*)str, len); - ss->host_name_strlen = len; - return; - } - - unsigned len = ntohs(ext->length) + offsetof(ServiceSSLV3ExtensionServerName, list_length); - if (len > length) - return; - - data += len; - length -= len; - } + parse_client_hello_data(data, size, &ss->client_hello); } static bool parse_certificates(ServiceSSLData* ss) @@ -561,7 +447,7 @@ int SslServiceDetector::validate(AppIdDiscoveryArgs& args) data += sizeof(ServiceSSLV3Hdr); size -= sizeof(ServiceSSLV3Hdr); rec = (const ServiceSSLV3Record*)data; - if (size < sizeof(ServiceSSLV3Record) || rec->type != SSL_SERVER_HELLO || + if (size < sizeof(ServiceSSLV3Record) || rec->type != SSLV3RecordType::SERVER_HELLO || (ntohs(rec->version) != 0x0300 && ntohs(rec->version) != 0x0301 && ntohs(rec->version) != 0x0302 && ntohs(rec->version) != 0x0303) || rec->length_msb) @@ -622,13 +508,13 @@ int SslServiceDetector::validate(AppIdDiscoveryArgs& args) } rec = (const ServiceSSLV3Record*)data; - if (rec->type != SSL_SERVER_HELLO_DONE and rec->length_msb) + if (rec->type != SSLV3RecordType::SERVER_HELLO_DONE and rec->length_msb) { goto fail; } switch (rec->type) { - case SSL_CERTIFICATE: + case SSLV3RecordType::CERTIFICATE: /* Start pulling out certificates. */ if (!ss->certs_data) { @@ -655,9 +541,9 @@ int SslServiceDetector::validate(AppIdDiscoveryArgs& args) } } /* fall through */ - case SSL_CERTIFICATE_STATUS: - case SSL_SERVER_KEY_XCHG: - case SSL_SERVER_CERT_REQ: + case SSLV3RecordType::CERTIFICATE_STATUS: + case SSLV3RecordType::SERVER_KEY_XCHG: + case SSLV3RecordType::SERVER_CERT_REQ: ss->length = ntohs(rec->length) + offsetof(ServiceSSLV3Record, version); if (ss->tot_length < ss->length) goto fail; @@ -674,7 +560,7 @@ int SslServiceDetector::validate(AppIdDiscoveryArgs& args) ss->pos = 0; } break; - case SSL_SERVER_HELLO_DONE: + case SSLV3RecordType::SERVER_HELLO_DONE: if (size < offsetof(ServiceSSLV3Record, version)) goto success; if (rec->length) @@ -736,11 +622,11 @@ fail: if (reallocated_data) snort_free(reallocated_data); snort_free(ss->certs_data); - snort_free(ss->host_name); + ss->client_hello.clear(); snort_free(ss->common_name); snort_free(ss->org_name); ss->certs_data = nullptr; - ss->host_name = ss->common_name = ss->org_name = nullptr; + ss->common_name = ss->org_name = nullptr; fail_service(args.asd, args.pkt, args.dir); return APPID_NOMATCH; @@ -758,15 +644,15 @@ success: } args.asd.set_session_flags(APPID_SESSION_SSL_SESSION); - if (ss->host_name || ss->common_name || ss->org_name) + if (ss->client_hello.host_name || ss->common_name || ss->org_name) { if (!args.asd.tsession) args.asd.tsession = new TlsSession(); /* TLS Host */ - if (ss->host_name) + if (ss->client_hello.host_name) { - args.asd.tsession->set_tls_host(ss->host_name, 0, args.change_bits); + args.asd.tsession->set_tls_host(ss->client_hello.host_name, 0, args.change_bits); args.asd.scan_flags |= SCAN_SSL_HOST_FLAG; } else if (ss->common_name) @@ -786,7 +672,7 @@ success: if (ss->org_name) args.asd.tsession->set_tls_org_unit(ss->org_name, 0); - ss->host_name = ss->common_name = ss->org_name = nullptr; + ss->client_hello.host_name = ss->common_name = ss->org_name = nullptr; args.asd.tsession->set_tls_handshake_done(); } return add_service(args.change_bits, args.asd, args.pkt, args.dir, diff --git a/src/protocols/ssl.cc b/src/protocols/ssl.cc index 11ea94b49..f8d70b622 100644 --- a/src/protocols/ssl.cc +++ b/src/protocols/ssl.cc @@ -26,6 +26,7 @@ #include "ssl.h" #include "packet.h" +#include "utils/util.h" #define THREE_BYTE_LEN(x) ((x)[2] | (x)[1] << 8 | (x)[0] << 16) @@ -40,6 +41,17 @@ #define SSL2_CHELLO_BYTE 0x01 #define SSL2_SHELLO_BYTE 0x04 +SSLV3ClientHelloData::~SSLV3ClientHelloData() +{ + snort_free(host_name); +} + +void SSLV3ClientHelloData::clear() +{ + snort_free(host_name); + host_name = nullptr; +} + static uint32_t SSL_decode_version_v3(uint8_t major, uint8_t minor) { /* Should only be called internally and by functions which have previously @@ -69,7 +81,7 @@ static uint32_t SSL_decode_version_v3(uint8_t major, uint8_t minor) } static uint32_t SSL_decode_handshake_v3(const uint8_t* pkt, int size, - uint32_t cur_flags, uint32_t pkt_flags) + uint32_t cur_flags, uint32_t pkt_flags, SSLV3ClientHelloData* client_hello_data) { const SSL_handshake_hello_t* hello; uint32_t retval = 0; @@ -114,6 +126,8 @@ static uint32_t SSL_decode_handshake_v3(const uint8_t* pkt, int size, hello = (const SSL_handshake_hello_t*)handshake; retval |= SSL_decode_version_v3(hello->major, hello->minor); + snort::parse_client_hello_data((const uint8_t*)handshake, size + SSL_HS_PAYLOAD_OFFSET, client_hello_data); + break; case SSL_HS_SHELLO: @@ -190,7 +204,8 @@ static uint32_t SSL_decode_handshake_v3(const uint8_t* pkt, int size, } static uint32_t SSL_decode_v3(const uint8_t* pkt, int size, uint32_t pkt_flags, - uint8_t* alert_flags, uint16_t* partial_rec_len, int max_hb_len, uint32_t* info_flags) + uint8_t* alert_flags, uint16_t* partial_rec_len, int max_hb_len, uint32_t* info_flags, + SSLV3ClientHelloData* client_hello_data) { uint32_t retval = 0; uint16_t hblen; @@ -284,7 +299,7 @@ static uint32_t SSL_decode_v3(const uint8_t* pkt, int size, uint32_t pkt_flags, if (!(retval & SSL_CHANGE_CIPHER_FLAG)) { int hsize = size < (int)reclen ? size : (int)reclen; - retval |= SSL_decode_handshake_v3(pkt, hsize, retval, pkt_flags); + retval |= SSL_decode_handshake_v3(pkt, hsize, retval, pkt_flags, client_hello_data); } else if (ccs) { @@ -427,7 +442,8 @@ namespace snort { uint32_t SSL_decode( const uint8_t* pkt, int size, uint32_t pkt_flags, uint32_t prev_flags, - uint8_t* alert_flags, uint16_t* partial_rec_len, int max_hb_len, uint32_t* info_flags) + uint8_t* alert_flags, uint16_t* partial_rec_len, int max_hb_len, uint32_t* info_flags, + SSLV3ClientHelloData* sslv3_chello_data) { if (!pkt || !size) return SSL_ARG_ERROR_FLAG; @@ -448,7 +464,7 @@ uint32_t SSL_decode( * SSLv2 as TLS,the decoder will either catch a bad type, bad version, or * indicate that it is truncated. */ if (size == 5) - return SSL_decode_v3(pkt, size, pkt_flags, alert_flags, partial_rec_len, max_hb_len, info_flags); + return SSL_decode_v3(pkt, size, pkt_flags, alert_flags, partial_rec_len, max_hb_len, info_flags, sslv3_chello_data); /* At this point, 'size' has to be > 5 */ @@ -495,7 +511,7 @@ uint32_t SSL_decode( } } - return SSL_decode_v3(pkt, size, pkt_flags, alert_flags, partial_rec_len, max_hb_len, info_flags); + return SSL_decode_v3(pkt, size, pkt_flags, alert_flags, partial_rec_len, max_hb_len, info_flags, sslv3_chello_data); } /* very simplistic - just enough to say this is binary data - the rules will make a final @@ -557,4 +573,89 @@ bool IsSSL(const uint8_t* ptr, int len, int pkt_flags) return false; } +void parse_client_hello_data(const uint8_t* pkt, uint16_t size, SSLV3ClientHelloData* client_hello_data) +{ + if (client_hello_data == nullptr) + return; + + if (size < sizeof(ServiceSSLV3Record)) + return; + const ServiceSSLV3Record* rec = (const ServiceSSLV3Record*)pkt; + uint16_t ver = ntohs(rec->version); + if (rec->type != SSLV3RecordType::CLIENT_HELLO || (ver != 0x0300 && ver != 0x0301 && ver != 0x0302 && + ver != 0x0303) || rec->length_msb) + { + return; + } + unsigned length = ntohs(rec->length) + offsetof(ServiceSSLV3Record, version); + if (size < length) + return; + pkt += sizeof(ServiceSSLV3Record); + size -= sizeof(ServiceSSLV3Record); + + /* Session ID (1-byte length). */ + if (size < 1) + return; + length = *((const uint8_t*)pkt); + pkt += length + 1; + if (size < (length + 1)) + return; + size -= length + 1; + + /* Cipher Suites (2-byte length). */ + if (size < 2) + return; + length = ntohs(*((const uint16_t*)pkt)); + pkt += length + 2; + if (size < (length + 2)) + return; + size -= length + 2; + + /* Compression Methods (1-byte length). */ + if (size < 1) + return; + length = *((const uint8_t*)pkt); + pkt += length + 1; + if (size < (length + 1)) + return; + size -= length + 1; + + /* Extensions (2-byte length) */ + if (size < 2) + return; + length = ntohs(*((const uint16_t*)pkt)); + pkt += 2; + size -= 2; + if (size < length) + return; + + /* We need at least type (2 bytes) and length (2 bytes) in the extension. */ + while (length >= 4) + { + const ServiceSSLV3ExtensionServerName* ext = (const ServiceSSLV3ExtensionServerName*)pkt; + if (ntohs(ext->type) == SSL_EXT_SERVER_NAME) + { + /* Found server host name. */ + if (length < sizeof(ServiceSSLV3ExtensionServerName)) + return; + + unsigned len = ntohs(ext->string_length); + if ((length - sizeof(ServiceSSLV3ExtensionServerName)) < len) + return; + + const uint8_t* str = pkt + offsetof(ServiceSSLV3ExtensionServerName, string_length) + + sizeof(ext->string_length); + client_hello_data->host_name = snort_strndup((const char*)str, len); + return; + } + + unsigned len = ntohs(ext->length) + offsetof(ServiceSSLV3ExtensionServerName, list_length); + if (len > length) + return; + + pkt += len; + length -= len; + } +} + } // namespace snort diff --git a/src/protocols/ssl.h b/src/protocols/ssl.h index d10e6078f..664970647 100644 --- a/src/protocols/ssl.h +++ b/src/protocols/ssl.h @@ -195,6 +195,51 @@ struct SSLv2_shello_t uint8_t minor; }; +struct SSLV3ClientHelloData +{ + ~SSLV3ClientHelloData(); + void clear(); + char* host_name = nullptr; +}; + +enum class SSLV3RecordType : uint8_t +{ + CLIENT_HELLO = 1, + SERVER_HELLO = 2, + CERTIFICATE = 11, + SERVER_KEY_XCHG = 12, + SERVER_CERT_REQ = 13, + SERVER_HELLO_DONE = 14, + CERTIFICATE_STATUS = 22 +}; + +/* Usually referred to as a TLS Handshake. */ +struct ServiceSSLV3Record +{ + SSLV3RecordType type; + uint8_t length_msb; + uint16_t length; + uint16_t version; + struct + { + uint32_t time; + uint8_t data[28]; + } random; +}; + +struct ServiceSSLV3ExtensionServerName +{ + uint16_t type; + uint16_t length; + uint16_t list_length; + uint8_t string_length_msb; + uint16_t string_length; + /* String follows. */ +}; + +/* Extension types. */ +#define SSL_EXT_SERVER_NAME 0 + #define SSL_V2_MIN_LEN 5 #pragma pack() @@ -229,7 +274,10 @@ namespace snort { uint32_t SSL_decode( const uint8_t* pkt, int size, uint32_t pktflags, uint32_t prevflags, - uint8_t* alert_flags, uint16_t* partial_rec_len, int hblen, uint32_t* info_flags = nullptr); + uint8_t* alert_flags, uint16_t* partial_rec_len, int hblen, uint32_t* info_flags = nullptr, + SSLV3ClientHelloData* data = nullptr); + + void parse_client_hello_data(const uint8_t* pkt, uint16_t size, SSLV3ClientHelloData*); SO_PUBLIC bool IsTlsClientHello(const uint8_t* ptr, const uint8_t* end); SO_PUBLIC bool IsTlsServerHello(const uint8_t* ptr, const uint8_t* end); diff --git a/src/pub_sub/CMakeLists.txt b/src/pub_sub/CMakeLists.txt index 1787c7fc9..6ce86c532 100644 --- a/src/pub_sub/CMakeLists.txt +++ b/src/pub_sub/CMakeLists.txt @@ -25,6 +25,7 @@ set (PUB_SUB_INCLUDES stream_event_ids.h smb_events.h ssh_events.h + ssl_events.h ) add_library( pub_sub OBJECT diff --git a/src/pub_sub/ssl_events.h b/src/pub_sub/ssl_events.h new file mode 100644 index 000000000..a0566b661 --- /dev/null +++ b/src/pub_sub/ssl_events.h @@ -0,0 +1,50 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2022-2023 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// ssl_events.h author Serhii Vlasiuk + +#ifndef SSL_EVENTS_H +#define SSL_EVENTS_H + +// This event allows the SSL service inspector to publish extracted SSL handshake client hello data +// for use by data bus subscribers + +#include "framework/data_bus.h" + +struct SslEventIds { enum : unsigned { CHELLO_SERVER_NAME, num_ids }; }; + +const snort::PubKey ssl_chello_pub_key { "ssl_chello", SslEventIds::num_ids }; + +class SslClientHelloEvent : public snort::DataEvent +{ +public: + SslClientHelloEvent(const std::string& ch_server_name, const snort::Packet* packet) : + ch_server_name(ch_server_name), packet(packet) + { } + + const snort::Packet* get_packet() const override + { return packet; } + + const std::string& get_host_name() const + { return ch_server_name; } + +private: + const std::string ch_server_name; + const snort::Packet* packet; +}; + +#endif diff --git a/src/service_inspectors/ssl/ssl_inspector.cc b/src/service_inspectors/ssl/ssl_inspector.cc index 22b71fd6b..071942be3 100644 --- a/src/service_inspectors/ssl/ssl_inspector.cc +++ b/src/service_inspectors/ssl/ssl_inspector.cc @@ -37,6 +37,7 @@ #include "protocols/ssl.h" #include "pub_sub/finalize_packet_event.h" #include "pub_sub/opportunistic_tls_event.h" +#include "pub_sub/ssl_events.h" #include "stream/stream.h" #include "stream/stream_splitter.h" #include "trace/trace_api.h" @@ -44,6 +45,8 @@ #include "ssl_module.h" #include "ssl_splitter.h" +#include "utils/util.h" + using namespace snort; #define SSLPP_ENCRYPTED_FLAGS \ @@ -56,6 +59,8 @@ using namespace snort; THREAD_LOCAL ProfileStats sslPerfStats; THREAD_LOCAL SslStats sslstats; +static unsigned ssl_chello_pub_id = 0; + const PegInfo ssl_peg_names[] = { { CountType::SUM, "packets", "total packets processed" }, @@ -303,8 +308,15 @@ static void snort_ssl(SSL_PROTO_CONF* config, Packet* p) uint8_t heartbleed_type = 0; uint32_t info_flags = 0; + SSLV3ClientHelloData client_hello_data; uint32_t new_flags = SSL_decode(p->data, (int)p->dsize, p->packet_flags, sd->ssn_flags, - &heartbleed_type, &(sd->partial_rec_len[dir+index]), config->max_heartbeat_len, &info_flags); + &heartbleed_type, &(sd->partial_rec_len[dir+index]), config->max_heartbeat_len, &info_flags, &client_hello_data); + + if (client_hello_data.host_name != nullptr) + { + SslClientHelloEvent event(client_hello_data.host_name, p); + DataBus::publish(ssl_chello_pub_id, SslEventIds::CHELLO_SERVER_NAME, event); + } if (heartbleed_type & SSL_HEARTBLEED_REQUEST) { @@ -488,6 +500,8 @@ void Ssl::eval(Packet* p) bool Ssl::configure(SnortConfig*) { + ssl_chello_pub_id = DataBus::get_id(ssl_chello_pub_key); + DataBus::subscribe(intrinsic_pub_key, IntrinsicEventIds::FINALIZE_PACKET, new SslFinalizePacketHandler()); DataBus::subscribe(intrinsic_pub_key, IntrinsicEventIds::OPPORTUNISTIC_TLS, new SslStartTlsEventtHandler()); return true;