]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #2474 in SNORT/snort3 from ~ARMANDAV/snort3:rna_service to master
authorMasud Hasan (mashasan) <mashasan@cisco.com>
Tue, 22 Sep 2020 01:00:39 +0000 (01:00 +0000)
committerMasud Hasan (mashasan) <mashasan@cisco.com>
Tue, 22 Sep 2020 01:00:39 +0000 (01:00 +0000)
Squashed commit of the following:

commit 45fe15c3bfa63927ccb6d9cedb486ebae9f5b739
Author: Arun Mandava <armandav@cisco.com>
Date:   Mon Sep 21 15:10:43 2020 -0400

    rna: Service discovery with multiple vendor and version support

src/host_tracker/host_tracker.cc
src/host_tracker/host_tracker.h
src/host_tracker/host_tracker_module.cc
src/host_tracker/host_tracker_module.h
src/host_tracker/test/host_cache_allocator_test.cc
src/network_inspectors/rna/rna_app_discovery.cc
src/network_inspectors/rna/rna_app_discovery.h
src/network_inspectors/rna/rna_logger.cc
src/network_inspectors/rna/rna_pnd.cc

index a5770c499f81ae3a1229d859cc8e362597b2af59..1c818a555f11b29f38602c36f16b7fc41c3a8482 100644 (file)
@@ -22,8 +22,9 @@
 #include "config.h"
 #endif
 
+#include "host_cache.h"
+#include "host_cache_allocator.cc"
 #include "host_tracker.h"
-
 #include "utils/util.h"
 
 using namespace snort;
@@ -247,6 +248,39 @@ bool HostTracker::add_service(Port port, IpProtocol proto, AppId appid, bool inf
     return true;
 }
 
+void HostTracker::clear_service(HostApplication& ha)
+{
+    lock_guard<mutex> lck(host_tracker_lock);
+    ha = {};
+}
+
+bool HostTracker::add_service(HostApplication& app, bool* added)
+{
+    host_tracker_stats.service_adds++;
+    lock_guard<mutex> lck(host_tracker_lock);
+
+    for ( auto& s : services )
+    {
+        if ( s.port == app.port and s.proto == app.proto )
+        {
+            if ( s.appid != app.appid and app.appid != APP_ID_NONE )
+            {
+                s.appid = app.appid;
+                s.inferred_appid = app.inferred_appid;
+                if (added)
+                    *added = true;
+            }
+            return true;
+        }
+    }
+
+    services.emplace_back(app.port, app.proto, app.appid, app.inferred_appid);
+    if (added)
+        *added = true;
+
+    return true;
+}
+
 AppId HostTracker::get_appid(Port port, IpProtocol proto, bool inferred_only,
     bool allow_port_wildcard)
 {
@@ -280,15 +314,21 @@ HostApplication HostTracker::get_service(Port port, IpProtocol proto, uint32_t l
     {
         if ( s.port == port and s.proto == proto )
         {
-            if ( s.appid != appid and appid != APP_ID_NONE )
+            if ( appid != APP_ID_NONE and s.appid != appid )
             {
                 s.appid = appid;
                 is_new = true;
+                s.hits = 1;
             }
             else if ( s.last_seen == 0 )
+            {
                 is_new = true;
+                s.hits = 1;
+            }
+            else
+                ++s.hits;
+
             s.last_seen = lseen;
-            ++s.hits;
             return s;
         }
     }
@@ -310,14 +350,25 @@ void HostTracker::update_service(const HostApplication& ha)
         {
             s.hits = ha.hits;
             s.last_seen = ha.last_seen;
-            if ( ha.appid > APP_ID_NONE )
-                s.appid = ha.appid;
             return;
         }
     }
 }
 
-bool HostTracker::update_service_info(HostApplication& ha, const char* vendor, const char* version)
+void HostTracker::update_service_port(HostApplication& app, Port port)
+{
+    lock_guard<mutex> lck(host_tracker_lock);
+    app.port = port;
+}
+
+void HostTracker::update_service_proto(HostApplication& app, IpProtocol proto)
+{
+    lock_guard<mutex> lck(host_tracker_lock);
+    app.proto = proto;
+}
+
+bool HostTracker::update_service_info(HostApplication& ha, const char* vendor,
+    const char* version, uint16_t max_info)
 {
     host_tracker_stats.service_finds++;
     lock_guard<mutex> lck(host_tracker_lock);
@@ -326,23 +377,28 @@ bool HostTracker::update_service_info(HostApplication& ha, const char* vendor, c
     {
         if ( s.port == ha.port and s.proto == ha.proto )
         {
-            bool changed = false;
-            if ( vendor and strncmp(s.vendor, vendor, INFO_SIZE) )
+            if (s.info.size() < max_info)
             {
-                strncpy(s.vendor, vendor, INFO_SIZE);
-                s.vendor[INFO_SIZE-1] = '\0';
-                changed = true;
+                for (auto& i : s.info)
+                {
+                    if (((!version and i.version[0] == '\0') or
+                        (version and !strncmp(version, i.version, INFO_SIZE)))
+                        and ((!vendor and i.vendor[0] == '\0') or
+                        (vendor and !strncmp(vendor, i.vendor, INFO_SIZE))))
+                            return false;
+                }
+                s.info.emplace_back(version, vendor);
             }
-            if ( version and strncmp(s.version, version, INFO_SIZE) )
-            {
-                strncpy(s.version, version, INFO_SIZE);
-                s.version[INFO_SIZE-1] = '\0';
-                changed = true;
-            }
-            if ( !changed )
-                return false;
 
-            ha.appid = s.appid; // copy these info for the caller
+            // copy these info for the caller
+            if (ha.appid == APP_ID_NONE)
+                ha.appid = s.appid;
+            else
+                s.appid = ha.appid;
+
+            for (auto& i: s.info)
+                ha.info.emplace_back(i.version, i.vendor);
+
             ha.hits = s.hits;
             return true;
         }
@@ -407,6 +463,16 @@ size_t HostTracker::get_client_count()
     return clients.size();
 }
 
+HostClient::HostClient(AppId clientid, const char *ver, AppId ser) :
+    id(clientid), service(ser)
+{
+    if (ver)
+    {
+        strncpy(version, ver, INFO_SIZE);
+        version[INFO_SIZE-1] = '\0';
+    }
+}
+
 HostClient HostTracker::get_client(AppId id, const char* version, AppId service, bool& is_new)
 {
     lock_guard<mutex> lck(host_tracker_lock);
@@ -426,6 +492,20 @@ HostClient HostTracker::get_client(AppId id, const char* version, AppId service,
     return clients.back();
 }
 
+HostApplicationInfo::HostApplicationInfo(const char *ver, const char *ven)
+{
+    if (ver)
+    {
+        strncpy(version, ver, INFO_SIZE);
+        version[INFO_SIZE-1] = '\0';
+    }
+    if (ven)
+    {
+        strncpy(vendor, ven, INFO_SIZE);
+        vendor[INFO_SIZE-1] = '\0';
+    }
+}
+
 static inline string to_time_string(uint32_t p_time)
 {
     time_t raw_time = (time_t) p_time;
@@ -474,10 +554,15 @@ void HostTracker::stringify(string& str)
                 if ( s.inferred_appid )
                     str += ", inferred";
             }
-            if ( s.vendor[0] != '\0' )
-                str += ", vendor: " + string(s.vendor);
-            if ( s.version[0] != '\0' )
-                str += ", version: " + string(s.version);
+
+            if ( !s.info.empty() )
+                for ( const auto& i : s.info )
+                {
+                    if ( i.vendor[0] != '\0' )
+                        str += ", vendor: " + string(i.vendor);
+                    if ( i.version[0] != '\0' )
+                        str += ", version: " + string(i.version);
+                }
         }
     }
 
index 1d7890b732d9a99cf90b2a4ad5351a193e300315..4302e9d68d17244c66dd575c4d546d5200eaf2e8 100644 (file)
@@ -66,35 +66,36 @@ struct HostMac
     uint32_t last_seen;
 };
 
+struct HostApplicationInfo
+{
+    HostApplicationInfo() = default;
+    HostApplicationInfo(const char *ver, const char *ven);
+    char vendor[INFO_SIZE] = { 0 };
+    char version[INFO_SIZE] = { 0 };
+};
+
+typedef HostCacheAllocIp<HostApplicationInfo> HostAppInfoAllocator;
+
 struct HostApplication
 {
     HostApplication() = default;
     HostApplication(Port pt, IpProtocol pr, AppId ap, bool in, uint32_t ht = 0, uint32_t ls = 0) :
         port(pt), proto(pr), appid(ap), inferred_appid(in), hits(ht), last_seen(ls) { }
 
-    Port port;
+    Port port = 0;
     IpProtocol proto;
-    AppId appid;
-    bool inferred_appid;
-    uint32_t hits;
-    uint32_t last_seen;
-    char vendor[INFO_SIZE] = { 0 };
-    char version[INFO_SIZE] = { 0 };
+    AppId appid = APP_ID_NONE;
+    bool inferred_appid = false;
+    uint32_t hits = 0;
+    uint32_t last_seen = 0;
+
+    std::vector<HostApplicationInfo, HostAppInfoAllocator> info;
 };
 
 struct HostClient
 {
     HostClient() = default;
-    HostClient(AppId clientid, const char *ver, AppId ser) :
-        id(clientid), service(ser)
-    {
-        if (ver)
-        {
-            strncpy(version, ver, INFO_SIZE);
-            version[INFO_SIZE-1] = '\0';
-        }
-    }
-
+    HostClient(AppId clientid, const char *ver, AppId ser);
     AppId id;
     char version[INFO_SIZE] = { 0 };
     AppId service;
@@ -208,6 +209,10 @@ public:
     // appid detected from one flow to another flow such as BitTorrent.
     bool add_service(Port port, IpProtocol proto,
         AppId appid = APP_ID_NONE, bool inferred_appid = false, bool* added = nullptr);
+    bool add_service(HostApplication& app, bool* added = nullptr);
+    void clear_service(HostApplication& hs);
+    void update_service_port(HostApplication& app, Port port);
+    void update_service_proto(HostApplication& app, IpProtocol proto);
 
     AppId get_appid(Port port, IpProtocol proto, bool inferred_only = false,
         bool allow_port_wildcard = false);
@@ -216,7 +221,8 @@ public:
     HostApplication get_service(Port port, IpProtocol proto, uint32_t lseen, bool& is_new,
         AppId appid = APP_ID_NONE);
     void update_service(const HostApplication& ha);
-    bool update_service_info(HostApplication& ha, const char* vendor, const char* version);
+    bool update_service_info(HostApplication& ha, const char* vendor, const char* version,
+        uint16_t max_info);
     void remove_inferred_services();
 
     size_t get_client_count();
index d7b18fbe203460640380259347599ba2e87ba7d8..0274f910b93aa7871c8b989cfa41618f9dce0b40 100644 (file)
@@ -62,13 +62,12 @@ bool HostTrackerModule::set(const char*, Value& v, SnortConfig*)
         v.get_addr(addr);
 
     else if ( v.is("port") )
-        app.port = v.get_uint16();
-
+        host_cache[addr]->update_service_port(app, v.get_uint16());
     else if ( v.is("proto") )
     {
         const IpProtocol mask[] =
         { IpProtocol::IP, IpProtocol::TCP, IpProtocol::UDP };
-        app.proto = mask[v.get_uint8()];
+        host_cache[addr]->update_service_proto(app, mask[v.get_uint8()]);
     }
 
     else
@@ -82,7 +81,6 @@ bool HostTrackerModule::begin(const char* fqn, int idx, SnortConfig*)
     if ( idx && !strcmp(fqn, "host_tracker") )
     {
         addr.clear();
-        app = {};
     }
     return true;
 }
@@ -92,12 +90,14 @@ bool HostTrackerModule::end(const char* fqn, int idx, SnortConfig*)
     if ( idx && !strcmp(fqn, "host_tracker.services") )
     {
         if ( addr.is_set() )
-            host_cache[addr]->add_service(app.port, app.proto);
-        app = {};
+            host_cache[addr]->add_service(app);
+
+        host_cache[addr]->clear_service(app);
     }
     else if ( idx && !strcmp(fqn, "host_tracker") && addr.is_set() )
     {
         host_cache[addr];
+        host_cache[addr]->clear_service(app);
         addr.clear();
     }
 
index e3d1a3c79498933dc3c2b859dc3f337b5d5516a0..6e9888906a9387d9c86e4b69e71ee8877bf3b1a3 100644 (file)
@@ -31,6 +31,7 @@
 
 #include "framework/module.h"
 #include "host_tracker/host_cache.h"
+#include "host_tracker/host_cache_allocator.cc"
 
 #define host_tracker_help \
     "configure hosts"
index 1497262c2a5893392c31c6c82e38f7f8a5050440..22614e9fc1d042eb19d9488f7dc29320a31f1343 100644 (file)
@@ -23,7 +23,7 @@
 #endif
 
 #include "host_tracker/host_cache.h"
-#include "host_tracker/host_cache_allocator.h"
+#include "host_tracker/host_cache_allocator.cc"
 
 #include <string>
 
@@ -32,6 +32,8 @@
 #include <CppUTest/CommandLineTestRunner.h>
 #include <CppUTest/TestHarness.h>
 
+HostCacheIp host_cache(100);
+
 using namespace std;
 using namespace snort;
 
index 0163c872238307c482bbe0b8cea55b1db2d8bbcf..6298a78a4e14bd0869d515e77d8d432598784e9a 100644 (file)
@@ -71,11 +71,17 @@ void RnaAppDiscovery::process(AppidEvent* appid_event, DiscoveryFilter& filter,
         appid_session_api.get_app_id(&service, &client, &payload, nullptr, nullptr);
 
         if ( appid_change_bits[APPID_SERVICE_BIT] and service > APP_ID_NONE )
-            discover_service(p, proto, ht, (const struct in6_addr*) src_ip->get_ip6_ptr(),
-                src_mac, conf, logger, service);
+        {
+            if ( p->packet_flags & PKT_FROM_SERVER )
+                discover_service(p, proto, ht, (const struct in6_addr*) src_ip->get_ip6_ptr(),
+                    src_mac, conf, logger, p->flow->server_port, service);
+            else if ( p->packet_flags & PKT_FROM_CLIENT )
+                discover_service(p, proto, ht, (const struct in6_addr*) src_ip->get_ip6_ptr(),
+                    src_mac, conf, logger, p->flow->client_port, service);
+        }
 
         if (appid_change_bits[APPID_CLIENT_BIT] and client > APP_ID_NONE
-            and service > APP_ID_NONE)
+            and service > APP_ID_NONE )
         {
             const char* version = appid_session_api.get_client_version();
             discover_client(p, ht, (const struct in6_addr*) src_ip->get_ip6_ptr(), src_mac,
@@ -88,8 +94,12 @@ void RnaAppDiscovery::process(AppidEvent* appid_event, DiscoveryFilter& filter,
         const char* vendor;
         const char* version;
         const AppIdServiceSubtype* subtype;
+        AppId service, client, payload;
+
+        appid_session_api.get_app_id(&service, &client, &payload, nullptr, nullptr);
         appid_session_api.get_service_info(vendor, version, subtype);
-        update_service_info(p, proto, vendor, version, ht, src_ip, src_mac, logger);
+        update_service_info(p, proto, vendor, version, ht, src_ip, src_mac, logger, conf,
+            service);
     }
 
     if ( p->is_from_client() and ( appid_change_bits[APPID_HOST_BIT] or
@@ -113,7 +123,7 @@ void RnaAppDiscovery::process(AppidEvent* appid_event, DiscoveryFilter& filter,
 
 void RnaAppDiscovery::discover_service(const Packet* p, IpProtocol proto, RnaTracker& rt,
     const struct in6_addr* src_ip, const uint8_t* src_mac, RnaConfig* conf,
-    RnaLogger& logger, AppId service)
+    RnaLogger& logger, uint16_t port, AppId service)
 {
     if ( conf and conf->max_host_services and conf->max_host_services <= rt->get_service_count() )
         return;
@@ -121,7 +131,7 @@ void RnaAppDiscovery::discover_service(const Packet* p, IpProtocol proto, RnaTra
     bool is_new = false;
 
     // Work on a local copy instead of reference as we release lock during event generations
-    auto ha = rt->get_service(p->flow->server_port, proto, (uint32_t) packet_time(), is_new);
+    auto ha = rt->get_service(port, proto, (uint32_t) packet_time(), is_new, service);
     if ( is_new )
     {
         if ( proto == IpProtocol::TCP )
@@ -130,33 +140,23 @@ void RnaAppDiscovery::discover_service(const Packet* p, IpProtocol proto, RnaTra
             logger.log(RNA_EVENT_NEW, NEW_UDP_SERVICE, p, &rt, src_ip, src_mac, &ha);
 
         ha.hits = 0; // hit count is reset after logs are written
-        ha.appid = service;
         rt->update_service(ha);
     }
 }
 
 void RnaAppDiscovery::update_service_info(const Packet* p, IpProtocol proto, const char* vendor,
-    const char* version, RnaTracker& rt, const SfIp* ip, const uint8_t* src_mac, RnaLogger& logger)
+    const char* version, RnaTracker& rt, const SfIp* ip, const uint8_t* src_mac, RnaLogger& logger,
+    RnaConfig* conf, AppId service)
 {
     if ( !vendor and !version )
         return;
 
-    HostApplication ha(p->flow->server_port, proto, APP_ID_NONE, false);
-    if ( !rt->update_service_info(ha, vendor, version) )
+    HostApplication ha(p->flow->server_port, proto, service, false);
+    if ( !rt->update_service_info(ha, vendor, version, conf->max_host_service_info) )
         return;
 
     // Work on a local copy for eventing purpose
     ha.last_seen = (uint32_t) packet_time();
-    if ( vendor )
-    {
-        strncpy(ha.vendor, vendor, INFO_SIZE);
-        ha.vendor[INFO_SIZE-1] = '\0';
-    }
-    if ( version )
-    {
-        strncpy(ha.version, version, INFO_SIZE);
-        ha.version[INFO_SIZE-1] = '\0';
-    }
 
     if ( proto == IpProtocol::TCP )
         logger.log(RNA_EVENT_CHANGE, CHANGE_TCP_SERVICE_INFO, p, &rt,
index 6e44f28a59268ba846b92e13ec639e886429914b..513e5251d6a44c96678fe2a817d702ee527a4bd4 100644 (file)
@@ -29,7 +29,7 @@ public:
 
     static void discover_service(const snort::Packet* p, IpProtocol proto, RnaTracker& rt,
         const struct in6_addr* src_ip, const uint8_t* src_mac, RnaConfig* conf,
-        RnaLogger& logger, AppId service = APP_ID_NONE);
+        RnaLogger& logger, uint16_t port, AppId service = APP_ID_NONE);
 
     static void discover_client(const snort::Packet* p, RnaTracker& rt,
         const struct in6_addr* src_ip, const uint8_t* src_mac, RnaConfig* conf,
@@ -37,7 +37,7 @@ public:
 private:
     static void update_service_info(const snort::Packet* p, IpProtocol proto, const char* vendor,
         const char* version, RnaTracker& rt, const snort::SfIp* ip, const uint8_t* src_mac,
-        RnaLogger& logger);
+        RnaLogger& logger, RnaConfig* conf, AppId service);
     static void analyze_user_agent_fingerprint(const snort::Packet* p, const char* host,
         const char* uagent, RnaTracker& rt, const snort::SfIp* ip, const uint8_t* src_mac,
         RnaLogger& logger);
index da8e94266e500bff87c63f9f7db95eb95a76a947..bd9ada57308532b8abf3d1e56b369c0e0bba468f 100644 (file)
@@ -68,6 +68,23 @@ static inline void rna_logger_message(const RnaLoggerEvent& rle)
                 debug_logf(rna_trace, nullptr, "RNA client log: client %u, service %u\n",
                     rle.hc->id, rle.hc->service);
         }
+        if (rle.ha)
+        {
+            debug_logf(rna_trace, nullptr,
+                "RNA Service Info log: appid: %d proto %u, port: %u\n",
+                rle.ha->appid, (uint32_t)rle.ha->proto, rle.ha->port);
+
+            for (auto& s: rle.ha->info)
+            {
+                if (s.vendor[0] != '\0')
+                    debug_logf(rna_trace, nullptr, "RNA Service Info log: vendor: %s\n",
+                        s.vendor);
+
+                if (s.version[0] != '\0')
+                    debug_logf(rna_trace, nullptr, "RNA Service Info log: version: %s\n",
+                        s.version);
+            }
+        }
     }
     else
         debug_logf(rna_trace, nullptr, "RNA log: type %u, subtype %u, mac %s\n",
index 775cd90e8519c37b941c755b4e5cf497c6f57cac..c99bdaf61b0271765fad9b6c95f2e203892d7527 100644 (file)
@@ -178,7 +178,8 @@ void RnaPnd::discover_network(const Packet* p, uint8_t ttl)
     {
         auto proto = p->get_ip_proto_next();
         if ( proto == IpProtocol::TCP or proto == IpProtocol::UDP )
-            RnaAppDiscovery::discover_service(p, proto, ht, src_ip_ptr, src_mac, conf, logger);
+            RnaAppDiscovery::discover_service(p, proto, ht, src_ip_ptr, src_mac, conf,
+                logger, p->flow->server_port);
     }
 
     if ( !new_host )