]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #608 in SNORT/snort3 from mpls_encode to master
authorShawn Turner (shaturne) <shaturne@cisco.com>
Fri, 2 Sep 2016 18:03:28 +0000 (14:03 -0400)
committerShawn Turner (shaturne) <shaturne@cisco.com>
Fri, 2 Sep 2016 18:03:28 +0000 (14:03 -0400)
Squashed commit of the following:

commit 42ccbfaa13ee35556dfde13671aebb74b99ac014
Author: Bhagya Tholpady <bbantwal@cisco.com>
Date:   Tue Aug 30 01:13:12 2016 -0400

    porting mpls encode changes from 2.x

17 files changed:
src/codecs/ip/cd_ipv4.cc
src/codecs/ip/cd_ipv6.cc
src/codecs/ip/cd_tcp.cc
src/codecs/ip/cd_udp.cc
src/codecs/link/cd_fabricpath.cc
src/codecs/link/cd_mpls.cc
src/codecs/link/cd_pppoe.cc
src/codecs/misc/cd_gtp.cc
src/codecs/root/cd_eth.cc
src/flow/flow.cc
src/flow/flow.h
src/flow/flow_control.cc
src/framework/codec.h
src/piglet_plugins/pp_codec_iface.cc
src/protocols/layer.cc
src/protocols/layer.h
src/protocols/packet_manager.cc

index 5bbf623d8f93fa13016393d805b1d93064a4dde7..f762a97f9f688c8c337ff90e6df0547f019c4780 100644 (file)
@@ -115,7 +115,7 @@ public:
     bool decode(const RawData&, CodecData&, DecodeData&) override;
     void log(TextLog* const, const uint8_t* pkt, const uint16_t len) override;
     bool encode(const uint8_t* const raw_in, const uint16_t raw_len,
-        EncState&, Buffer&) override;
+        EncState&, Buffer&, Flow*) override;
     void update(const ip::IpApi&, const EncodeFlags, uint8_t* raw_pkt,
         uint16_t lyr_len, uint32_t& updated_len) override;
     void format(bool reverse, uint8_t* raw_pkt, DecodeData& snort) override;
@@ -629,7 +629,7 @@ static inline uint16_t IpId_Next()
  ******************** E N C O D E R  ******************************
  ******************************************************************/
 bool Ipv4Codec::encode(const uint8_t* const raw_in, const uint16_t /*raw_len*/,
-    EncState& enc, Buffer& buf)
+    EncState& enc, Buffer& buf, Flow*)
 {
     if (!buf.allocate(ip::IP4_HEADER_LEN))
         return false;
index c615db7c8c2e61a8c810dd9c6d8e5442d3cf1bfa..8d8edf67046577cd0f203e51a2bf1704579d64a7 100644 (file)
@@ -94,7 +94,7 @@ public:
     void get_protocol_ids(std::vector<ProtocolId>& v) override;
     bool decode(const RawData&, CodecData&, DecodeData&) override;
     bool encode(const uint8_t* const raw_in, const uint16_t raw_len,
-        EncState&, Buffer&) override;
+        EncState&, Buffer&, Flow*) override;
     void update(const ip::IpApi&, const EncodeFlags, uint8_t* raw_pkt,
         uint16_t lyr_len, uint32_t& updated_len) override;
     void format(bool reverse, uint8_t* raw_pkt, DecodeData& snort) override;
@@ -556,7 +556,7 @@ void Ipv6Codec::log(TextLog* const text_log, const uint8_t* raw_pkt,
  ******************************************************************/
 
 bool Ipv6Codec::encode(const uint8_t* const raw_in, const uint16_t /*raw_len*/,
-    EncState& enc, Buffer& buf)
+    EncState& enc, Buffer& buf, Flow*)
 {
     if (!buf.allocate(sizeof(ip::IP6Hdr)))
         return false;
index def9c6a96af9b04f2b287a7a630bae153a416d8d..f3424ae98b6f14205919586b30096eac4ac17fda 100644 (file)
@@ -114,7 +114,7 @@ public:
     void log(TextLog* const, const uint8_t* pkt, const uint16_t len) override;
     bool decode(const RawData&, CodecData&, DecodeData&) override;
     bool encode(const uint8_t* const raw_in, const uint16_t raw_len,
-        EncState&, Buffer&) override;
+        EncState&, Buffer&, Flow*) override;
     void update(const ip::IpApi&, const EncodeFlags, uint8_t* raw_pkt,
         uint16_t lyr_len, uint32_t& updated_len) override;
     void format(bool reverse, uint8_t* raw_pkt, DecodeData& snort) override;
@@ -579,7 +579,7 @@ void TcpCodec::log(TextLog* const text_log, const uint8_t* raw_pkt,
 //-------------------------------------------------------------------------
 
 bool TcpCodec::encode(const uint8_t* const raw_in, const uint16_t /*raw_len*/,
-    EncState& enc, Buffer& buf)
+    EncState& enc, Buffer& buf, Flow*)
 {
     const tcp::TCPHdr* const hi = reinterpret_cast<const tcp::TCPHdr*>(raw_in);
 
index bdabe170ccccb1980b261779828ba9a6cf156ca3..959da7ed83241d188c4f623d3f6f9aa689a4b474 100644 (file)
@@ -146,7 +146,7 @@ public:
     bool decode(const RawData&, CodecData&, DecodeData&) override;
 
     bool encode(const uint8_t* const raw_in, const uint16_t raw_len,
-        EncState&, Buffer&) override;
+        EncState&, Buffer&, Flow*) override;
     void update(const ip::IpApi&, const EncodeFlags, uint8_t* raw_pkt,
         uint16_t lyr_len, uint32_t& updated_len) override;
     void format(bool reverse, uint8_t* raw_pkt, DecodeData& snort) override;
@@ -346,7 +346,7 @@ void UdpCodec::log(TextLog* const text_log, const uint8_t* raw_pkt,
  ******************************************************************/
 
 bool UdpCodec::encode(const uint8_t* const raw_in, const uint16_t /*raw_len*/,
-    EncState& enc, Buffer& buf)
+    EncState& enc, Buffer& buf, Flow*)
 {
     // If we enter this function, this packe is some sort of tunnel.
 
index b8bce50ef6e011d4977132ea53994de773149575..917b02405cc7603d7d388039a5940bf32be48286 100644 (file)
@@ -51,7 +51,7 @@ public:
     void get_protocol_ids(std::vector<ProtocolId>& v) override;
     bool decode(const RawData&, CodecData&, DecodeData&) override;
     bool encode(const uint8_t* const raw_in, const uint16_t raw_len,
-        EncState&, Buffer&) override;
+        EncState&, Buffer&, Flow*) override;
     void format(bool reverse, uint8_t* raw_pkt, DecodeData& snort) override;
 };
 
@@ -83,7 +83,7 @@ bool FabricPathCodec::decode(const RawData& raw, CodecData& codec, DecodeData&)
 }
 
 bool FabricPathCodec::encode(const uint8_t* const raw_in, const uint16_t /*raw_len*/,
-    EncState& enc, Buffer& buf)
+    EncState& enc, Buffer& buf, Flow*)
 {
     // not raw ip -> encode layer 2
     bool raw = ( enc.flags & ENC_FLAG_RAW );
index cd6f59bf3b091863bbc23630f98ef6485ff9e296..bce857f86bef608dd29ebe83af290c4eb6ad8864 100644 (file)
@@ -26,6 +26,7 @@
 #include "protocols/mpls.h"
 #include "main/snort_config.h"
 #include "log/text_log.h"
+#include "utils/safec.h"
 
 #define CD_MPLS_NAME "mpls"
 #define CD_MPLS_HELP "support for multiprotocol label switching"
@@ -128,6 +129,8 @@ public:
 
     void get_protocol_ids(std::vector<ProtocolId>& v) override;
     bool decode(const RawData&, CodecData&, DecodeData&) override;
+    bool encode(const uint8_t* const raw_in, const uint16_t raw_len,
+        EncState&, Buffer&, Flow*) override;
     void log(TextLog* const, const uint8_t* pkt, const uint16_t len) override;
 
 private:
@@ -232,6 +235,32 @@ bool MplsCodec::decode(const RawData& raw, CodecData& codec, DecodeData& snort)
 
     return true;
 }
+bool MplsCodec::encode(const uint8_t* const raw_in, const uint16_t raw_len,
+        EncState& enc, Buffer& buf, Flow* pflow)
+{
+    uint16_t hdr_len = raw_len;
+    const uint8_t* hdr_start = raw_in;
+    if( pflow )
+    {
+        Layer mpls_lyr = pflow->get_mpls_layer_per_dir(enc.forward());
+
+        if( mpls_lyr.length )
+        {
+            hdr_len = mpls_lyr.length;
+            hdr_start = mpls_lyr.start;
+        }
+
+    }
+
+    if (!buf.allocate(hdr_len))
+        return false;
+
+    memcpy_s(buf.data(), hdr_len, hdr_start, hdr_len);
+    enc.next_ethertype = ProtocolId::ETHERTYPE_NOT_SET;
+    enc.next_proto = IpProtocol::PROTO_NOT_SET;
+
+    return true;
+}
 
 /*
  * check if reserved labels are used properly
index 422819ba10df76536083207aca51929d0279a2b4..f31442415eecc045011031864ea7c2b6cc047fb4 100644 (file)
@@ -75,7 +75,7 @@ public: ~PPPoECodec() { }
 
     bool decode(const RawData&, CodecData&, DecodeData&) override final;
     bool encode(const uint8_t* const raw_in, const uint16_t raw_len,
-        EncState&, Buffer&) override final;
+        EncState&, Buffer&, Flow*) override final;
 
 protected:
     PPPoECodec(const char* s, PppoepktType type) :
@@ -265,7 +265,7 @@ bool PPPoECodec::decode(const RawData& raw,
  ******************************************************************/
 
 bool PPPoECodec::encode(const uint8_t* const raw_in, const uint16_t raw_len,
-    EncState&, Buffer& buf)
+    EncState&, Buffer& buf, Flow*)
 {
     if (!buf.allocate(raw_len))
         return false;
index d119daaa7b7c64f41465f6401a34415756120d66..2c5e1ce9f2ed4e734d66875ca2868c82425ec62d 100644 (file)
@@ -63,7 +63,7 @@ public:
     void get_protocol_ids(std::vector<ProtocolId>& v) override;
     bool decode(const RawData&, CodecData&, DecodeData&) override;
     bool encode(const uint8_t* const raw_in, const uint16_t raw_len,
-        EncState&, Buffer&) override;
+        EncState&, Buffer&, Flow*) override;
     void update(const ip::IpApi&, const EncodeFlags, uint8_t* raw_pkt,
         uint16_t lyr_len, uint32_t& updated_len) override;
 };
@@ -235,7 +235,7 @@ static inline bool update_GTP_length(GTPHdr* const gtph, int gtp_total_len)
 }
 
 bool GtpCodec::encode(const uint8_t* const raw_in, const uint16_t raw_len,
-    EncState&, Buffer& buf)
+    EncState&, Buffer& buf, Flow*)
 {
     if (buf.allocate(raw_len))
         return false;
index adaf6a462e517efb297255927d8a0f093d711db7..0f090a077ce6ada8aafc61f1371b7060614807c1 100644 (file)
@@ -62,7 +62,7 @@ public:
     void log(TextLog* const, const uint8_t* pkt, const uint16_t len) override;
     bool decode(const RawData&, CodecData&, DecodeData&) override;
     bool encode(const uint8_t* const raw_in, const uint16_t raw_len,
-        EncState&, Buffer&) override;
+        EncState&, Buffer&, Flow*) override;
     void format(bool reverse, uint8_t* raw_pkt, DecodeData& snort) override;
     void update(const ip::IpApi&, const EncodeFlags, uint8_t* raw_pkt,
         uint16_t lyr_len, uint32_t& updated_len) override;
@@ -151,7 +151,7 @@ void EthCodec::log(TextLog* const text_log, const uint8_t* raw_pkt,
 //-------------------------------------------------------------------------
 
 bool EthCodec::encode(const uint8_t* const raw_in, const uint16_t /*raw_len*/,
-    EncState& enc, Buffer& buf)
+    EncState& enc, Buffer& buf, Flow*)
 {
     const eth::EtherHdr* hi = reinterpret_cast<const eth::EtherHdr*>(raw_in);
 
index 00514d85e6ecbcaee21491a6def4d9236e95435e..8c5dbcb6031c8c5818dc3a913b3dffef8e75d9d9 100644 (file)
@@ -95,6 +95,17 @@ void Flow::term()
 
     if ( ha_state )
         delete ha_state;
+
+    if ( clientMplsLyr.length )
+    {
+        delete[] clientMplsLyr.start;
+        clientMplsLyr.length = 0;
+    }
+    if ( serverMplsLyr.length )
+    {
+        delete[] serverMplsLyr.start;
+        serverMplsLyr.length = 0;
+    }
 }
 
 void Flow::reset(bool do_cleanup)
@@ -110,6 +121,17 @@ void Flow::reset(bool do_cleanup)
 
     free_application_data();
 
+    if ( clientMplsLyr.length )
+    {
+        delete[] clientMplsLyr.start;
+        clientMplsLyr.length = 0;
+    }
+    if ( serverMplsLyr.length )
+    {
+        delete[] serverMplsLyr.start;
+        serverMplsLyr.length = 0;
+    }
+
     // FIXIT-M cleanup() winds up calling clear()
     if ( ssn_client )
     {
@@ -145,6 +167,17 @@ void Flow::restart(bool free_flow_data)
     if ( free_flow_data )
         free_application_data();
 
+    if ( clientMplsLyr.length )
+    {
+        delete[] clientMplsLyr.start;
+        clientMplsLyr.length = 0;
+    }
+    if ( serverMplsLyr.length )
+    {
+        delete[] serverMplsLyr.start;
+        serverMplsLyr.length = 0;
+    }
+
     bitop->reset();
 
     ssn_state.ignore_direction = 0;
@@ -429,3 +462,40 @@ void Flow::get_application_ids(AppId& serviceAppId, AppId& clientAppId,
     payloadAppId = application_ids[APP_PROTOID_PAYLOAD];
     miscAppId    = application_ids[APP_PROTOID_MISC];
 }
+
+void Flow::set_mpls_layer_per_dir(Packet* p)
+{
+    const Layer* mpls_lyr = layer::get_mpls_layer(p);
+
+    if ( !mpls_lyr || !(mpls_lyr->start) )
+        return;
+
+    if ( p->packet_flags & PKT_FROM_CLIENT )
+    {
+        if ( !clientMplsLyr.length )
+        {
+            clientMplsLyr.length = mpls_lyr->length;
+            clientMplsLyr.prot_id = mpls_lyr->prot_id;
+            clientMplsLyr.start = new uint8_t[mpls_lyr->length];
+            memcpy((void *)clientMplsLyr.start, mpls_lyr->start, mpls_lyr->length);
+        }
+    }
+    else
+    {
+        if ( !serverMplsLyr.length )
+        {
+            serverMplsLyr.length = mpls_lyr->length;
+            serverMplsLyr.prot_id = mpls_lyr->prot_id;
+            serverMplsLyr.start = new uint8_t[mpls_lyr->length];
+            memcpy((void *)serverMplsLyr.start, mpls_lyr->start, mpls_lyr->length);
+        }
+    }
+}
+
+Layer Flow::get_mpls_layer_per_dir(bool client)
+{
+    if ( client )
+        return clientMplsLyr;
+    else
+        return serverMplsLyr;
+}
index 01df9deeda9c04a10ad062bb84df7dbaafc10cf4..e3576ba620033f6d06ace83720f689c6f06de090 100644 (file)
@@ -175,6 +175,8 @@ public:
     int get_expire(const Packet*);
     bool expired(const Packet*);
     void set_ttl(Packet*, bool client);
+    void set_mpls_layer_per_dir(Packet*);
+    Layer get_mpls_layer_per_dir(bool);
 
     uint32_t update_session_flags( uint32_t flags )
     {
@@ -312,6 +314,7 @@ public:  // FIXIT-M privatize if possible
     Inspector* gadget;    // service handler
     Inspector* data;
     const char* service;
+    Layer clientMplsLyr, serverMplsLyr;
 
     unsigned policy_id;
 
index 9de59c1cf728b432037eb07b1b3bcf2f41c29456..2a2ea03bf3906cf4421ff3c7af11b6af94aecf68 100644 (file)
@@ -482,6 +482,10 @@ unsigned FlowControl::process(Flow* flow, Packet* p)
 
     flow->set_direction(p);
 
+    // This requires the packet direction to be set
+    if ( p->proto_bits & PROTO_BIT__MPLS )
+        flow->set_mpls_layer_per_dir(p);
+
     switch ( flow->flow_state )
     {
     case Flow::FlowState::SETUP:
index d9f682c13bafd0fdced137daf72b38728bdb3876..b0d2feb0408bb0024245fcddfdd83048064d1663 100644 (file)
@@ -41,6 +41,7 @@ struct TextLog;
 struct _daq_pkthdr;
 struct Packet;
 struct Layer;
+class Flow;
 enum CodecSid : uint32_t;
 
 namespace ip
@@ -329,7 +330,7 @@ public:
     virtual bool encode(const uint8_t* const /*raw_in */,
         const uint16_t /*raw_len*/,
         EncState&,
-        Buffer&)
+        Buffer&, Flow*)
     { return true; }
 
     /*
index 9e9a323e4606df960e60cfe195a8de9d7c8b3883..c21294b8d9aefce50c877cb470341a6aa3193696 100644 (file)
@@ -37,6 +37,7 @@
 #include "pp_enc_state_iface.h"
 #include "pp_ip_api_iface.h"
 #include "pp_raw_buffer_iface.h"
+#include "pp_flow_iface.h"
 
 // FIXIT-M delete this, and make the IpApi arg in codec.update required
 static const ip::IpApi default_ip_api {};
@@ -156,10 +157,11 @@ static const luaL_Reg methods[] =
             auto& rb = RawBufferIface.get(L, 1); // raw_in
             auto& es = EncStateIface.get(L, 2);
             auto& b = BufferIface.get(L, 3);
+            auto& flow = FlowIface.get(L, 4);
 
             auto& self = CodecIface.get(L);
 
-            bool result = self.encode(get_data(rb), rb.size(), es, b);
+            bool result = self.encode(get_data(rb), rb.size(), es, b, &flow);
 
             lua_pushboolean(L, result);
 
index bd4d8e88015b92726c657aec07b41b7faa9b0825..69d6d65e1bfdcb4eb6bf3fba68c01c82d9988717 100644 (file)
@@ -79,6 +79,24 @@ static inline const uint8_t* find_inner_layer(const Layer* lyr,
     return nullptr;
 }
 
+static inline const Layer* find_layer(const Layer* lyr,
+    uint8_t num_layers,
+    ProtocolId prot_id1,
+    ProtocolId prot_id2)
+{
+    int tmp = num_layers-1;
+    lyr = &lyr[tmp];
+
+    for (int i = num_layers - 1; i >= 0; i--)
+    {
+        if (lyr->prot_id == prot_id1 ||
+            lyr->prot_id == prot_id2)
+            return lyr;
+        lyr--;
+    }
+    return nullptr;
+}
+
 void set_packet_pointer(const Packet* const p)
 { curr_pkt = p; }
 
@@ -116,6 +134,15 @@ const eapol::EtherEapol* get_eapol_layer(const Packet* const p)
         find_inner_layer(lyr, num_layers, ProtocolId::ETHERTYPE_EAPOL));
 }
 
+const Layer* get_mpls_layer(const Packet* const p)
+{
+    uint8_t num_layers = p->num_layers;
+    const Layer* lyr = p->layers;
+
+    return find_layer(lyr, num_layers, ProtocolId::ETHERTYPE_MPLS_UNICAST,
+            ProtocolId::ETHERTYPE_MPLS_MULTICAST);
+}
+
 const vlan::VlanTagHdr* get_vlan_layer(const Packet* const p)
 {
     uint8_t num_layers = p->num_layers;
index 202d618e29445123eb7dd42bdc8935dc6c2c2e58..2cd4a678c8c246f3f50b6d6e1994f2b624dc4a61 100644 (file)
@@ -100,6 +100,7 @@ SO_PUBLIC const uint8_t* get_root_layer(const Packet* const);
 SO_PUBLIC const udp::UDPHdr* get_outer_udp_lyr(const Packet* const);
 // return the inner ip layer's index in the p->layers array
 SO_PUBLIC int get_inner_ip_lyr_index(const Packet* const p);
+SO_PUBLIC const Layer* get_mpls_layer(const Packet* const p);
 
 // Two versions of this because ip_defrag:: wants to call this on
 // its rebuilt packet, not on the current packet.  Extra function
index 61c6436aff6b1c9d480e79eec46bb6dedca1cb91..01daf9277bd313958cc2e49a83f37431d027a3ec 100644 (file)
@@ -390,7 +390,7 @@ bool PacketManager::encode(const Packet* p,
             const Layer& l = lyrs[i];
             ProtocolIndex mapped_prot =
                 i ? CodecManager::s_proto_map[to_utype(l.prot_id)] : CodecManager::grinder;
-            if (!CodecManager::s_protocols[mapped_prot]->encode(l.start, l.length, enc, buf))
+            if (!CodecManager::s_protocols[mapped_prot]->encode(l.start, l.length, enc, buf, p->flow))
             {
                 return false;
             }
@@ -408,7 +408,7 @@ bool PacketManager::encode(const Packet* p,
         ProtocolIndex mapped_prot =
             i ? CodecManager::s_proto_map[to_utype(l.prot_id)] : CodecManager::grinder;
 
-        if (!CodecManager::s_protocols[mapped_prot]->encode(l.start, l.length, enc, buf))
+        if (!CodecManager::s_protocols[mapped_prot]->encode(l.start, l.length, enc, buf, p->flow))
         {
             return false;
         }