]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4239: Handle gracefully decode error cases
authorMaya Dagon (mdagon) <mdagon@cisco.com>
Wed, 8 May 2024 03:39:20 +0000 (03:39 +0000)
committerPriyanka Bangalore Gurudev (prbg) <prbg@cisco.com>
Wed, 8 May 2024 03:39:20 +0000 (03:39 +0000)
Merge in SNORT/snort3 from ~MDAGON/snort3:defensive to master

Squashed commit of the following:

commit 963134b2cf090fe6bb8811dfdebe5aa590878ffa
Author: maya dagon <mdagon@cisco.com>
Date:   Wed May 1 11:00:55 2024 -0400

    framework: api version bump

commit fdbfa6df53a6ad24aa4f25ebcd1a379c7ef320b0
Author: maya dagon <mdagon@cisco.com>
Date:   Thu Apr 4 08:35:42 2024 -0400

    framework: expand decode flags

commit 7da61b14fdf0114059f7e1a2a9a3a066afdd91b8
Author: PRATEEK MOHAN PRABHU -X (pratepra - XORIANT CORPORATION at Cisco) <pratepra@cisco.com>
Date:   Tue Jan 16 16:32:22 2024 +0530

    protocols: defensive fix for malformed packets, discard log

src/framework/base_api.h
src/framework/decode_data.h
src/network_inspectors/packet_tracer/packet_tracer.cc
src/network_inspectors/packet_tracer/packet_tracer.h
src/protocols/packet_manager.cc
src/protocols/test/CMakeLists.txt
src/protocols/test/decode_err_len_test.cc [new file with mode: 0644]

index 49bbfdf09f8efa0ef8a39de217c80199e3dd8830..f2ca5bf8ccb15536ad9462affc2d60a7220e218c 100644 (file)
@@ -29,7 +29,7 @@
 
 // this is the current version of the base api
 // must be prefixed to subtype version
-#define BASE_API_VERSION 17
+#define BASE_API_VERSION 18
 
 // set options to API_OPTIONS to ensure compatibility
 #ifndef API_OPTIONS
index 6dc4e63b79df844a179a5d3ba30ef4e89d4fcae7..64a3522aa8dbb25a45f606729fb4e506007e862f 100644 (file)
@@ -84,32 +84,35 @@ enum class PktType : std::uint8_t
 #define PROTO_BIT__ANY_SSN  (PROTO_BIT__ANY_IP | PROTO_BIT__PDU | PROTO_BIT__FILE | PROTO_BIT__USER)
 #define PROTO_BIT__ANY_TYPE (PROTO_BIT__ANY_SSN | PROTO_BIT__ARP)
 
-enum DecodeFlags : std::uint16_t
+enum DecodeFlags : std::uint32_t
 {
-    DECODE_ERR_CKSUM_IP =   0x0001,  // error flags
-    DECODE_ERR_CKSUM_TCP =  0x0002,
-    DECODE_ERR_CKSUM_UDP =  0x0004,
-    DECODE_ERR_CKSUM_ICMP = 0x0008,
-    DECODE_ERR_BAD_TTL =    0x0010,
+    DECODE_ERR_CKSUM_IP =   0x00000001,  // error flags
+    DECODE_ERR_CKSUM_TCP =  0x00000002,
+    DECODE_ERR_CKSUM_UDP =  0x00000004,
+    DECODE_ERR_CKSUM_ICMP = 0x00000008,
+    DECODE_ERR_BAD_TTL =    0x00000010,
 
     DECODE_ERR_CKSUM_ALL = ( DECODE_ERR_CKSUM_IP | DECODE_ERR_CKSUM_TCP |
         DECODE_ERR_CKSUM_UDP | DECODE_ERR_CKSUM_ICMP ),
-    DECODE_ERR_FLAGS = ( DECODE_ERR_CKSUM_ALL | DECODE_ERR_BAD_TTL ),
 
-    DECODE_PKT_TRUST =      0x0020,  // trust this packet
-    DECODE_FRAG =           0x0040,  // ip - fragmented packet
-    DECODE_MF =             0x0080,  // ip - more fragments
-    DECODE_DF =             0x0100,  // ip - don't fragment
+    DECODE_PKT_TRUST =      0x00000020,  // trust this packet
+    DECODE_FRAG =           0x00000040,  // ip - fragmented packet
+    DECODE_MF =             0x00000080,  // ip - more fragments
+    DECODE_DF =             0x00000100,  // ip - don't fragment
 
     // using decode flags in lieu of creating user layer for now
-    DECODE_C2S =            0x0200,  // user - client to server
-    DECODE_SOF =            0x0400,  // user - start of flow
-    DECODE_EOF =            0x0800,  // user - end of flow
-    DECODE_GTP =            0x1000,
-
-    DECODE_TCP_MSS =        0x2000,
-    DECODE_TCP_TS =         0x4000,
-    DECODE_TCP_WS =         0x8000,
+    DECODE_C2S =            0x00000200,  // user - client to server
+    DECODE_SOF =            0x00000400,  // user - start of flow
+    DECODE_EOF =            0x00000800,  // user - end of flow
+    DECODE_GTP =            0x00001000,
+
+    DECODE_TCP_MSS =        0x00002000,
+    DECODE_TCP_TS =         0x00004000,
+    DECODE_TCP_WS =         0x00008000,
+
+    DECODE_ERR_LEN =        0X00010000,  // received incorrect len from DAQ
+
+    DECODE_ERR_FLAGS = ( DECODE_ERR_CKSUM_ALL | DECODE_ERR_BAD_TTL | DECODE_ERR_LEN ),
 };
 
 struct DecodeData
@@ -125,7 +128,7 @@ struct DecodeData
     uint16_t sp = 0;    /* source port (TCP/UDP) */
     uint16_t dp = 0;    /* dest port (TCP/UDP) */
 
-    uint16_t decode_flags = 0;
+    uint32_t decode_flags = 0;
     PktType type = PktType::NONE;
 
     snort::ip::IpApi ip_api;
index c064ef790a44536fca7ed945bfb583d6e20bd824..22d3c69bf2819db6778e296eadaa593b2400f793 100644 (file)
@@ -164,6 +164,17 @@ void PacketTracer::log(const char* format, ...)
     va_end(ap);
 }
 
+void PacketTracer::log_msg_only(const char* format, ...)
+{
+    if (is_paused())
+        return;
+
+    va_list ap;
+    va_start(ap, format);
+    s_pkt_trace->log_va(format, ap, false, true);
+    va_end(ap);
+}
+
 void PacketTracer::log(TracerMute mute, const char* format, ...)
 {
     if ( s_pkt_trace->mutes[mute] )
@@ -289,14 +300,14 @@ void PacketTracer::populate_buf(const char* format, va_list ap, char* buffer, ui
         buff_len = max_buff_size - 1;
 }
 
-void PacketTracer::log_va(const char* format, va_list ap, bool daq_log)
+void PacketTracer::log_va(const char* format, va_list ap, bool daq_log, bool msg_only)
 {
     // FIXIT-L Need to find way to add 'PktTracerDbg' string as part of format string.
     std::string dbg_str;
     if (shell_enabled and !daq_log) // only add debug string during shell execution
     {
         dbg_str = "PktTracerDbg ";
-        if (strcmp(format, "\n") != 0)
+        if (!msg_only && (strcmp(format, "\n") != 0))
             dbg_str += get_debug_session();
         dbg_str += format;
         format = dbg_str.c_str();
index 4dd397ebc54fcd20b15cb0ceda41a04e90680ab0..8776b5f6333ba00b92aa3d9292f9e9e2523d2e53 100644 (file)
@@ -71,6 +71,7 @@ public:
 
     static SO_PUBLIC void log(const char* format, ...) __attribute__((format (printf, 1, 2)));
     static SO_PUBLIC void log(TracerMute, const char* format, ...) __attribute__((format (printf, 2, 3)));
+    static SO_PUBLIC void log_msg_only(const char* format, ...) __attribute__((format (printf, 1, 2)));
 
     static SO_PUBLIC void daq_log(const char* format, ...) __attribute__((format (printf, 1, 2)));
     static SO_PUBLIC void pt_timer_start();
@@ -99,7 +100,7 @@ protected:
     template<typename T = PacketTracer> static void _thread_init();
 
     // non-static functions
-    void log_va(const char*, va_list, bool);
+    void log_va(const char*, va_list, bool, bool msg_only = false);
     void populate_buf(const char*, va_list, char*, uint32_t&);
     void add_ip_header_info(const snort::Packet&);
     void add_eth_header_info(const snort::Packet&);
index a9553f46e0ebcc3467a44190d4eb5662b4d7605d..17fc08e6921997502835a6490984bb3e8c50b491 100644 (file)
@@ -33,6 +33,7 @@
 #include "main/snort_config.h"
 #include "packet_io/active.h"
 #include "packet_io/sfdaq.h"
+#include "packet_tracer/packet_tracer.h"
 #include "profiler/profiler_defs.h"
 #include "stream/stream.h"
 #include "trace/trace_api.h"
@@ -171,6 +172,7 @@ void PacketManager::handle_decode_failure(Packet* p, RawData& raw, const CodecDa
     // if the codec exists, it failed
     if (CodecManager::s_proto_map[to_utype(prev_prot_id)])
     {
+        PacketTracer::log_msg_only("Packet %" PRIu64": decoding error\n", p->context->packet_number);
         s_stats[discards]++;
     }
     else
@@ -345,7 +347,13 @@ void PacketManager::decode(
 
         // Shrink the buffer of undecoded data
         const uint16_t curr_lyr_len = codec_data.lyr_len + codec_data.invalid_bytes;
-        assert(curr_lyr_len <= raw.len);
+        if (curr_lyr_len > raw.len)
+        {
+            p->ptrs.decode_flags |= DECODE_ERR_LEN;
+            PacketTracer::log("Packet %" PRIu64": current layer len %d > raw length %d\n", p->context->packet_number,
+                curr_lyr_len, raw.len);
+            return;
+        }
         raw.len -= curr_lyr_len;
         raw.data += curr_lyr_len;
 
index b43021bb48e926946ab7c6d34019c55cf3daae51..b0e338f652220829211f98f93e7b2a5f3f688326 100644 (file)
@@ -1,5 +1,10 @@
-
 add_cpputest( get_geneve_opt_test
     SOURCES
         ../packet.cc
 )
+
+add_cpputest( decode_err_len_test
+  SOURCES
+      ../packet.cc
+      ../packet_manager.cc
+)
diff --git a/src/protocols/test/decode_err_len_test.cc b/src/protocols/test/decode_err_len_test.cc
new file mode 100644 (file)
index 0000000..29620e2
--- /dev/null
@@ -0,0 +1,137 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2024-2024 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// decode_err_len_test.cc author Maya Dagon <mdagon@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "detection/detection_engine.h"
+#include "detection/ips_context.h"
+#include "flow/expect_cache.h"
+#include "log/text_log.h"
+#include "main/snort_config.h"
+#include "managers/codec_manager.h"
+#include "packet_io/sfdaq.h"
+#include "packet_tracer/packet_tracer.h"
+#include "profiler/profiler_defs.h"
+#include "stream/stream.h"
+#include "trace/trace_api.h"
+
+#include "protocols/packet.h"
+#include "protocols/packet_manager.h"
+
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+#include <CppUTestExt/MockSupport.h>
+
+using namespace snort;
+
+
+//------------------------------
+// Stubs
+//------------------------------
+bool snort::TextLog_Print(TextLog* const, const char*, ...) { return true; }
+bool snort::TextLog_Write(TextLog* const, const char*, int) { return true; }
+bool snort::TextLog_Putc(TextLog* const, char) { return true; }
+void snort::trace_vprintf(const char*, TraceLevel, const char*, const Packet*, const char*, va_list) { }
+void TraceApi::filter(const Packet&) {}
+uint8_t TraceApi::get_constraints_generation() { return 0; }
+
+void ExpectFlow::reset_expect_flows() {}
+bool SnortConfig::tunnel_bypass_enabled(uint16_t) const { return false; }
+const vlan::VlanTagHdr* layer::get_vlan_layer(const Packet*) { return nullptr; }
+const geneve::GeneveLyr* layer::get_geneve_layer(const Packet*, bool) { return nullptr; }
+void ip::IpApi::reset() {}
+void PacketTracer::log_msg_only(const char*, ...) {}
+void PacketTracer::log(const char*, ...) {}
+int DetectionEngine::queue_event(unsigned, unsigned) { return 0; }
+Packet* DetectionEngine::get_encode_packet() { return nullptr; }
+void show_percent_stats(PegCount*, const char*[], unsigned, const char*) {}
+void layer::set_packet_pointer(const Packet* const) {}
+bool layer::set_inner_ip_api(const Packet* const, ip::IpApi&, int8_t&) { return true; }
+int layer::get_inner_ip_lyr_index(const Packet* const) { return 0; }
+int layer::get_inner_ip6_frag_index(const Packet* const) { return 0; }
+uint8_t Stream::get_flow_ttl(Flow*, char, bool) { return 0; }
+bool SFDAQ::forwarding_packet(const DAQ_PktHdr_t*) { return false; }
+void sum_stats(PegCount*, PegCount*, unsigned, bool) {}
+IpsContext::IpsContext(unsigned):
+    packet(nullptr), encode_packet(nullptr), pkth (nullptr), buf(nullptr),
+    stash(nullptr), otnx(nullptr), equeue(nullptr), context_num(0),
+    active_rules(IpsContext::NONE), state(IpsContext::IDLE), check_tags(false), clear_inspectors(false),
+    data(0), depends_on(nullptr), next_to_process(nullptr) { searches.context = nullptr; }
+IpsContext::~IpsContext() {}
+Buffer::Buffer(uint8_t*, uint32_t) :
+    base(nullptr), end(0), max_len(0), off(0) {}
+EncState::EncState(const ip::IpApi& api, EncodeFlags f, IpProtocol pr, uint8_t t, uint16_t data_size) :
+    ip_api(api), flags(f), dsize(data_size), next_ethertype(ProtocolId::ETHERTYPE_NOT_SET),
+    next_proto(pr), ttl(t) {}
+
+THREAD_LOCAL bool TimeProfilerStats::enabled = false;
+THREAD_LOCAL const Trace* decode_trace = nullptr;
+std::array<uint8_t, num_protocol_ids> CodecManager::s_proto_map {
+    { 0 }
+};
+
+THREAD_LOCAL ProtocolId CodecManager::grinder_id = ProtocolId::ETHERTYPE_NOT_SET;
+THREAD_LOCAL uint8_t CodecManager::grinder = 0;
+
+//-----------------------------
+// Mocks
+//-----------------------------
+class MockCodec : public Codec
+{
+public:
+    MockCodec() : Codec("mock_codec") { }
+
+    bool decode(const RawData& raw, CodecData& codec_data, DecodeData&) override
+    {
+        codec_data.lyr_len = raw.len +1;
+        codec_data.next_prot_id = ProtocolId::FINISHED_DECODE;
+        return true;
+    }
+};
+
+MockCodec mock_cd;
+std::array<Codec*, UINT8_MAX> CodecManager::s_protocols { { &mock_cd } };
+THREAD_LOCAL uint8_t CodecManager::max_layers = 1;
+
+//-----------------------------
+// Test
+//-----------------------------
+
+TEST_GROUP(decode_err_len_tests)
+{
+};
+
+TEST(decode_err_len_tests, layer_len_more_than_raw)
+{
+    Packet p(false);
+    p.context = new IpsContext();
+    _daq_msg msg;
+    memset(&msg, 0, sizeof(_daq_msg));
+    p.daq_msg = &msg;
+    PacketManager::decode(&p, nullptr, nullptr, 10, false);
+    CHECK_TRUE((p.ptrs.decode_flags & DECODE_ERR_LEN) != 0);
+    delete p.context;
+}
+
+int main(int argc, char** argv)
+{
+    return CommandLineTestRunner::RunAllTests(argc, argv);
+}