]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #1166 in SNORT/snort3 from thread_local_move to master
authorRuss Combs (rucombs) <rucombs@cisco.com>
Wed, 28 Mar 2018 13:14:32 +0000 (09:14 -0400)
committerRuss Combs (rucombs) <rucombs@cisco.com>
Wed, 28 Mar 2018 13:14:32 +0000 (09:14 -0400)
Squashed commit of the following:

commit 4c20483a6cdab22db78fcca39f7778edbeef9f70
Author: Bhagya Tholpady <bbantwal@cisco.com>
Date:   Tue Mar 27 00:01:10 2018 -0400

    detect: moving thread locals identified to ips context

src/actions/act_replace.cc
src/detection/detection_engine.cc
src/detection/detection_engine.h
src/detection/ips_context.h
src/main/snort.cc

index 0a41211d2c20189578fcd3b8173df58fb297b05b..91292a5f363ce7beadbe87885a3b00e23b3c4ba7 100644 (file)
@@ -23,6 +23,7 @@
 
 #include "act_replace.h"
 
+#include "detection/detection_engine.h"
 #include "framework/ips_action.h"
 #include "framework/module.h"
 #include "packet_io/active.h"
@@ -39,59 +40,46 @@ using namespace snort;
 // queue foo
 //--------------------------------------------------------------------------
 
-struct Replacement
-{
-    std::string data;
-    unsigned offset;
-};
-
-#define MAX_REPLACEMENTS 32
-static THREAD_LOCAL Replacement* rpl;
-static THREAD_LOCAL int num_rpl = 0;
-
 void Replace_ResetQueue()
 {
-    num_rpl = 0;
+    DetectionEngine::clear_replacement();
 }
 
 void Replace_QueueChange(const std::string& s, unsigned off)
 {
-    Replacement* r;
-
-    if ( num_rpl == MAX_REPLACEMENTS )
-        return;
-
-    r = rpl + num_rpl++;
-
-    r->data = s;
-    r->offset = off;
+    DetectionEngine::add_replacement(s, off);
 }
 
-static inline void Replace_ApplyChange(Packet* p, Replacement* r)
+static inline void Replace_ApplyChange(Packet* p, std::string& data, unsigned offset)
 {
-    uint8_t* start = const_cast<uint8_t*>(p->data) + r->offset;
+    uint8_t* start = const_cast<uint8_t*>(p->data) + offset;
     const uint8_t* end = p->data + p->dsize;
     unsigned len;
 
-    if ( (start + r->data.size()) >= end )
-        len = p->dsize - r->offset;
+    if ( (start + data.size()) >= end )
+        len = p->dsize - offset;
     else
-        len = r->data.size();
+        len = data.size();
 
-    memcpy(start, r->data.c_str(), len);
+    memcpy(start, data.c_str(), len);
 }
 
 static void Replace_ModifyPacket(Packet* p)
 {
-    if ( num_rpl == 0 )
-        return;
+    std::string data;
+    unsigned offset;
+    bool modified = false;
 
-    for ( int n = 0; n < num_rpl; n++ )
+    while ( DetectionEngine::get_replacement(data, offset) )
     {
-        Replace_ApplyChange(p, rpl+n);
+        modified = true;
+        Replace_ApplyChange(p, data, offset);
     }
-    p->packet_flags |= PKT_MODIFIED;
-    num_rpl = 0;
+
+    if ( modified )
+        p->packet_flags |= PKT_MODIFIED;
+
+    DetectionEngine::clear_replacement();
 }
 
 //-------------------------------------------------------------------------
@@ -183,12 +171,6 @@ static IpsAction* rep_ctor(Module* m)
 static void rep_dtor(IpsAction* p)
 { delete p; }
 
-static void rep_tinit()
-{ rpl = new Replacement[MAX_REPLACEMENTS]; }
-
-static void rep_tterm()
-{ delete[] rpl; }
-
 static ActionApi rep_api
 {
     {
@@ -206,8 +188,8 @@ static ActionApi rep_api
     Actions::ALERT,
     nullptr,
     nullptr,
-    rep_tinit,
-    rep_tterm,
+    nullptr,
+    nullptr,
     rep_ctor,
     rep_dtor
 };
index 58d047f9bbc5501e1c98174f4dfaa1f819d5712d..275dde47c6e91ff4233b0c543765cc00fc70625e 100644 (file)
@@ -24,6 +24,7 @@
 
 #include "detection_engine.h"
 
+#include "actions/act_replace.h"
 #include "events/sfeventq.h"
 #include "filters/sfthreshold.h"
 #include "framework/endianness.h"
@@ -163,6 +164,34 @@ void DetectionEngine::set_data(unsigned id, IpsContextData* p)
 IpsContextData* DetectionEngine::get_data(unsigned id)
 { return Snort::get_switcher()->get_context()->get_context_data(id); }
 
+void DetectionEngine::add_replacement(const std::string& s, unsigned off)
+{ 
+    Replacement r;
+
+    r.data = s;
+    r.offset = off;
+    Snort::get_switcher()->get_context()->rpl.push_back(r); 
+}
+
+bool DetectionEngine::get_replacement(std::string& s, unsigned& off)
+{ 
+    if ( Snort::get_switcher()->get_context()->rpl.empty() )
+        return false;
+
+    auto rep = Snort::get_switcher()->get_context()->rpl.back();
+
+    s = rep.data;
+    off = rep.offset;
+
+    Snort::get_switcher()->get_context()->rpl.pop_back();
+    return true;
+}
+
+void DetectionEngine::clear_replacement()
+{
+    Snort::get_switcher()->get_context()->rpl.clear();
+}
+
 void DetectionEngine::disable_all(Packet* p)
 { p->context->active_rules = IpsContext::NONE; }
 
index 92fc6881d3c03323abf865c08ebcdf9b385a4a0f..444ee0fefdb494179ec731ac55cd3fd978f5a4a9 100644 (file)
@@ -31,6 +31,7 @@
 #include "main/snort_types.h"
 
 struct DataPointer;
+struct Replacement;
 
 namespace snort
 {
@@ -75,6 +76,10 @@ public:
     static void set_data(unsigned id, IpsContextData*);
     static IpsContextData* get_data(unsigned id);
 
+    static void add_replacement(const std::string&, unsigned);
+    static bool get_replacement(std::string&, unsigned&);
+    static void clear_replacement();
+
     static bool detect(Packet*, bool offload_ok = false);
     static void inspect(Packet*);
 
index f25c6925cbefdbf4f6b184f5276dcd62684e2d37..99e0073055ca64c9468319d4fba61667bc06e92e 100644 (file)
@@ -40,6 +40,11 @@ struct SF_EVENTQ;
 namespace snort
 {
 struct SnortConfig;
+struct Replacement
+{
+    std::string data;
+    unsigned offset;
+};
 
 class SO_PUBLIC IpsContextData
 {
@@ -92,6 +97,8 @@ public:
     ActiveRules active_rules;
     bool check_tags;
 
+    std::vector<Replacement> rpl;
+
     static const unsigned buf_size = Codec::PKT_MAX;
 
 private:
index ca6b528946e742c354e271d276cd6c535e4bd1f8..25c304d0a3c5f9cf685cce2ee990a1c38e769c78 100644 (file)
@@ -882,32 +882,32 @@ DAQ_Verdict Snort::process_packet(
 }
 
 // process (wire-only) packet verdicts here
-static DAQ_Verdict update_verdict(DAQ_Verdict verdict, int& inject)
+static DAQ_Verdict update_verdict(Packet* p, DAQ_Verdict verdict, int& inject)
 {
     if ( Active::packet_was_dropped() and Active::can_block() )
     {
         if ( verdict == DAQ_VERDICT_PASS )
             verdict = DAQ_VERDICT_BLOCK;
     }
-    else if ( s_packet->packet_flags & PKT_RESIZED )
+    else if ( p->packet_flags & PKT_RESIZED )
     {
         // we never increase, only trim, but daq doesn't support resizing wire packet
-        PacketManager::encode_update(s_packet);
+        PacketManager::encode_update(p);
 
-        if ( !SFDAQ::inject(s_packet->pkth, 0, s_packet->pkt, s_packet->pkth->pktlen) )
+        if ( !SFDAQ::inject(p->pkth, 0, p->pkt, p->pkth->pktlen) )
         {
             inject = 1;
             verdict = DAQ_VERDICT_BLOCK;
         }
     }
-    else if ( s_packet->packet_flags & PKT_MODIFIED )
+    else if ( p->packet_flags & PKT_MODIFIED )
     {
         // this packet was normalized and/or has replacements
-        PacketManager::encode_update(s_packet);
+        PacketManager::encode_update(p);
         verdict = DAQ_VERDICT_REPLACE;
     }
-    else if ( (s_packet->packet_flags & PKT_IGNORE) ||
-        (s_packet->flow && s_packet->flow->get_ignore_direction( ) == SSN_DIR_BOTH) )
+    else if ( (p->packet_flags & PKT_IGNORE) ||
+        (p->flow && p->flow->get_ignore_direction( ) == SSN_DIR_BOTH) )
     {
         if ( !Active::get_tunnel_bypass() )
         {
@@ -919,10 +919,10 @@ static DAQ_Verdict update_verdict(DAQ_Verdict verdict, int& inject)
             aux_counts.internal_whitelist++;
         }
     }
-    else if ( s_packet->ptrs.decode_flags & DECODE_PKT_TRUST )
+    else if ( p->ptrs.decode_flags & DECODE_PKT_TRUST )
     {
-        if (s_packet->flow)
-            s_packet->flow->set_ignore_direction(SSN_DIR_BOTH);
+        if (p->flow)
+            p->flow->set_ignore_direction(SSN_DIR_BOTH);
         verdict = DAQ_VERDICT_WHITELIST;
     }
     else
@@ -960,7 +960,7 @@ DAQ_Verdict Snort::packet_callback(
     ActionManager::execute(s_packet);
 
     int inject = 0;
-    verdict = update_verdict(verdict, inject);
+    verdict = update_verdict(s_packet, verdict, inject);
 
     PacketTracer::log("NAP id %u, IPS id %u, Verdict %s\n",
         get_network_policy()->policy_id, get_ips_policy()->policy_id,