]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #2733 in SNORT/snort3 from ~OSHUMEIK/snort3:sslv2_curse to master
authorBhagya Tholpady (bbantwal) <bbantwal@cisco.com>
Thu, 11 Feb 2021 17:30:00 +0000 (17:30 +0000)
committerBhagya Tholpady (bbantwal) <bbantwal@cisco.com>
Thu, 11 Feb 2021 17:30:00 +0000 (17:30 +0000)
Squashed commit of the following:

commit af61d25062a0f28247cd017cd9a2f4269f0655bc
Author: ryanhoff <ryanhoff@cisco.com>
Date:   Tue Jan 21 16:55:33 2020 -0500

    wizard: add support for sslv2 detection

    The curse ignores specs/challenge/session_id length values.
    It's up to the inspector to decide about it.

lua/snort_defaults.lua
src/service_inspectors/wizard/CMakeLists.txt
src/service_inspectors/wizard/curses.cc
src/service_inspectors/wizard/curses.h
src/service_inspectors/wizard/wiz_module.cc

index 9afdff06a26ebae9341a912c7c93814936d4b26b..d95926bf0e11bddf6117d4b9c03cb5fd46dd9b04 100644 (file)
@@ -424,7 +424,7 @@ default_wizard =
           to_server = telnet_commands, to_client = telnet_commands },
     },
 
-    curses = {'dce_udp', 'dce_tcp', 'dce_smb'}
+    curses = {'dce_udp', 'dce_tcp', 'dce_smb', 'sslv2'}
 }
 
 ---------------------------------------------------------------------------
index a27614b4052d3710117480ccb77addaf8483c74a..31cbc290cd22e0d0cb264fe4fb9084ead4b54c8a 100644 (file)
@@ -18,3 +18,9 @@ else (STATIC_INSPECTORS)
     add_dynamic_module(wizard inspectors ${FILE_LIST})
 
 endif (STATIC_INSPECTORS)
+
+add_catch_test(curses_test
+    NO_TEST_SOURCE
+    SOURCES
+        curses.cc
+)
index 188a414aeb60178a056f893196714a527093fa4b..8f3f20edb144681cb479d10475bc56fd1a9a4f28 100644 (file)
@@ -103,35 +103,32 @@ static bool dce_udp_curse(const uint8_t* data, unsigned len, CurseTracker*)
 static bool dce_tcp_curse(const uint8_t* data, unsigned len, CurseTracker* tracker)
 {
     const uint8_t dce_rpc_co_hdr_len = 16;
+    CurseTracker::DCE& dce = tracker->dce;
 
     uint32_t n = 0;
     while (n < len)
     {
-        switch (tracker->state)
+        switch (dce.state)
         {
         case STATE_0: // check major version
-        {
             if (data[n] != DCERPC_PROTO_MAJOR_VERS__5)
             {
                 // go to bad state
-                tracker->state = STATE_10;
+                dce.state = STATE_10;
                 return false;
             }
-            tracker->state = (DCE_States)((int)tracker->state + 1);
+            dce.state = (DCE_State)((int)dce.state + 1);
             break;
-        }
 
         case STATE_1: // check minor version
-        {
             if (data[n] != DCERPC_PROTO_MINOR_VERS__0)
             {
                 // go to bad state
-                tracker->state = STATE_10;
+                dce.state = STATE_10;
                 return false;
             }
-            tracker->state = (DCE_States)((int)tracker->state + 1);
+            dce.state = (DCE_State)((int)dce.state + 1);
             break;
-        }
 
         case STATE_2: // pdu_type
         {
@@ -140,47 +137,45 @@ static bool dce_tcp_curse(const uint8_t* data, unsigned len, CurseTracker* track
                 (pdu_type != DCERPC_PDU_TYPE__BIND_ACK))
             {
                 // go to bad state
-                tracker->state = STATE_10;
+                dce.state = STATE_10;
                 return false;
             }
-            tracker->state = (DCE_States)((int)tracker->state + 1);
+            dce.state = (DCE_State)((int)dce.state + 1);
             break;
         }
 
         case STATE_4: //little endian
-            tracker->helper = (data[n] & 0x10) << 20;
-            tracker->state = (DCE_States)((int)tracker->state + 1);
+            dce.helper = (data[n] & 0x10) << 20;
+            dce.state = (DCE_State)((int)dce.state + 1);
             break;
         case STATE_8:
-            tracker->helper |= data[n];
-            tracker->state = (DCE_States)((int)tracker->state + 1);
+            dce.helper |= data[n];
+            dce.state = (DCE_State)((int)dce.state + 1);
             break;
         case STATE_9:
-        {
 #ifdef WORDS_BIGENDIAN
-            if (!(tracker->helper >> 24))
+            if (!(dce.helper >> 24))
 #else
-            if (tracker->helper >> 24)
+            if (dce.helper >> 24)
 #endif  /* WORDS_BIGENDIAN */
-                tracker->helper = (data[n] << 8) | (tracker->helper & 0XFF);
+                dce.helper = (data[n] << 8) | (dce.helper & 0XFF);
             else
             {
-                tracker->helper <<=8;
-                tracker->helper |= data[n];
+                dce.helper <<=8;
+                dce.helper |= data[n];
             }
 
-            if (tracker->helper >= dce_rpc_co_hdr_len)
+            if (dce.helper >= dce_rpc_co_hdr_len)
                 return true;
 
-            tracker->state = STATE_10;
+            dce.state = STATE_10;
             break;
-        }
 
         case STATE_10:
             // no match
             return false;
         default:
-            tracker->state = (DCE_States)((int)tracker->state + 1);
+            dce.state = (DCE_State)((int)dce.state + 1);
             break;
         }
         n++;
@@ -195,72 +190,66 @@ static bool dce_smb_curse(const uint8_t* data, unsigned len, CurseTracker* track
     const uint32_t dce_smb2_id = 0xfe534d42;  /* \xfeSMB */
     const uint8_t session_request = 0x81, session_response = 0x82,
                   session_message = 0x00;
+    CurseTracker::DCE& dce = tracker->dce;
 
     uint32_t n = 0;
     while (n < len)
     {
-        switch (tracker->state)
+        switch (dce.state)
         {
         case STATE_0:
-        {
             if (data[n] == session_message)
             {
-                tracker->state = (DCE_States)((int)tracker->state + 2);
+                dce.state = (DCE_State)((int)dce.state + 2);
                 break;
             }
 
             if (data[n] == session_request || data[n] == session_response)
             {
-                tracker->state = (DCE_States)((int)tracker->state + 1);
+                dce.state = (DCE_State)((int)dce.state + 1);
                 return false;
             }
 
-            tracker->state = STATE_9;
+            dce.state = STATE_9;
             return false;
-        }
+
         case STATE_1:
-        {
             if (data[n] == session_message)
             {
-                tracker->state = (DCE_States)((int)tracker->state + 1);
+                dce.state = (DCE_State)((int)dce.state + 1);
                 break;
             }
 
-            tracker->state = STATE_9;
+            dce.state = STATE_9;
             return false;
-        }
+
         case STATE_5:
-        {
-            tracker->helper = data[n];
-            tracker->state = (DCE_States)((int)tracker->state + 1);
+            dce.helper = data[n];
+            dce.state = (DCE_State)((int)dce.state + 1);
             break;
-        }
+
         case STATE_6:
         case STATE_7:
-        {
-            tracker->helper <<= 8;
-            tracker->helper |= data[n];
-            tracker->state = (DCE_States)((int)tracker->state + 1);
+            dce.helper <<= 8;
+            dce.helper |= data[n];
+            dce.state = (DCE_State)((int)dce.state + 1);
             break;
-        }
 
         case STATE_8:
-        {
-            tracker->helper <<= 8;
-            tracker->helper |= data[n];
-            if ((tracker->helper == dce_smb_id) || (tracker->helper == dce_smb2_id))
+            dce.helper <<= 8;
+            dce.helper |= data[n];
+            if ((dce.helper == dce_smb_id) || (dce.helper == dce_smb2_id))
                 return true;
 
-            tracker->state = (DCE_States)((int)tracker->state + 1);
+            dce.state = (DCE_State)((int)dce.state + 1);
             break;
-        }
 
         case STATE_9:
             // no match
             return false;
 
         default:
-            tracker->state = (DCE_States)((int)tracker->state + 1);
+            dce.state = (DCE_State)((int)dce.state + 1);
             break;
         }
         n++;
@@ -269,6 +258,138 @@ static bool dce_smb_curse(const uint8_t* data, unsigned len, CurseTracker* track
     return false;
 }
 
+namespace SSL_Const
+{
+static constexpr uint8_t hdr_len = 9;
+static constexpr uint8_t sslv2_msb_set = 0x80;
+static constexpr uint8_t client_hello = 0x01;
+static constexpr uint8_t sslv3_major_ver = 0x03;
+static constexpr uint8_t sslv3_max_minor_ver = 0x03;
+}
+
+static bool ssl_v2_curse(const uint8_t* data, unsigned len, CurseTracker* tracker)
+{
+    CurseTracker::SSL& ssl = tracker->ssl;
+
+    if (ssl.state == SSL_State::SSL_NOT_FOUND)
+    {
+        return false;
+    }
+    else if (ssl.state == SSL_State::SSL_FOUND)
+    {
+        return true;
+    }
+
+    for (unsigned i = 0; i < len; ++i)
+    {
+        uint8_t val = data[i];
+
+        switch (ssl.state)
+        {
+        case SSL_State::BYTE_0_LEN_MSB:
+            if ((val & SSL_Const::sslv2_msb_set) == 0)
+            {
+                ssl.state = SSL_State::SSL_NOT_FOUND;
+                return false;
+            }
+            ssl.total_len = (val & (~SSL_Const::sslv2_msb_set)) << 8;
+            ssl.state = SSL_State::BYTE_1_LEN_LSB;
+            break;
+
+        case SSL_State::BYTE_1_LEN_LSB:
+            ssl.total_len |= val;
+            if (ssl.total_len < SSL_Const::hdr_len)
+            {
+                ssl.state = SSL_State::SSL_NOT_FOUND;
+                return false;
+            }
+            ssl.total_len -= SSL_Const::hdr_len;
+            ssl.state = SSL_State::BYTE_2_CLIENT_HELLO;
+            break;
+
+        case SSL_State::BYTE_2_CLIENT_HELLO:
+            if (val != SSL_Const::client_hello)
+            {
+                ssl.state = SSL_State::SSL_NOT_FOUND;
+                return false;
+            }
+            ssl.state = SSL_State::BYTE_3_MAX_MINOR_VER;
+            break;
+
+        case SSL_State::BYTE_3_MAX_MINOR_VER:
+            if (val > SSL_Const::sslv3_max_minor_ver)
+            {
+                ssl.state = SSL_State::SSL_NOT_FOUND;
+                return false;
+            }
+            ssl.state = SSL_State::BYTE_4_V3_MAJOR;
+            break;
+
+        case SSL_State::BYTE_4_V3_MAJOR:
+            if (val > SSL_Const::sslv3_major_ver)
+            {
+                ssl.state = SSL_State::SSL_NOT_FOUND;
+                return false;
+            }
+            ssl.state = SSL_State::BYTE_5_SPECS_LEN_MSB;
+            break;
+
+        case SSL_State::BYTE_5_SPECS_LEN_MSB:
+            ssl.specs_len = val << 8;
+            ssl.state = SSL_State::BYTE_6_SPECS_LEN_LSB;
+            break;
+
+        case SSL_State::BYTE_6_SPECS_LEN_LSB:
+            ssl.specs_len |= val;
+            if (ssl.total_len < ssl.specs_len)
+            {
+                ssl.state = SSL_State::SSL_NOT_FOUND;
+                return false;
+            }
+            ssl.total_len -= ssl.specs_len;
+            ssl.state = SSL_State::BYTE_7_SSNID_LEN_MSB;
+            break;
+
+        case SSL_State::BYTE_7_SSNID_LEN_MSB:
+            ssl.ssnid_len = val << 8;
+            ssl.state = SSL_State::BYTE_8_SSNID_LEN_LSB;
+            break;
+
+        case SSL_State::BYTE_8_SSNID_LEN_LSB:
+            ssl.ssnid_len |= val;
+            if (ssl.total_len < ssl.ssnid_len)
+            {
+                ssl.state = SSL_State::SSL_NOT_FOUND;
+                return false;
+            }
+            ssl.total_len -= ssl.ssnid_len;
+            ssl.state = SSL_State::BYTE_9_CHLNG_LEN_MSB;
+            break;
+
+        case SSL_State::BYTE_9_CHLNG_LEN_MSB:
+            ssl.chlng_len = val << 8;
+            ssl.state = SSL_State::BYTE_10_CHLNG_LEN_LSB;
+            break;
+
+        case SSL_State::BYTE_10_CHLNG_LEN_LSB:
+            ssl.chlng_len |= val;
+            if (ssl.total_len < ssl.chlng_len)
+            {
+                ssl.state = SSL_State::SSL_NOT_FOUND;
+                return false;
+            }
+            ssl.state = SSL_State::SSL_FOUND;
+            return true;
+
+        default:
+            return false;
+        }
+    }
+
+    return false;
+}
+
+
 // map between service and curse details
 static vector<CurseDetails> curse_map
 {
@@ -276,6 +397,7 @@ static vector<CurseDetails> curse_map
     { "dce_udp", "dcerpc",      dce_udp_curse, false },
     { "dce_tcp", "dcerpc",      dce_tcp_curse, true  },
     { "dce_smb", "netbios-ssn", dce_smb_curse, true  },
+    { "sslv2"  , "ssl",         ssl_v2_curse , true  }
 };
 
 bool CurseBook::add_curse(const char* key)
@@ -301,3 +423,121 @@ const vector<const CurseDetails*>& CurseBook::get_curses(bool tcp) const
     return non_tcp_curses;
 }
 
+#ifdef CATCH_TEST_BUILD
+
+#include "catch/catch.hpp"
+#include <cstring>
+
+//client hello with v2 header advertising sslv2
+static const uint8_t ssl_v2_ch[] =
+{ 0x80,0x59,0x01,0x00,0x02,0x00,0x30,0x00,0x00,0x00,0x20,0x00,0x00,0x39,0x00,0x00,
+  0x38,0x00,0x00,0x35,0x00,0x00,0x16,0x00,0x00,0x13,0x00,0x00,0x0a,0x00,0x00,0x33,
+  0x00,0x00,0x32,0x00,0x00,0x2f,0x00,0x00,0x07,0x00,0x00,0x05,0x00,0x00,0x04,0x00,
+  0x00,0x15,0x00,0x00,0x12,0x00,0x00,0x09,0x00,0x00,0xff,0xda,0x86,0xfa,0xb4,0x73,
+  0x5a,0x1e,0x11,0xd1,0xdb,0x58,0x4b,0x59,0xe1,0x07,0x51,0x5f,0x13,0x46,0xa2,0xdd,
+  0xee,0xda,0xc1,0x9d,0xdc,0xd7,0xb8,0x86,0x51,0x10,0x5a };
+
+//client hello with v2 header advertising tls 1.0
+static const uint8_t ssl_v2_v3_ch[] =
+{ 0x80,0x59,0x01,0x03,0x01,0x00,0x30,0x00,0x00,0x00,0x20,0x00,0x00,0x39,0x00,0x00,
+  0x38,0x00,0x00,0x35,0x00,0x00,0x16,0x00,0x00,0x13,0x00,0x00,0x0a,0x00,0x00,0x33,
+  0x00,0x00,0x32,0x00,0x00,0x2f,0x00,0x00,0x07,0x00,0x00,0x05,0x00,0x00,0x04,0x00,
+  0x00,0x15,0x00,0x00,0x12,0x00,0x00,0x09,0x00,0x00,0xff,0xda,0x86,0xfa,0xb4,0x73,
+  0x5a,0x1e,0x11,0xd1,0xdb,0x58,0x4b,0x59,0xe1,0x07,0x51,0x5f,0x13,0x46,0xa2,0xdd,
+  0xee,0xda,0xc1,0x9d,0xdc,0xd7,0xb8,0x86,0x51,0x10,0x5a };
+
+TEST_CASE("sslv2 detect", "[SslV2Curse]")
+{
+    uint32_t max_detect = static_cast<uint32_t>(SSL_State::BYTE_10_CHLNG_LEN_LSB);
+    CurseTracker tracker{ };
+
+    auto test = [&](uint32_t incr_by,const uint8_t* ch)
+        {
+            uint32_t i = 0;
+            while (i <= max_detect)
+            {
+                if ((i + incr_by - 1) < max_detect)
+                {
+                    CHECK(tracker.ssl.state == static_cast<SSL_State>(i));
+                    CHECK_FALSE(ssl_v2_curse(&ch[i],sizeof(uint8_t) * incr_by,&tracker));
+                }
+                else
+                {
+                    CHECK(ssl_v2_curse(&ch[i],sizeof(uint8_t) * incr_by,&tracker));
+                    CHECK(tracker.ssl.state == SSL_State::SSL_FOUND);
+                }
+                i += incr_by;
+            }
+            //subsequent checks must return found
+            CHECK(ssl_v2_curse(&ch[max_detect + 1],sizeof(uint8_t),&tracker));
+            CHECK(tracker.ssl.state == SSL_State::SSL_FOUND);
+        };
+
+    //sslv2 with ssl version 2
+    SECTION("1 byte v2"){ test(1,ssl_v2_ch); }
+    SECTION("2 bytes v2"){ test(2,ssl_v2_ch); }
+    SECTION("3 bytes v2"){ test(3,ssl_v2_ch); }
+    SECTION("4 bytes v2"){ test(4,ssl_v2_ch); }
+    SECTION("5 bytes v2"){ test(5,ssl_v2_ch); }
+    SECTION("6 bytes v2"){ test(6,ssl_v2_ch); }
+    SECTION("7 bytes v2"){ test(7,ssl_v2_ch); }
+    SECTION("8 bytes v2"){ test(8,ssl_v2_ch); }
+    SECTION("9 bytes v2"){ test(9,ssl_v2_ch); }
+    SECTION("10 bytes v2"){ test(10,ssl_v2_ch); }
+    SECTION("11 bytes v2"){ test(11,ssl_v2_ch);}
+
+    //sslv2 with tls version 1.0
+    SECTION("1 byte v2_v3"){ test(1,ssl_v2_v3_ch); }
+    SECTION("2 bytes v2_v3"){ test(2,ssl_v2_v3_ch); }
+    SECTION("3 bytes v2_v3"){ test(3,ssl_v2_v3_ch); }
+    SECTION("4 bytes v2_v3"){ test(4,ssl_v2_v3_ch); }
+    SECTION("5 bytes v2_v3"){ test(5,ssl_v2_v3_ch); }
+    SECTION("6 bytes v2_v3"){ test(6,ssl_v2_v3_ch); }
+    SECTION("7 bytes v2_v3"){ test(7,ssl_v2_v3_ch); }
+    SECTION("8 bytes v2_v3"){ test(8,ssl_v2_v3_ch); }
+    SECTION("9 bytes v2_v3"){ test(9,ssl_v2_v3_ch); }
+    SECTION("10 bytes v2_v3"){ test(10,ssl_v2_v3_ch); }
+    SECTION("11 bytes v2_v3"){ test(11,ssl_v2_v3_ch); }
+}
+
+TEST_CASE("sslv2 not found", "[SslV2Curse]")
+{
+    uint32_t max_detect = static_cast<uint32_t>(SSL_State::BYTE_10_CHLNG_LEN_LSB);
+    CurseTracker tracker{};
+    uint8_t bad_data[] = {0x00,0x08,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff};
+    auto test = [&](uint32_t fail_at_byte)
+        {
+            uint8_t ch_data[sizeof(ssl_v2_ch)];
+            memcpy(ch_data,ssl_v2_ch,sizeof(ssl_v2_ch));
+
+            ch_data[fail_at_byte] = bad_data[fail_at_byte];
+
+            for (uint32_t i = 0; i <= fail_at_byte; i++)
+            {
+                if (i < fail_at_byte)
+                {
+                    CHECK(tracker.ssl.state == static_cast<SSL_State>(i));
+                    CHECK_FALSE(ssl_v2_curse(&ch_data[i],sizeof(uint8_t),&tracker));
+                }
+                else
+                {
+                    CHECK_FALSE(ssl_v2_curse(&ch_data[i],sizeof(uint8_t),&tracker));
+                    CHECK(tracker.ssl.state == SSL_State::SSL_NOT_FOUND);
+                }
+            }
+            //subsequent checks must return ssl not found
+            CHECK_FALSE(ssl_v2_curse(&ch_data[max_detect + 1],sizeof(uint8_t),&tracker));
+            CHECK(tracker.ssl.state == SSL_State::SSL_NOT_FOUND);
+        };
+
+    SECTION("byte 0"){ test(0);}
+    SECTION("byte 1"){ test(1);}
+    SECTION("byte 2"){ test(2);}
+    SECTION("byte 3"){ test(3);}
+    SECTION("byte 4"){ test(4);}
+    SECTION("byte 6"){ test(6);}
+    SECTION("byte 8"){ test(8);}
+    SECTION("byte 10"){ test(10);}
+}
+
+#endif
index 00484ea61de20f6e137e24ee940f6884cd70aafb..62fd286b33f11ef4183afc7fb750157667a8ee53 100644 (file)
@@ -24,7 +24,7 @@
 #include <string>
 #include <vector>
 
-enum DCE_States
+enum DCE_State
 {
     STATE_0 = 0,
     STATE_1,
@@ -39,13 +39,46 @@ enum DCE_States
     STATE_10
 };
 
+enum SSL_State
+{
+    BYTE_0_LEN_MSB = 0,
+    BYTE_1_LEN_LSB,
+    BYTE_2_CLIENT_HELLO,
+    BYTE_3_MAX_MINOR_VER,
+    BYTE_4_V3_MAJOR,
+    BYTE_5_SPECS_LEN_MSB,
+    BYTE_6_SPECS_LEN_LSB,
+    BYTE_7_SSNID_LEN_MSB,
+    BYTE_8_SSNID_LEN_LSB,
+    BYTE_9_CHLNG_LEN_MSB,
+    BYTE_10_CHLNG_LEN_LSB,
+    SSL_FOUND,
+    SSL_NOT_FOUND
+};
+
 class CurseTracker
 {
 public:
-    DCE_States state;
-    uint32_t helper;
+    struct DCE
+    {
+        DCE_State state;
+        uint32_t helper;
+    } dce;
+
+    struct SSL
+    {
+        SSL_State state;
+        unsigned total_len;
+        unsigned ssnid_len;
+        unsigned specs_len;
+        unsigned chlng_len;
+    } ssl;
 
-    CurseTracker() { state = STATE_0; }
+    CurseTracker()
+    {
+        dce.state = DCE_State::STATE_0;
+        ssl.state = SSL_State::BYTE_0_LEN_MSB;
+    }
 };
 
 typedef bool (* curse_alg)(const uint8_t* data, unsigned len, CurseTracker*);
index dfb621ffa06b51859dc9f346698831ab3bd7f0ce..cbbf5583b7d6ffea31f756f519bca1ccb4a25f5d 100644 (file)
@@ -103,7 +103,7 @@ static const Parameter s_params[] =
     { "spells", Parameter::PT_LIST, wizard_spells_params, nullptr,
       "criteria for text service identification" },
 
-    { "curses", Parameter::PT_MULTI, "dce_smb | dce_udp | dce_tcp", nullptr,
+    { "curses", Parameter::PT_MULTI, "dce_smb | dce_udp | dce_tcp | sslv2", nullptr,
       "enable service identification based on internal algorithm" },
 
     { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }