#include "managers/inspector_manager.h"
#include "profiler/profiler.h"
#include "src/main.h"
+#include "target_based/host_attributes.h"
#include "trace/trace.h"
#include "utils/util.h"
OdpContext& current_odp_ctxt = ctxt.get_odp_ctxt();
assert(pkt_thread_odp_ctxt != ¤t_odp_ctxt);
+ HostAttributesManager::clear_appid_services();
AppIdServiceState::clean();
AppIdPegCounts::cleanup_pegs();
AppIdServiceState::initialize(ctxt.config.memcap);
AppIdContext& ctxt = inspector->get_ctxt();
OdpContext& old_odp_ctxt = ctxt.get_odp_ctxt();
+ ServiceDiscovery::clear_ftp_service_state();
clear_dynamic_host_cache_services();
AppIdPegCounts::cleanup_peg_info();
LuaDetectorManager::clear_lua_detector_mgrs();
if (tmp_snort_protocol_id != snort_protocol_id)
{
snort_protocol_id = tmp_snort_protocol_id;
- Stream::set_snort_protocol_id(p->flow, tmp_snort_protocol_id);
+ Stream::set_snort_protocol_id(p->flow, tmp_snort_protocol_id, true);
}
}
return asd.add_flow_data_id(21, ftp_service);
}
+void ServiceDiscovery::clear_ftp_service_state()
+{
+ ftp_service = nullptr;
+}
+
bool ServiceDiscovery::do_service_discovery(AppIdSession& asd, Packet* p,
AppidSessionDirection direction, AppidChangeBits& change_bits)
{
int fail_service(AppIdSession&, const snort::Packet*, AppidSessionDirection dir, ServiceDetector*, ServiceDiscoveryState* sds = nullptr);
int incompatible_data(AppIdSession&, const snort::Packet*, AppidSessionDirection dir, ServiceDetector*);
static int add_ftp_service_state(AppIdSession&);
+ static void clear_ftp_service_state();
private:
void get_next_service(const snort::Packet*, const AppidSessionDirection dir, AppIdSession&);
return UNKNOWN_PROTOCOL_ID;
}
-SnortProtocolId Stream::set_snort_protocol_id(Flow* flow, SnortProtocolId id)
+SnortProtocolId Stream::set_snort_protocol_id(Flow* flow, SnortProtocolId id, bool is_appid_service)
{
if (!flow)
return UNKNOWN_PROTOCOL_ID;
if ( !flow->is_proxied() )
{
HostAttributesManager::update_service
- (flow->server_ip, flow->server_port, flow->ssn_state.ipprotocol, id);
+ (flow->server_ip, flow->server_port, flow->ssn_state.ipprotocol, id, is_appid_service);
}
return id;
static SnortProtocolId get_snort_protocol_id(Flow*);
// Set the protocol identifier for a stream
- static SnortProtocolId set_snort_protocol_id(Flow*, SnortProtocolId);
+ static SnortProtocolId set_snort_protocol_id(Flow*, SnortProtocolId, bool is_appid_service = false);
// initialize response count and expiration time
static void init_active_response(const Packet*, Flow*);
static THREAD_LOCAL HostAttributeStats host_attribute_stats;
bool HostAttributesDescriptor::update_service
- (uint16_t port, uint16_t protocol, SnortProtocolId snort_protocol_id, bool& updated)
+ (uint16_t port, uint16_t protocol, SnortProtocolId snort_protocol_id, bool& updated,
+ bool is_appid_service)
{
std::lock_guard<std::mutex> lck(host_attributes_lock);
{
if ( s.ipproto == protocol && (uint16_t)s.port == port )
{
+ if ( s.snort_protocol_id != snort_protocol_id )
+ s.appid_service = is_appid_service;
s.snort_protocol_id = snort_protocol_id;
updated = true;
return true;
if ( services.size() < SnortConfig::get_conf()->get_max_services_per_host() )
{
updated = false;
- services.emplace_back(HostServiceDescriptor(port, protocol, snort_protocol_id));
+ services.emplace_back(HostServiceDescriptor(port, protocol, snort_protocol_id, is_appid_service));
return true;
}
return false;
}
+void HostAttributesDescriptor::clear_appid_services()
+{
+ std::lock_guard<std::mutex> lck(host_attributes_lock);
+ for ( auto s = services.begin(); s != services.end(); )
+ {
+ if ( s->appid_service and s->snort_protocol_id != UNKNOWN_PROTOCOL_ID )
+ s = services.erase(s);
+ else
+ s++;
+ }
+}
+
SnortProtocolId HostAttributesDescriptor::get_snort_protocol_id(int ipprotocol, uint16_t port) const
{
std::lock_guard<std::mutex> lck(host_attributes_lock);
return nullptr;
}
-void HostAttributesManager::update_service(const snort::SfIp& host_ip, uint16_t port, uint16_t protocol, SnortProtocolId snort_protocol_id)
+void HostAttributesManager::update_service(const snort::SfIp& host_ip, uint16_t port,
+ uint16_t protocol, SnortProtocolId snort_protocol_id, bool is_appid_service)
{
if ( active_cache )
{
}
bool updated = false;
- if ( host->update_service(port, protocol, snort_protocol_id, updated) )
+ if ( host->update_service(port, protocol, snort_protocol_id, updated, is_appid_service) )
{
if ( updated )
host_attribute_stats.dynamic_service_updates++;
}
}
+void HostAttributesManager::clear_appid_services()
+{
+ if ( active_cache )
+ {
+ auto hosts = active_cache->get_all_data();
+ for ( auto& h : hosts )
+ h.second->clear_appid_services();
+ }
+}
+
int32_t HostAttributesManager::get_num_host_entries()
{
if ( active_cache )
{
public:
HostServiceDescriptor() = default;
- HostServiceDescriptor(uint16_t port, uint16_t protocol, SnortProtocolId spi)
- : port(port), ipproto(protocol), snort_protocol_id(spi)
+ HostServiceDescriptor(uint16_t port, uint16_t protocol, SnortProtocolId spi, bool appid_service)
+ : port(port), ipproto(protocol), snort_protocol_id(spi), appid_service(appid_service)
{ }
~HostServiceDescriptor() = default;
uint16_t port = 0;
uint16_t ipproto = 0;
SnortProtocolId snort_protocol_id = UNKNOWN_PROTOCOL_ID;
+ bool appid_service = false;
};
struct HostPolicyDescriptor
HostAttributesDescriptor() = default;
~HostAttributesDescriptor() = default;
- bool update_service(uint16_t port, uint16_t protocol, SnortProtocolId, bool& updated);
+ bool update_service(uint16_t port, uint16_t protocol, SnortProtocolId, bool& updated,
+ bool is_appid_service = false);
+ void clear_appid_services();
SnortProtocolId get_snort_protocol_id(int ipprotocol, uint16_t port) const;
const snort::SfIp& get_ip_addr() const
static bool add_host(HostAttributesEntry, snort::SnortConfig*);
static HostAttributesEntry find_host(const snort::SfIp&);
- static void update_service(const snort::SfIp&, uint16_t port, uint16_t protocol, SnortProtocolId);
+ static void update_service(const snort::SfIp&, uint16_t port, uint16_t protocol,
+ SnortProtocolId, bool is_appid_service = false);
+ static void clear_appid_services();
static int32_t get_num_host_entries();
static const PegInfo* get_pegs();
static PegCount* get_peg_counts();