static unsigned char* s_rec_message = nullptr;
static size_t s_rec_message_size = 0;
static int s_socket_return = 1;
+static bool s_socket_return_switch = false;
static int s_bind_return = 0;
static int s_listen_return = 0;
static int s_accept_return = 2;
}
#ifdef __GLIBC__
-int socket (int, int, int) __THROW { return s_socket_return; }
+int socket (int, int, int) __THROW
+{
+ if(s_socket_return_switch)
+ {
+ auto tmp_ret = s_socket_return;
+ s_socket_return = s_socket_return > 0 ? -1 : 1;
+ return tmp_ret;
+ }
+ return s_socket_return;
+}
int bind (int, const struct sockaddr*, socklen_t) __THROW { return s_bind_return; }
int listen (int, int) __THROW { return s_listen_return; }
#else
-int socket (int, int, int) { return s_socket_return; }
+int socket (int, int, int)
+{
+ if(s_socket_return_switch)
+ {
+ auto tmp_ret = s_socket_return;
+ s_socket_return = s_socket_return > 0 ? -1 : 1;
+ return tmp_ret;
+ }
+ return s_socket_return;
+}
int bind (int, const struct sockaddr*, socklen_t) { return s_bind_return; }
int listen (int, int) { return s_listen_return; }
#endif
TEST(unixdomain_connector_listener, listener_accept_stop)
{
+ set_normal_status();
UnixDomainConnectorConfig cfg;
cfg.direction = Connector::CONN_DUPLEX;
cfg.connector_name = "unixdomain";
test_listener_config = nullptr;
}
+UnixDomainConnectorReconnectHelper* reconnect_helper = nullptr;
+UnixDomainConnector* test_reconnect_connector = nullptr;
+
+void reconnect_update_callback(UnixDomainConnector* connector, bool is_reconnecting)
+{
+ test_reconnect_connector = connector;
+}
+
+TEST_GROUP(unixdomain_connector_reconnect_helper)
+{
+ UnixDomainConnectorConfig reconnect_config;
+ int reconnect_sfd = 0;
+ void setup()
+ {
+ reconnect_config.direction = Connector::CONN_DUPLEX;
+ reconnect_config.connector_name = "unixdomain-reconnect";
+ reconnect_config.paths.push_back("/tmp/pub_sub_reconnect");
+ reconnect_config.setup = UnixDomainConnectorConfig::Setup::CALL;
+ reconnect_config.conn_retries = 2;
+ reconnect_config.async_receive = false;
+ reconnect_helper = new UnixDomainConnectorReconnectHelper(reconnect_config, reconnect_update_callback);
+ }
+
+ void teardown()
+ {
+ if (reconnect_helper)
+ {
+ delete reconnect_helper;
+ reconnect_helper = nullptr;
+ }
+ }
+};
+
+TEST(unixdomain_connector_reconnect_helper, connect_then_reconnect_call)
+{
+ set_normal_status();
+ reconnect_helper->connect("/tmp/pub_sub_reconnect", 0);
+
+ CHECK(test_reconnect_connector != nullptr);
+
+ s_poll_undesirable = true;
+
+ auto tmp_test_connector = test_reconnect_connector;
+
+ //trigger the reconnect
+ test_reconnect_connector->process_receive();
+
+ //collapse the reconnect_helper joining reconnect thread
+ delete reconnect_helper;
+ reconnect_helper = nullptr;
+
+ CHECK(test_reconnect_connector != nullptr);
+ CHECK(test_reconnect_connector != tmp_test_connector);
+ delete test_reconnect_connector;
+
+ delete tmp_test_connector;
+}
+
+TEST(unixdomain_connector_reconnect_helper, connect_in_other_thread)
+{
+ set_normal_status();
+ s_socket_return = -1;
+ s_socket_return_switch = true;
+ reconnect_helper->connect("/tmp/pub_sub_reconnect", 0);
+
+ delete reconnect_helper;
+ reconnect_helper = nullptr;
+
+ CHECK(test_reconnect_connector != nullptr);
+ delete test_reconnect_connector;
+ test_reconnect_connector = nullptr;
+}
+
+TEST(unixdomain_connector_reconnect_helper, is_reconnect_available_flag)
+{
+ CHECK(reconnect_helper->is_reconnect_enabled() == true);
+ reconnect_helper->set_reconnect_enabled(false);
+ CHECK(reconnect_helper->is_reconnect_enabled() == false);
+ reconnect_helper->set_reconnect_enabled(true);
+ CHECK(reconnect_helper->is_reconnect_enabled() == true);
+}
+
int main(int argc, char** argv)
{
int return_value = CommandLineTestRunner::RunAllTests(argc, argv);
/* Module *****************************************************************/
-static bool attempt_connection(int& sfd, const char* path) {
+static bool attempt_connection(int& sfd, const char* path, unsigned long timeout_sec) {
sfd = socket(AF_UNIX, SOCK_STREAM, 0);
if (sfd == -1) {
ErrorMessage("UnixDomainC: socket error: %s \n", strerror(errno));
strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1);
if (connect(sfd, (struct sockaddr*)&addr, sizeof(struct sockaddr_un)) == -1) {
- if (errno != EINPROGRESS) {
+ if (errno == EINPROGRESS) {
+ // Wait for the socket to be writable (connection established or failed)
+ fd_set writefds;
+ FD_ZERO(&writefds);
+ FD_SET(sfd, &writefds);
+
+ struct timeval tv;
+ tv.tv_sec = timeout_sec;
+ tv.tv_usec = 0;
+
+ int sel = select(sfd + 1, nullptr, &writefds, nullptr, &tv);
+ if (sel <= 0) {
+ ErrorMessage("UnixDomainC: connect timeout or select error: %s \n", strerror(errno));
+ close(sfd);
+ return false;
+ }
+
+ int so_error = 0;
+ socklen_t len = sizeof(so_error);
+ if (getsockopt(sfd, SOL_SOCKET, SO_ERROR, &so_error, &len) < 0 || so_error != 0) {
+ ErrorMessage("UnixDomainC: connect failed after select: %s \n", strerror(so_error ? so_error : errno));
+ close(sfd);
+ return false;
+ }
+ } else {
ErrorMessage("UnixDomainC: connect error: %s \n", strerror(errno));
close(sfd);
return false;
}
// Function to handle connection retries
-static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_t idx, UnixDomainConnectorUpdateHandler update_handler = nullptr) {
- if(update_handler)
+static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_t idx,
+ UnixDomainConnectorUpdateHandler update_handler = nullptr, UnixDomainConnectorReconnectHelper* reconnect_helper = nullptr) {
+ if (update_handler)
update_handler(nullptr, ( (cfg.conn_retries > 0) and (cfg.setup == UnixDomainConnectorConfig::Setup::CALL) ));
else
ConnectorManager::update_thread_connector(cfg.connector_name, idx, nullptr);
const char* path = paths[idx].c_str();
- if(cfg.setup == UnixDomainConnectorConfig::Setup::CALL)
+ if (cfg.setup == UnixDomainConnectorConfig::Setup::CALL)
{
LogMessage("UnixDomainC: Attempting to reconnect to %s\n", cfg.paths[idx].c_str());
uint32_t retry_count = 0;
while (retry_count < cfg.max_retries) {
+ if (reconnect_helper and reconnect_helper->is_reconnect_enabled() == false)
+ return;
+
int sfd;
- if (attempt_connection(sfd, path)) {
+ if (attempt_connection(sfd, path, cfg.connect_timeout_seconds)) {
// Connection successful
- UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx);
+ UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx, reconnect_helper);
LogMessage("UnixDomainC: Connected to %s\n", path);
- if(update_handler)
+ if (update_handler)
{
unixdomain_conn->set_update_handler(update_handler);
update_handler(unixdomain_conn, false);
}
static void start_retry_thread(const UnixDomainConnectorConfig& cfg, size_t idx, UnixDomainConnectorUpdateHandler update_handler = nullptr) {
- std::thread retry_thread(connection_retry_handler, cfg, idx, update_handler);
+ std::thread retry_thread(connection_retry_handler, cfg, idx, update_handler, nullptr);
retry_thread.detach();
}
-UnixDomainConnector::UnixDomainConnector(const UnixDomainConnectorConfig& unixdomain_connector_config, int sfd, size_t idx)
- : Connector(unixdomain_connector_config), sock_fd(sfd), run_thread(false), receive_thread(nullptr),
- receive_ring(new ReceiveRing(50)), instance_id(idx), cfg(unixdomain_connector_config) {
+UnixDomainConnector::UnixDomainConnector(const UnixDomainConnectorConfig& unixdomain_connector_config, int sfd, size_t idx, UnixDomainConnectorReconnectHelper* reconnect_helper)
+ : Connector(unixdomain_connector_config), sock_fd(sfd), run_thread(false), receive_thread(nullptr),
+ receive_ring(new ReceiveRing(50)), instance_id(idx), cfg(unixdomain_connector_config), reconnect_helper(reconnect_helper) {
if (unixdomain_connector_config.async_receive) {
start_receive_thread();
}
sock_fd = -1;
}
- start_retry_thread(cfg, instance_id, update_handler);
+ if (reconnect_helper)
+ reconnect_helper->reconnect(instance_id);
+ else
+ start_retry_thread(cfg, instance_id, update_handler);
+
return;
}
else if (rval > 0 && pfds[0].revents & POLLIN) {
UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx, const UnixDomainConnectorUpdateHandler& update_handler) {
int sfd;
- if (!attempt_connection(sfd, path)) {
+ if (!attempt_connection(sfd, path, 0)) {
if (cfg.conn_retries) {
// Spawn a new thread to handle connection retries
start_retry_thread(cfg, idx, update_handler);
accept_thread = nullptr;
}
}
+
+UnixDomainConnectorReconnectHelper::~UnixDomainConnectorReconnectHelper()
+{
+ if(connection_thread)
+ {
+ if (connection_thread->joinable())
+ connection_thread->join();
+ delete connection_thread;
+ connection_thread = nullptr;
+ }
+}
+
+void UnixDomainConnectorReconnectHelper::connect(const char* path, size_t idx)
+{
+ int sfd;
+ if (!attempt_connection(sfd, path, cfg.connect_timeout_seconds)) {
+ if (cfg.conn_retries) {
+
+ connection_thread = new std::thread(connection_retry_handler, cfg, idx, update_handler, this);
+ return;
+ } else {
+ close(sfd);
+ return;
+ }
+ }
+ if(update_handler)
+ {
+ LogMessage("UnixDomainC: Connected to %s\n", path);
+ UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx, this);
+ unixdomain_conn->set_update_handler(update_handler);
+ update_handler(unixdomain_conn, false);
+ }
+ else
+ {
+ assert(true);
+ close(sfd);
+ }
+}
+
+void UnixDomainConnectorReconnectHelper::reconnect(size_t idx)
+{
+ if(!reconnect_enabled.load())
+ {
+ return;
+ }
+ if (connection_thread)
+ {
+ if (connection_thread->joinable())
+ connection_thread->join();
+ delete connection_thread;
+ connection_thread = nullptr;
+ }
+
+ connection_thread = new std::thread(connection_retry_handler, cfg, idx, update_handler, this);
+}
+
+void UnixDomainConnectorReconnectHelper::set_reconnect_enabled(bool enabled)
+{
+ reconnect_enabled.store(enabled);
+}
+
+bool UnixDomainConnectorReconnectHelper::is_reconnect_enabled() const
+{
+ return reconnect_enabled.load();
+}
};
class UnixDomainConnector;
+class UnixDomainConnectorReconnectHelper;
typedef std::function<void (UnixDomainConnector*,bool)> UnixDomainConnectorUpdateHandler;
typedef std::function<void ()> UnixDomainConnectorMessageReceivedHandler;
class UnixDomainConnector : public snort::Connector
{
public:
- UnixDomainConnector(const UnixDomainConnectorConfig& config, int sfd, size_t idx);
+ UnixDomainConnector(const UnixDomainConnectorConfig& config, int sfd, size_t idx, UnixDomainConnectorReconnectHelper* reconnect_helper = nullptr);
~UnixDomainConnector() override;
bool transmit_message(const snort::ConnectorMsg&, const ID& = null) override;
void set_update_handler(UnixDomainConnectorUpdateHandler handler);
void set_message_received_handler(UnixDomainConnectorMessageReceivedHandler handler);
+ void start_receive_thread();
+
int sock_fd;
private:
typedef Ring<snort::ConnectorMsg*> ReceiveRing;
- void start_receive_thread();
void stop_receive_thread();
void receive_processing_thread();
UnixDomainConnectorUpdateHandler update_handler;
UnixDomainConnectorMessageReceivedHandler message_received_handler;
+
+ UnixDomainConnectorReconnectHelper* reconnect_helper;
};
typedef std::function<void (UnixDomainConnector*, UnixDomainConnectorConfig*)> UnixDomainConnectorAcceptHandler;
};
+class UnixDomainConnectorReconnectHelper
+{
+ public:
+ UnixDomainConnectorReconnectHelper(const UnixDomainConnectorConfig& cfg, const UnixDomainConnectorUpdateHandler& update_handler = nullptr)
+ : cfg(cfg), update_handler(update_handler), connection_thread(nullptr), reconnect_enabled(true) { }
+ ~UnixDomainConnectorReconnectHelper();
+
+ void connect(const char* path, size_t idx);
+ void reconnect(size_t idx);
+ void set_reconnect_enabled(bool enabled);
+ bool is_reconnect_enabled() const;
+
+ private:
+ UnixDomainConnectorConfig cfg;
+ UnixDomainConnectorUpdateHandler update_handler;
+ std::thread* connection_thread;
+ std::atomic<bool> reconnect_enabled;
+};
+
extern SO_PUBLIC UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx, const UnixDomainConnectorUpdateHandler& update_handler = nullptr);
#endif // UNIXDOMAIN_CONNECTOR_H
std::lock_guard<std::mutex> guard_send(_send_mutex);
std::lock_guard<std::mutex> guard_read(_read_mutex);
+ if(!this->is_running.load())
+ return;
+
transport_stats.successful_connections++;
auto side_channel = new SideChannel(ScMsgFormat::BINARY);
connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this));
this->side_channels.push_back(new SideChannelHandle(side_channel, cfg, channel_id));
connector->set_update_handler(std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel));
+ connector->start_receive_thread();
}
MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig *c, MPUnixTransportStats& stats) : MPTransport(),
this->consume_message_received = true;
}
-void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connector, bool is_recconecting, SideChannel *side_channel)
+void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connector, bool is_reconecting, SideChannel *side_channel)
{
std::lock_guard<std::mutex> guard_send(_send_mutex);
std::lock_guard<std::mutex> guard_read(_read_mutex);
+
+ if(!this->is_running.load())
+ return;
+
if (side_channel->connector_receive)
{
delete side_channel->connector_receive;
{
connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this));
side_channel->connector_receive = side_channel->connector_transmit = connector;
+ connector->start_receive_thread();
this->transport_stats.successful_connections++;
}
else
{
- if (is_recconecting == false)
+ if (is_reconecting == false)
{
MPTransportLog("Accepted connection interrupted, removing handle\n");
for(auto it = this->side_channels.begin(); it != this->side_channels.end(); ++it)
if (config->max_processes < 2)
return;
+ if (this->is_running.load())
+ return;
+
auto instance_id = mp_current_process_id = Snort::get_process_id();//Snort instance id
auto max_processes = config->max_processes;
UnixDomainConnectorConfig* unix_config = new UnixDomainConnectorConfig();
unix_config->setup = UnixDomainConnectorConfig::Setup::ANSWER;
- unix_config->async_receive = true;
+ unix_config->async_receive = false;
if (config->conn_retries)
{
unix_config->conn_retries = config->conn_retries;
UnixDomainConnectorConfig* connector_conf = new UnixDomainConnectorConfig();
connector_conf->setup = UnixDomainConnectorConfig::Setup::CALL;
- connector_conf->async_receive = true;
+ connector_conf->async_receive = false;
connector_conf->conn_retries = config->conn_retries;
connector_conf->retry_interval = config->retry_interval_seconds;
connector_conf->max_retries = config->max_retries;
connector_conf->connect_timeout_seconds = config->connect_timeout_seconds;
connector_conf->paths.push_back(send_path);
- unixdomain_connector_tinit_call(*connector_conf, send_path.c_str(), 0, std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel));
+ UnixDomainConnectorReconnectHelper* reconnect_helper = new UnixDomainConnectorReconnectHelper(*connector_conf,
+ std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel));
+
+ reconnect_helper->connect(send_path.c_str(), 0);
- this->side_channels.push_back( new SideChannelHandle(side_channel, connector_conf, i));
+ this->side_channels.push_back( new SideChannelHandle(side_channel, connector_conf, i, reconnect_helper));
}
+ assert(!this->consume_thread);
this->consume_thread = new std::thread(&MPUnixDomainTransport::process_messages_from_side_channels, this);
}
SideChannelHandle::~SideChannelHandle()
{
+ if(reconnect_helper)
+ reconnect_helper->set_reconnect_enabled(false);
+
if (side_channel)
{
if (side_channel->connector_receive)
delete side_channel;
}
+
+ if(reconnect_helper)
+ delete reconnect_helper;
if (connector_config)
delete connector_config;
struct SideChannelHandle
{
- SideChannelHandle(SideChannel* sc, UnixDomainConnectorConfig* cc, const unsigned short& channel_id) :
- side_channel(sc), connector_config(cc), channel_id(channel_id)
+ SideChannelHandle(SideChannel* sc, UnixDomainConnectorConfig* cc, const unsigned short& channel_id,
+ UnixDomainConnectorReconnectHelper* reconnect_helper = nullptr) :
+ side_channel(sc), connector_config(cc), channel_id(channel_id), reconnect_helper(reconnect_helper)
{ }
~SideChannelHandle();
SideChannel* side_channel;
UnixDomainConnectorConfig* connector_config;
unsigned short channel_id;
+ UnixDomainConnectorReconnectHelper* reconnect_helper;
};
struct UnixAcceptorHandle
void handle_new_connection(UnixDomainConnector* connector, UnixDomainConnectorConfig* cfg, const unsigned short& channel_id);
void process_messages_from_side_channels();
void notify_process_thread();
- void connector_update_handler(UnixDomainConnector* connector, bool is_recconecting, SideChannel* side_channel);
+ void connector_update_handler(UnixDomainConnector* connector, bool is_reconecting, SideChannel* side_channel);
void MPTransportLog(const char* msg, ...);
MPSerializeFunc get_event_serialization_function(unsigned pub_id, unsigned event_id);
test_update_handler = h;
}
+void UnixDomainConnector::start_receive_thread()
+{
+
+}
+
static snort::ConnectorMsg* test_msg_answer = nullptr;
static snort::ConnectorMsg* test_msg_call = nullptr;
static uint8_t* test_msg_call_data = nullptr;
}
return snort::ConnectorMsg();
}
-UnixDomainConnector::UnixDomainConnector(const UnixDomainConnectorConfig& config, int sfd, size_t idx) : Connector(config) // cppcheck-suppress uninitMemberVar
+UnixDomainConnector::UnixDomainConnector(const UnixDomainConnectorConfig& config, int sfd, size_t idx, UnixDomainConnectorReconnectHelper*) : Connector(config) // cppcheck-suppress uninitMemberVar
{ cfg = config; } // cppcheck-suppress useInitializationList
UnixDomainConnector::~UnixDomainConnector()
{
close(0);
}
+UnixDomainConnectorReconnectHelper::~UnixDomainConnectorReconnectHelper()
+{}
+
+void UnixDomainConnectorReconnectHelper::connect(const char* path, size_t idx)
+{
+ unixdomain_connector_tinit_call(cfg, path, idx, update_handler);
+}
+
+void UnixDomainConnectorReconnectHelper::reconnect(size_t idx)
+{}
+
+void UnixDomainConnectorReconnectHelper::set_reconnect_enabled(bool enabled)
+{
+ reconnect_enabled.store(enabled);
+}
+
UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx, const UnixDomainConnectorUpdateHandler& update_handler)
{
if(cfg.setup == UnixDomainConnectorConfig::Setup::CALL)