From: Mike Stepanek (mstepane) Date: Tue, 10 May 2022 11:10:02 +0000 (+0000) Subject: Pull request #3416: wizard: fix code style X-Git-Tag: 3.1.30.0~14 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c5ee627ba8c452767d96b0c204251236167d1cd6;p=thirdparty%2Fsnort3.git Pull request #3416: wizard: fix code style Merge in SNORT/snort3 from ~YVELYKOZ/snort3:fix_up_code_style to master Squashed commit of the following: commit 4103d16df893987b729caf1dc649de82b58fbda0 Author: Yehor Date: Thu May 5 21:43:30 2022 +0300 wizard: fix code style Following code style aspects was covered: 1. Space after 'if', 'for', 'while', 'switch' and space between braces 2. Newline before mentioned keyword. 3. Newline before 'return'. 4. Adding const to func if applicable. --- diff --git a/src/service_inspectors/wizard/curses.cc b/src/service_inspectors/wizard/curses.cc index f0f1969b2..be258c5e8 100644 --- a/src/service_inspectors/wizard/curses.cc +++ b/src/service_inspectors/wizard/curses.cc @@ -73,7 +73,7 @@ static bool dce_udp_curse(const uint8_t* data, unsigned len, CurseTracker*) const uint8_t dcerpc_cl_hdr_len = 80; const uint8_t cl_len_offset = 74; - if (len >= dcerpc_cl_hdr_len) + if ( len >= dcerpc_cl_hdr_len ) { uint8_t version = data[0]; uint8_t pdu_type = data[1]; @@ -81,22 +81,22 @@ static bool dce_udp_curse(const uint8_t* data, unsigned len, CurseTracker*) uint16_t cl_len; #ifdef WORDS_BIGENDIAN - if (!little_endian) + if ( !little_endian ) #else - if (little_endian) + if ( little_endian ) #endif /* WORDS_BIGENDIAN */ cl_len = (data[cl_len_offset+1] << 8) | data[cl_len_offset]; else cl_len = (data[cl_len_offset] << 8) | data[cl_len_offset+1]; - if ((version == DCERPC_PROTO_MAJOR_VERS__4) && - ((pdu_type == DCERPC_PDU_TYPE__REQUEST) || - (pdu_type == DCERPC_PDU_TYPE__RESPONSE) || - (pdu_type == DCERPC_PDU_TYPE__FAULT) || - (pdu_type == DCERPC_PDU_TYPE__REJECT) || - (pdu_type == DCERPC_PDU_TYPE__FACK)) && - ((cl_len != 0) && - (cl_len + (unsigned)dcerpc_cl_hdr_len) <= len)) + if ( (version == DCERPC_PROTO_MAJOR_VERS__4) and + ((pdu_type == DCERPC_PDU_TYPE__REQUEST) or + (pdu_type == DCERPC_PDU_TYPE__RESPONSE) or + (pdu_type == DCERPC_PDU_TYPE__FAULT) or + (pdu_type == DCERPC_PDU_TYPE__REJECT) or + (pdu_type == DCERPC_PDU_TYPE__FACK)) and + ((cl_len != 0) and + (cl_len + (unsigned)dcerpc_cl_hdr_len) <= len) ) return true; } @@ -109,40 +109,47 @@ static bool dce_tcp_curse(const uint8_t* data, unsigned len, CurseTracker* track CurseTracker::DCE& dce = tracker->dce; uint32_t n = 0; - while (n < len) + while ( n < len ) { - switch (dce.state) + switch ( dce.state ) { case STATE_0: // check major version - if (data[n] != DCERPC_PROTO_MAJOR_VERS__5) + if ( data[n] != DCERPC_PROTO_MAJOR_VERS__5 ) { // go to bad state dce.state = STATE_10; + return false; } + dce.state = (DCE_State)((int)dce.state + 1); break; case STATE_1: // check minor version - if (data[n] != DCERPC_PROTO_MINOR_VERS__0) + if ( data[n] != DCERPC_PROTO_MINOR_VERS__0 ) { // go to bad state dce.state = STATE_10; + return false; } + dce.state = (DCE_State)((int)dce.state + 1); break; case STATE_2: // pdu_type { uint8_t pdu_type = data[n]; - if ((pdu_type != DCERPC_PDU_TYPE__BIND) && - (pdu_type != DCERPC_PDU_TYPE__BIND_ACK)) + + if ( (pdu_type != DCERPC_PDU_TYPE__BIND) and + (pdu_type != DCERPC_PDU_TYPE__BIND_ACK) ) { // go to bad state dce.state = STATE_10; + return false; } + dce.state = (DCE_State)((int)dce.state + 1); break; } @@ -157,9 +164,9 @@ static bool dce_tcp_curse(const uint8_t* data, unsigned len, CurseTracker* track break; case STATE_9: #ifdef WORDS_BIGENDIAN - if (!(dce.helper >> 24)) + if ( !(dce.helper >> 24) ) #else - if (dce.helper >> 24) + if ( dce.helper >> 24 ) #endif /* WORDS_BIGENDIAN */ dce.helper = (data[n] << 8) | (dce.helper & 0XFF); else @@ -168,7 +175,7 @@ static bool dce_tcp_curse(const uint8_t* data, unsigned len, CurseTracker* track dce.helper |= data[n]; } - if (dce.helper >= dce_rpc_co_hdr_len) + if ( dce.helper >= dce_rpc_co_hdr_len ) return true; dce.state = STATE_10; @@ -181,6 +188,7 @@ static bool dce_tcp_curse(const uint8_t* data, unsigned len, CurseTracker* track dce.state = (DCE_State)((int)dce.state + 1); break; } + n++; } @@ -191,39 +199,41 @@ static bool dce_smb_curse(const uint8_t* data, unsigned len, CurseTracker* track { const uint32_t dce_smb_id = 0xff534d42; /* \xffSMB */ const uint32_t dce_smb2_id = 0xfe534d42; /* \xfeSMB */ - const uint8_t session_request = 0x81, session_response = 0x82, - session_message = 0x00; + const uint8_t session_request = 0x81, session_response = 0x82, session_message = 0x00; CurseTracker::DCE& dce = tracker->dce; uint32_t n = 0; - while (n < len) + while ( n < len ) { - switch (dce.state) + switch ( dce.state ) { case STATE_0: - if (data[n] == session_message) + if ( data[n] == session_message ) { dce.state = (DCE_State)((int)dce.state + 2); break; } - if (data[n] == session_request || data[n] == session_response) + if ( data[n] == session_request or data[n] == session_response ) { dce.state = (DCE_State)((int)dce.state + 1); + return false; } dce.state = STATE_9; + return false; case STATE_1: - if (data[n] == session_message) + if ( data[n] == session_message ) { dce.state = (DCE_State)((int)dce.state + 1); break; } dce.state = STATE_9; + return false; case STATE_5: @@ -241,7 +251,8 @@ static bool dce_smb_curse(const uint8_t* data, unsigned len, CurseTracker* track case STATE_8: dce.helper <<= 8; dce.helper |= data[n]; - if ((dce.helper == dce_smb_id) || (dce.helper == dce_smb2_id)) + + if ( (dce.helper == dce_smb_id) or (dce.helper == dce_smb2_id) ) return true; dce.state = (DCE_State)((int)dce.state + 1); @@ -255,6 +266,7 @@ static bool dce_smb_curse(const uint8_t* data, unsigned len, CurseTracker* track dce.state = (DCE_State)((int)dce.state + 1); break; } + n++; } @@ -270,7 +282,7 @@ static bool mms_curse(const uint8_t* data, unsigned len, CurseTracker* tracker) // if the state is set to MMS_STATE__SEARCH it means we most likely // have a split pipelined message coming through and will need to // reset the state - if (mms.state == MMS_STATE__SEARCH) + if ( mms.state == MMS_STATE__SEARCH ) { mms.state = mms.last_state; } @@ -295,9 +307,9 @@ static bool mms_curse(const uint8_t* data, unsigned len, CurseTracker* tracker) }; uint32_t idx = 0; - while (idx < len) + while ( idx < len ) { - switch (mms.state) + switch ( mms.state ) { case MMS_STATE__TPKT_VER: { @@ -336,11 +348,13 @@ static bool mms_curse(const uint8_t* data, unsigned len, CurseTracker* tracker) // . . . . x x x x Destination Reference // x x x x . . . . PDU Type const uint32_t MMS_COTP_PDU_DT_DATA = 0x0F; - if (data[idx] >> 0x04 != MMS_COTP_PDU_DT_DATA) + + if ( data[idx] >> 0x04 != MMS_COTP_PDU_DT_DATA ) { mms.state = MMS_STATE__NOT_FOUND; break; } + mms.state = MMS_STATE__COTP_TPDU_NUM; break; } @@ -361,7 +375,7 @@ static bool mms_curse(const uint8_t* data, unsigned len, CurseTracker* tracker) MMS_OSI_SESSION_SPDU_AC = 0x0E, }; - switch (data[idx]) + switch ( data[idx] ) { // check for a known MMS message tag in the event Session/Pres/ACSE aren't used case MMS_CONFIRMED_REQUEST_TAG: // fallthrough intentional @@ -407,10 +421,10 @@ static bool mms_curse(const uint8_t* data, unsigned len, CurseTracker* tracker) case MMS_STATE__MMS: { // loop through the remaining bytes in the buffer checking for known MMS tags - for (uint32_t i=idx; i < len; i++) + for ( uint32_t i=idx; i < len; i++ ) { // for each remaining byte check to see if it is in the known tag map - switch (data[i]) + switch ( data[i] ) { case MMS_CONFIRMED_REQUEST_TAG: // fallthrough intentional case MMS_CONFIRMED_RESPONSE_TAG: // fallthrough intentional @@ -437,25 +451,28 @@ static bool mms_curse(const uint8_t* data, unsigned len, CurseTracker* tracker) } // exit the loop when a state has been determined - if (mms.state == MMS_STATE__NOT_FOUND + if ( mms.state == MMS_STATE__NOT_FOUND or mms.state == MMS_STATE__SEARCH - or mms.state == MMS_STATE__FOUND) + or mms.state == MMS_STATE__FOUND ) { break; } } + break; } case MMS_STATE__FOUND: { mms.state = MMS_STATE__TPKT_VER; + return true; } case MMS_STATE__NOT_FOUND: { mms.state = MMS_STATE__TPKT_VER; + return false; } @@ -466,11 +483,13 @@ static bool mms_curse(const uint8_t* data, unsigned len, CurseTracker* tracker) break; } } + idx++; } mms.last_state = mms.state; mms.state = MMS_STATE__SEARCH; + return false; } @@ -488,66 +507,72 @@ static bool ssl_v2_curse(const uint8_t* data, unsigned len, CurseTracker* tracke { CurseTracker::SSL& ssl = tracker->ssl; - if (ssl.state == SSL_State::SSL_NOT_FOUND) - { + if ( ssl.state == SSL_State::SSL_NOT_FOUND ) return false; - } - else if (ssl.state == SSL_State::SSL_FOUND) - { + else if ( ssl.state == SSL_State::SSL_FOUND ) return true; - } - for (unsigned i = 0; i < len; ++i) + for ( unsigned i = 0; i < len; ++i ) { uint8_t val = data[i]; - switch (ssl.state) + switch ( ssl.state ) { case SSL_State::BYTE_0_LEN_MSB: - if ((val & SSL_Const::sslv2_msb_set) == 0) + if ( (val & SSL_Const::sslv2_msb_set) == 0 ) { ssl.state = SSL_State::SSL_NOT_FOUND; + return false; } + ssl.total_len = (val & (~SSL_Const::sslv2_msb_set)) << 8; ssl.state = SSL_State::BYTE_1_LEN_LSB; break; case SSL_State::BYTE_1_LEN_LSB: ssl.total_len |= val; - if (ssl.total_len < SSL_Const::hdr_len) + if ( ssl.total_len < SSL_Const::hdr_len ) { ssl.state = SSL_State::SSL_NOT_FOUND; + return false; } + ssl.total_len -= SSL_Const::hdr_len; ssl.state = SSL_State::BYTE_2_CLIENT_HELLO; break; case SSL_State::BYTE_2_CLIENT_HELLO: - if (val != SSL_Const::client_hello) + if ( val != SSL_Const::client_hello ) { ssl.state = SSL_State::SSL_NOT_FOUND; + return false; } + ssl.state = SSL_State::BYTE_3_MAX_MINOR_VER; break; case SSL_State::BYTE_3_MAX_MINOR_VER: - if (val > SSL_Const::sslv3_max_minor_ver) + if ( val > SSL_Const::sslv3_max_minor_ver ) { ssl.state = SSL_State::SSL_NOT_FOUND; + return false; } + ssl.state = SSL_State::BYTE_4_V3_MAJOR; break; case SSL_State::BYTE_4_V3_MAJOR: - if (val > SSL_Const::sslv3_major_ver) + if ( val > SSL_Const::sslv3_major_ver ) { ssl.state = SSL_State::SSL_NOT_FOUND; + return false; } + ssl.state = SSL_State::BYTE_5_SPECS_LEN_MSB; break; @@ -558,11 +583,14 @@ static bool ssl_v2_curse(const uint8_t* data, unsigned len, CurseTracker* tracke case SSL_State::BYTE_6_SPECS_LEN_LSB: ssl.specs_len |= val; - if (ssl.total_len < ssl.specs_len) + + if ( ssl.total_len < ssl.specs_len ) { ssl.state = SSL_State::SSL_NOT_FOUND; + return false; } + ssl.total_len -= ssl.specs_len; ssl.state = SSL_State::BYTE_7_SSNID_LEN_MSB; break; @@ -574,11 +602,14 @@ static bool ssl_v2_curse(const uint8_t* data, unsigned len, CurseTracker* tracke case SSL_State::BYTE_8_SSNID_LEN_LSB: ssl.ssnid_len |= val; - if (ssl.total_len < ssl.ssnid_len) + + if ( ssl.total_len < ssl.ssnid_len ) { ssl.state = SSL_State::SSL_NOT_FOUND; + return false; } + ssl.total_len -= ssl.ssnid_len; ssl.state = SSL_State::BYTE_9_CHLNG_LEN_MSB; break; @@ -590,12 +621,16 @@ static bool ssl_v2_curse(const uint8_t* data, unsigned len, CurseTracker* tracke case SSL_State::BYTE_10_CHLNG_LEN_LSB: ssl.chlng_len |= val; - if (ssl.total_len < ssl.chlng_len) + + if ( ssl.total_len < ssl.chlng_len ) { ssl.state = SSL_State::SSL_NOT_FOUND; + return false; } + ssl.state = SSL_State::SSL_FOUND; + return true; default: @@ -620,24 +655,27 @@ static vector curse_map bool CurseBook::add_curse(const char* key) { - for (const CurseDetails& curse : curse_map) + for ( const CurseDetails& curse : curse_map ) { - if (curse.name == key) + if ( curse.name == key ) { - if (curse.is_tcp) + if ( curse.is_tcp ) tcp_curses.emplace_back(&curse); else non_tcp_curses.emplace_back(&curse); + return true; } } + return false; } const vector& CurseBook::get_curses(bool tcp) const { - if (tcp) + if ( tcp ) return tcp_curses; + return non_tcp_curses; } @@ -672,9 +710,9 @@ TEST_CASE("sslv2 detect", "[SslV2Curse]") auto test = [&](uint32_t incr_by,const uint8_t* ch) { uint32_t i = 0; - while (i <= max_detect) + while ( i <= max_detect ) { - if ((i + incr_by - 1) < max_detect) + if ( (i + incr_by - 1) < max_detect ) { CHECK(tracker.ssl.state == static_cast(i)); CHECK_FALSE(ssl_v2_curse(&ch[i],sizeof(uint8_t) * incr_by,&tracker)); @@ -684,6 +722,7 @@ TEST_CASE("sslv2 detect", "[SslV2Curse]") CHECK(ssl_v2_curse(&ch[i],sizeof(uint8_t) * incr_by,&tracker)); CHECK(tracker.ssl.state == SSL_State::SSL_FOUND); } + i += incr_by; } //subsequent checks must return found @@ -730,9 +769,9 @@ TEST_CASE("sslv2 not found", "[SslV2Curse]") ch_data[fail_at_byte] = bad_data[fail_at_byte]; - for (uint32_t i = 0; i <= fail_at_byte; i++) + for ( uint32_t i = 0; i <= fail_at_byte; i++ ) { - if (i < fail_at_byte) + if ( i < fail_at_byte ) { CHECK(tracker.ssl.state == static_cast(i)); CHECK_FALSE(ssl_v2_curse(&ch_data[i],sizeof(uint8_t),&tracker)); diff --git a/src/service_inspectors/wizard/hexes.cc b/src/service_inspectors/wizard/hexes.cc index 327d95219..8b73db1df 100644 --- a/src/service_inspectors/wizard/hexes.cc +++ b/src/service_inspectors/wizard/hexes.cc @@ -56,7 +56,7 @@ bool HexBook::translate(const char* in, HexVector& out) } else if ( in[i] != ' ' ) { - if ( !isxdigit(in[i]) || byte.size() > 1 ) + if ( !isxdigit(in[i]) or byte.size() > 1 ) return false; byte += in[i]; @@ -64,14 +64,16 @@ bool HexBook::translate(const char* in, HexVector& out) else push = true; - if ( push && !byte.empty() ) + if ( push and !byte.empty() ) { int b = strtol(byte.c_str(), nullptr, 16); out.emplace_back((uint8_t)b); byte.clear(); } + ++i; } + return true; } @@ -93,6 +95,7 @@ void HexBook::add_spell( p = t; ++i; } + p->key = key; p->value = SnortConfig::get_static_name(val); } @@ -104,6 +107,7 @@ bool HexBook::add_spell(const char* key, const char*& val) if ( !translate(key, hv) ) { val = nullptr; + return false; } @@ -114,7 +118,7 @@ bool HexBook::add_spell(const char* key, const char*& val) { int c = hv[i]; - if ( c == WILD && p->any ) + if ( c == WILD and p->any ) p = p->any; else if ( p->next[c] ) @@ -128,10 +132,12 @@ bool HexBook::add_spell(const char* key, const char*& val) if ( p->key == key ) { val = p->value; + return false; } add_spell(key, val, hv, i, p); + return true; } @@ -161,7 +167,9 @@ const MagicPage* HexBook::find_spell( if ( const MagicPage* q = find_spell(s, n, p->any, i+1, bookmark) ) return q; } + return p->value ? p : nullptr; } + return p; } diff --git a/src/service_inspectors/wizard/magic.cc b/src/service_inspectors/wizard/magic.cc index 13397fedc..41ddd9641 100644 --- a/src/service_inspectors/wizard/magic.cc +++ b/src/service_inspectors/wizard/magic.cc @@ -29,6 +29,7 @@ MagicPage::MagicPage(const MagicBook& b) : book(b) { for ( int i = 0; i < 256; ++i ) next[i] = nullptr; + any = nullptr; } @@ -36,9 +37,10 @@ MagicPage::~MagicPage() { for ( int i = 0; i < 256; ++i ) { - if ( next[i] && next[i] != this ) + if ( next[i] and next[i] != this ) delete next[i]; } + delete any; } @@ -47,6 +49,7 @@ const char* MagicBook::find_spell(const uint8_t* data, unsigned len, { assert(p); p = find_spell(data, len, p, 0, bookmark); + return p ? p->value : nullptr; } diff --git a/src/service_inspectors/wizard/magic.h b/src/service_inspectors/wizard/magic.h index c2341e2ac..bc848d296 100644 --- a/src/service_inspectors/wizard/magic.h +++ b/src/service_inspectors/wizard/magic.h @@ -42,7 +42,6 @@ struct MagicPage typedef std::vector HexVector; // MagicBook is a set of MagicPages implementing a trie - class MagicBook { public: diff --git a/src/service_inspectors/wizard/spells.cc b/src/service_inspectors/wizard/spells.cc index 28ef4e290..9b8b51ebc 100644 --- a/src/service_inspectors/wizard/spells.cc +++ b/src/service_inspectors/wizard/spells.cc @@ -67,6 +67,7 @@ bool SpellBook::translate(const char* in, HexVector& out) } ++i; } + return true; } @@ -85,6 +86,7 @@ void SpellBook::add_spell( p = t; ++i; } + p->key = key; p->value = snort::SnortConfig::get_static_name(val); } @@ -96,6 +98,7 @@ bool SpellBook::add_spell(const char* key, const char*& val) if ( !translate(key, hv) ) { val = nullptr; + return false; } @@ -107,10 +110,10 @@ bool SpellBook::add_spell(const char* key, const char*& val) { int c = toupper(hv[i]); - if ( c == WILD && p->any ) + if ( c == WILD and p->any ) p = p->any; - else if ( c != WILD && p->next[c] ) + else if ( c != WILD and p->next[c] ) p = p->next[c]; else @@ -121,10 +124,12 @@ bool SpellBook::add_spell(const char* key, const char*& val) if ( p->key == key ) { val = p->value; + return false; } add_spell(key, val, hv, i, p); + return true; } @@ -156,15 +161,18 @@ const MagicPage* SpellBook::find_spell( if ( const MagicPage* q = find_spell(s, n, p->any, i, bookmark) ) { bookmark = q->any ? q : p; + return q; } + ++i; } + return p; } // If no match but has bookmark, continue lookup from bookmark - if ( !p->value && bookmark ) + if ( !p->value and bookmark ) { p = bookmark; bookmark = nullptr; @@ -174,5 +182,6 @@ const MagicPage* SpellBook::find_spell( return p->value ? p : nullptr; } + return p; } diff --git a/src/service_inspectors/wizard/wiz_module.cc b/src/service_inspectors/wizard/wiz_module.cc index d21ca1ced..1f49d65c4 100644 --- a/src/service_inspectors/wizard/wiz_module.cc +++ b/src/service_inspectors/wizard/wiz_module.cc @@ -131,6 +131,7 @@ void WizardModule::set_trace(const Trace* trace) const const TraceOption* WizardModule::get_trace_options() const { static const TraceOption wizard_trace_options(nullptr, 0, nullptr); + return &wizard_trace_options; } @@ -149,9 +150,9 @@ bool WizardModule::set(const char*, Value& v, SnortConfig*) else if ( v.is("client_first") ) return true; - else if ( v.is("hex") || v.is("spell") ) + else if ( v.is("hex") or v.is("spell") ) { - if (c2s) + if ( c2s ) c2s_patterns.emplace_back(v.get_string()); else s2c_patterns.emplace_back(v.get_string()); @@ -177,7 +178,7 @@ bool WizardModule::begin(const char* fqn, int idx, SnortConfig*) curses = new CurseBook; } - else if ( !strcmp(fqn, "wizard.hexes") || !strcmp(fqn, "wizard.spells") ) + else if ( !strcmp(fqn, "wizard.hexes") or !strcmp(fqn, "wizard.spells") ) { if ( idx > 0 ) { @@ -186,10 +187,10 @@ bool WizardModule::begin(const char* fqn, int idx, SnortConfig*) s2c_patterns.clear(); } } - else if ( !strcmp(fqn, "wizard.hexes.to_client") || !strcmp(fqn, "wizard.spells.to_client") ) + else if ( !strcmp(fqn, "wizard.hexes.to_client") or !strcmp(fqn, "wizard.spells.to_client") ) c2s = false; - else if ( !strcmp(fqn, "wizard.hexes.to_server") || !strcmp(fqn, "wizard.spells.to_server") ) + else if ( !strcmp(fqn, "wizard.hexes.to_server") or !strcmp(fqn, "wizard.spells.to_server") ) c2s = true; return true; @@ -206,6 +207,7 @@ static bool add_spells(MagicBook* b, const string& service, const vector { ParseError("Invalid %s '%s' for service '%s'", hex ? "hex" : "spell", p.c_str(), service.c_str()); + return false; } else if ( service != val ) @@ -238,11 +240,13 @@ bool WizardModule::end(const char* fqn, int idx, SnortConfig*) if ( service.empty() ) { ParseError("Hexes must have a service name"); + return false; } - if ( c2s_patterns.empty() && s2c_patterns.empty() ) + if ( c2s_patterns.empty() and s2c_patterns.empty() ) { ParseError("Hexes must have at least one pattern"); + return false; } if ( !add_spells(c2s_hexes, service, c2s_patterns, true) ) @@ -259,11 +263,13 @@ bool WizardModule::end(const char* fqn, int idx, SnortConfig*) if ( service.empty() ) { ParseError("Spells must have a service name"); + return false; } - if ( c2s_patterns.empty() && s2c_patterns.empty() ) + if ( c2s_patterns.empty() and s2c_patterns.empty() ) { ParseError("Spells must have at least one pattern"); + return false; } if ( !add_spells(c2s_spells, service, c2s_patterns, false) ) @@ -305,6 +311,7 @@ MagicBook* WizardModule::get_book(bool c2s, bool hex) c2s_hexes = nullptr; break; } + return b; } @@ -312,6 +319,7 @@ CurseBook* WizardModule::get_curse_book() { CurseBook* b = curses; curses = nullptr; + return b; } diff --git a/src/service_inspectors/wizard/wiz_module.h b/src/service_inspectors/wizard/wiz_module.h index d8adb0984..2dce2b36d 100644 --- a/src/service_inspectors/wizard/wiz_module.h +++ b/src/service_inspectors/wizard/wiz_module.h @@ -58,7 +58,7 @@ public: MagicBook* get_book(bool c2s, bool hex); CurseBook* get_curse_book(); - uint16_t get_max_search_depth() + uint16_t get_max_search_depth() const { return max_search_depth; } Usage get_usage() const override diff --git a/src/service_inspectors/wizard/wizard.cc b/src/service_inspectors/wizard/wizard.cc index f1d8bc60a..48cff8bba 100644 --- a/src/service_inspectors/wizard/wizard.cc +++ b/src/service_inspectors/wizard/wizard.cc @@ -139,8 +139,10 @@ public: StreamSplitter* get_splitter(bool) override; inline bool finished(Wand& w) - { return !w.hex && !w.spell && w.curse_tracker.empty(); } - void reset(Wand&, bool tcp, bool c2s); + { return !w.hex and !w.spell and w.curse_tracker.empty(); } + + void reset(Wand&, bool, bool); + bool cast_spell(Wand&, Flow*, const uint8_t*, unsigned, uint16_t&); bool spellbind(const MagicPage*&, Flow*, const uint8_t*, unsigned, const MagicPage*&); bool cursebind(const vector&, Flow*, const uint8_t*, unsigned); @@ -176,7 +178,7 @@ MagicSplitter::~MagicSplitter() wizard->rem_ref(); // release trackers - for (unsigned i = 0; i < wand.curse_tracker.size(); i++) + for ( unsigned i = 0; i < wand.curse_tracker.size(); i++ ) delete wand.curse_tracker[i].tracker; } @@ -195,18 +197,21 @@ StreamSplitter::Status MagicSplitter::scan( to_server() ? "c2s" : "s2c", pkt->flow->service); count_hit(pkt->flow); wizard_processed_bytes = 0; + return STOP; } - else if ( wizard->finished(wand) || bytes_scanned >= max(pkt->flow) ) + else if ( wizard->finished(wand) or bytes_scanned >= max(pkt->flow) ) { count_miss(pkt->flow); trace_logf(wizard_trace, pkt, "%s streaming search abandoned\n", to_server() ? "c2s" : "s2c"); wizard_processed_bytes = 0; - if (!pkt->flow->flags.svc_event_generated) + + if ( !pkt->flow->flags.svc_event_generated ) { DataBus::publish(FLOW_NO_SERVICE_EVENT, pkt); pkt->flow->flags.svc_event_generated = true; } + return ABORT; } @@ -217,7 +222,7 @@ StreamSplitter::Status MagicSplitter::scan( // delayed. Because AppId depends on wizard only for SSH detection and SSH inspector can be // attached very early, event is raised here after first scan. In the future, wizard should be // enhanced to abort sooner if it can't detect service. - if (!pkt->flow->service && !pkt->flow->flags.svc_event_generated) + if ( !pkt->flow->service and !pkt->flow->flags.svc_event_generated ) { DataBus::publish(FLOW_NO_SERVICE_EVENT, pkt); pkt->flow->flags.svc_event_generated = true; @@ -269,12 +274,13 @@ void Wizard::reset(Wand& w, bool tcp, bool c2s) w.spell = s2c_spells->page1(); } - if (w.curse_tracker.empty()) + if ( w.curse_tracker.empty() ) { vector pages = curses->get_curses(tcp); + for ( const CurseDetails* curse : pages ) { - if (tcp) + if ( tcp ) w.curse_tracker.emplace_back( CurseServiceTracker{ curse, new CurseTracker } ); else w.curse_tracker.emplace_back( CurseServiceTracker{ curse, nullptr } ); @@ -289,7 +295,7 @@ void Wizard::eval(Packet* p) if ( !p->is_udp() ) return; - if ( !p->data || !p->dsize ) + if ( !p->data or !p->dsize ) return; bool c2s = p->is_from_client(); @@ -298,6 +304,7 @@ void Wizard::eval(Packet* p) uint16_t udp_processed_bytes = 0; ++tstats.udp_scans; + if ( cast_spell(wand, p->flow, p->data, p->dsize, udp_processed_bytes) ) { trace_logf(wizard_trace, p, "%s datagram search found service %s\n", @@ -321,17 +328,19 @@ bool Wizard::spellbind( const MagicPage*& m, Flow* f, const uint8_t* data, unsigned len, const MagicPage*& bookmark) { f->service = m->book.find_spell(data, len, m, bookmark); + return f->service != nullptr; } bool Wizard::cursebind(const vector& curse_tracker, Flow* f, const uint8_t* data, unsigned len) { - for (const CurseServiceTracker& cst : curse_tracker) + for ( const CurseServiceTracker& cst : curse_tracker ) { - if (cst.curse->alg(data, len, cst.tracker)) + if ( cst.curse->alg(data, len, cst.tracker) ) { f->service = cst.curse->service; + if ( f->service ) return true; } @@ -344,23 +353,21 @@ bool Wizard::cast_spell( Wand& w, Flow* f, const uint8_t* data, unsigned len, uint16_t& wizard_processed_bytes) { auto curse_len = len; - len = std::min(len, static_cast(max_search_depth - wizard_processed_bytes)); - wizard_processed_bytes += len; - if ( w.hex && spellbind(w.hex, f, data, len, w.bookmark) ) + if ( w.hex and spellbind(w.hex, f, data, len, w.bookmark) ) return true; - if ( w.spell && spellbind(w.spell, f, data, len, w.bookmark) ) + if ( w.spell and spellbind(w.spell, f, data, len, w.bookmark) ) return true; - if (cursebind(w.curse_tracker, f, data, curse_len)) + if ( cursebind(w.curse_tracker, f, data, curse_len) ) return true; // If we reach max value of wizard_processed_bytes, // but not assign any inspector - raise tcp_miss and stop - if ( !f->service && wizard_processed_bytes >= max_search_depth ) + if ( !f->service and wizard_processed_bytes >= max_search_depth ) { w.spell = nullptr; w.hex = nullptr;