From 458d7268e91539c81463eda2054cc45a67c3d406 Mon Sep 17 00:00:00 2001 From: "Oleksandr Stepanov -X (ostepano - SOFTSERVE INC at Cisco)" Date: Fri, 30 May 2025 20:05:38 +0000 Subject: [PATCH] Pull request #4760: mp_unix_transport: refactored socket reconnect Merge in SNORT/snort3 from ~OSTEPANO/snort3:transport_asan to master Squashed commit of the following: commit e87ec546921a79a5e92e2c7dc59806768d1ea074 Author: Oleksandr Stepanov Date: Mon May 26 12:12:00 2025 -0400 mp_unix_transport: refactored socket reconnect --- .../test/unixdomain_connector_test.cc | 106 ++++++++++++++- .../unixdomain_connector.cc | 125 ++++++++++++++++-- .../unixdomain_connector.h | 27 +++- .../mp_unix_transport/mp_unix_transport.cc | 34 ++++- .../mp_unix_transport/mp_unix_transport.h | 8 +- .../test/unix_transport_test.cc | 23 +++- 6 files changed, 295 insertions(+), 28 deletions(-) diff --git a/src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc b/src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc index 3fd572388..f9d5c2d17 100644 --- a/src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc +++ b/src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc @@ -46,6 +46,7 @@ static unsigned s_instance = 0; static unsigned char* s_rec_message = nullptr; static size_t s_rec_message_size = 0; static int s_socket_return = 1; +static bool s_socket_return_switch = false; static int s_bind_return = 0; static int s_listen_return = 0; static int s_accept_return = 2; @@ -140,11 +141,29 @@ ssize_t recv (int, void *buf, size_t n, int) } #ifdef __GLIBC__ -int socket (int, int, int) __THROW { return s_socket_return; } +int socket (int, int, int) __THROW +{ + if(s_socket_return_switch) + { + auto tmp_ret = s_socket_return; + s_socket_return = s_socket_return > 0 ? -1 : 1; + return tmp_ret; + } + return s_socket_return; +} int bind (int, const struct sockaddr*, socklen_t) __THROW { return s_bind_return; } int listen (int, int) __THROW { return s_listen_return; } #else -int socket (int, int, int) { return s_socket_return; } +int socket (int, int, int) +{ + if(s_socket_return_switch) + { + auto tmp_ret = s_socket_return; + s_socket_return = s_socket_return > 0 ? -1 : 1; + return tmp_ret; + } + return s_socket_return; +} int bind (int, const struct sockaddr*, socklen_t) { return s_bind_return; } int listen (int, int) { return s_listen_return; } #endif @@ -798,6 +817,7 @@ void connection_callback(UnixDomainConnector* c, UnixDomainConnectorConfig* conf TEST(unixdomain_connector_listener, listener_accept_stop) { + set_normal_status(); UnixDomainConnectorConfig cfg; cfg.direction = Connector::CONN_DUPLEX; cfg.connector_name = "unixdomain"; @@ -823,6 +843,88 @@ TEST(unixdomain_connector_listener, listener_accept_stop) test_listener_config = nullptr; } +UnixDomainConnectorReconnectHelper* reconnect_helper = nullptr; +UnixDomainConnector* test_reconnect_connector = nullptr; + +void reconnect_update_callback(UnixDomainConnector* connector, bool is_reconnecting) +{ + test_reconnect_connector = connector; +} + +TEST_GROUP(unixdomain_connector_reconnect_helper) +{ + UnixDomainConnectorConfig reconnect_config; + int reconnect_sfd = 0; + void setup() + { + reconnect_config.direction = Connector::CONN_DUPLEX; + reconnect_config.connector_name = "unixdomain-reconnect"; + reconnect_config.paths.push_back("/tmp/pub_sub_reconnect"); + reconnect_config.setup = UnixDomainConnectorConfig::Setup::CALL; + reconnect_config.conn_retries = 2; + reconnect_config.async_receive = false; + reconnect_helper = new UnixDomainConnectorReconnectHelper(reconnect_config, reconnect_update_callback); + } + + void teardown() + { + if (reconnect_helper) + { + delete reconnect_helper; + reconnect_helper = nullptr; + } + } +}; + +TEST(unixdomain_connector_reconnect_helper, connect_then_reconnect_call) +{ + set_normal_status(); + reconnect_helper->connect("/tmp/pub_sub_reconnect", 0); + + CHECK(test_reconnect_connector != nullptr); + + s_poll_undesirable = true; + + auto tmp_test_connector = test_reconnect_connector; + + //trigger the reconnect + test_reconnect_connector->process_receive(); + + //collapse the reconnect_helper joining reconnect thread + delete reconnect_helper; + reconnect_helper = nullptr; + + CHECK(test_reconnect_connector != nullptr); + CHECK(test_reconnect_connector != tmp_test_connector); + delete test_reconnect_connector; + + delete tmp_test_connector; +} + +TEST(unixdomain_connector_reconnect_helper, connect_in_other_thread) +{ + set_normal_status(); + s_socket_return = -1; + s_socket_return_switch = true; + reconnect_helper->connect("/tmp/pub_sub_reconnect", 0); + + delete reconnect_helper; + reconnect_helper = nullptr; + + CHECK(test_reconnect_connector != nullptr); + delete test_reconnect_connector; + test_reconnect_connector = nullptr; +} + +TEST(unixdomain_connector_reconnect_helper, is_reconnect_available_flag) +{ + CHECK(reconnect_helper->is_reconnect_enabled() == true); + reconnect_helper->set_reconnect_enabled(false); + CHECK(reconnect_helper->is_reconnect_enabled() == false); + reconnect_helper->set_reconnect_enabled(true); + CHECK(reconnect_helper->is_reconnect_enabled() == true); +} + int main(int argc, char** argv) { int return_value = CommandLineTestRunner::RunAllTests(argc, argv); diff --git a/src/connectors/unixdomain_connector/unixdomain_connector.cc b/src/connectors/unixdomain_connector/unixdomain_connector.cc index ddfc6f036..97bb769d2 100644 --- a/src/connectors/unixdomain_connector/unixdomain_connector.cc +++ b/src/connectors/unixdomain_connector/unixdomain_connector.cc @@ -46,7 +46,7 @@ THREAD_LOCAL ProfileStats unixdomain_connector_perfstats; /* Module *****************************************************************/ -static bool attempt_connection(int& sfd, const char* path) { +static bool attempt_connection(int& sfd, const char* path, unsigned long timeout_sec) { sfd = socket(AF_UNIX, SOCK_STREAM, 0); if (sfd == -1) { ErrorMessage("UnixDomainC: socket error: %s \n", strerror(errno)); @@ -73,7 +73,31 @@ static bool attempt_connection(int& sfd, const char* path) { strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1); if (connect(sfd, (struct sockaddr*)&addr, sizeof(struct sockaddr_un)) == -1) { - if (errno != EINPROGRESS) { + if (errno == EINPROGRESS) { + // Wait for the socket to be writable (connection established or failed) + fd_set writefds; + FD_ZERO(&writefds); + FD_SET(sfd, &writefds); + + struct timeval tv; + tv.tv_sec = timeout_sec; + tv.tv_usec = 0; + + int sel = select(sfd + 1, nullptr, &writefds, nullptr, &tv); + if (sel <= 0) { + ErrorMessage("UnixDomainC: connect timeout or select error: %s \n", strerror(errno)); + close(sfd); + return false; + } + + int so_error = 0; + socklen_t len = sizeof(so_error); + if (getsockopt(sfd, SOL_SOCKET, SO_ERROR, &so_error, &len) < 0 || so_error != 0) { + ErrorMessage("UnixDomainC: connect failed after select: %s \n", strerror(so_error ? so_error : errno)); + close(sfd); + return false; + } + } else { ErrorMessage("UnixDomainC: connect error: %s \n", strerror(errno)); close(sfd); return false; @@ -83,8 +107,9 @@ static bool attempt_connection(int& sfd, const char* path) { } // Function to handle connection retries -static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_t idx, UnixDomainConnectorUpdateHandler update_handler = nullptr) { - if(update_handler) +static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_t idx, + UnixDomainConnectorUpdateHandler update_handler = nullptr, UnixDomainConnectorReconnectHelper* reconnect_helper = nullptr) { + if (update_handler) update_handler(nullptr, ( (cfg.conn_retries > 0) and (cfg.setup == UnixDomainConnectorConfig::Setup::CALL) )); else ConnectorManager::update_thread_connector(cfg.connector_name, idx, nullptr); @@ -99,19 +124,22 @@ static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_ const char* path = paths[idx].c_str(); - if(cfg.setup == UnixDomainConnectorConfig::Setup::CALL) + if (cfg.setup == UnixDomainConnectorConfig::Setup::CALL) { LogMessage("UnixDomainC: Attempting to reconnect to %s\n", cfg.paths[idx].c_str()); uint32_t retry_count = 0; while (retry_count < cfg.max_retries) { + if (reconnect_helper and reconnect_helper->is_reconnect_enabled() == false) + return; + int sfd; - if (attempt_connection(sfd, path)) { + if (attempt_connection(sfd, path, cfg.connect_timeout_seconds)) { // Connection successful - UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx); + UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx, reconnect_helper); LogMessage("UnixDomainC: Connected to %s\n", path); - if(update_handler) + if (update_handler) { unixdomain_conn->set_update_handler(update_handler); update_handler(unixdomain_conn, false); @@ -137,13 +165,13 @@ static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_ } static void start_retry_thread(const UnixDomainConnectorConfig& cfg, size_t idx, UnixDomainConnectorUpdateHandler update_handler = nullptr) { - std::thread retry_thread(connection_retry_handler, cfg, idx, update_handler); + std::thread retry_thread(connection_retry_handler, cfg, idx, update_handler, nullptr); retry_thread.detach(); } -UnixDomainConnector::UnixDomainConnector(const UnixDomainConnectorConfig& unixdomain_connector_config, int sfd, size_t idx) - : Connector(unixdomain_connector_config), sock_fd(sfd), run_thread(false), receive_thread(nullptr), - receive_ring(new ReceiveRing(50)), instance_id(idx), cfg(unixdomain_connector_config) { +UnixDomainConnector::UnixDomainConnector(const UnixDomainConnectorConfig& unixdomain_connector_config, int sfd, size_t idx, UnixDomainConnectorReconnectHelper* reconnect_helper) + : Connector(unixdomain_connector_config), sock_fd(sfd), run_thread(false), receive_thread(nullptr), + receive_ring(new ReceiveRing(50)), instance_id(idx), cfg(unixdomain_connector_config), reconnect_helper(reconnect_helper) { if (unixdomain_connector_config.async_receive) { start_receive_thread(); } @@ -278,7 +306,11 @@ void UnixDomainConnector::process_receive() { sock_fd = -1; } - start_retry_thread(cfg, instance_id, update_handler); + if (reconnect_helper) + reconnect_helper->reconnect(instance_id); + else + start_retry_thread(cfg, instance_id, update_handler); + return; } else if (rval > 0 && pfds[0].revents & POLLIN) { @@ -387,7 +419,7 @@ static void mod_dtor(Module* m) { UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx, const UnixDomainConnectorUpdateHandler& update_handler) { int sfd; - if (!attempt_connection(sfd, path)) { + if (!attempt_connection(sfd, path, 0)) { if (cfg.conn_retries) { // Spawn a new thread to handle connection retries start_retry_thread(cfg, idx, update_handler); @@ -629,3 +661,68 @@ void UnixDomainConnectorListener::stop_accepting_connections() accept_thread = nullptr; } } + +UnixDomainConnectorReconnectHelper::~UnixDomainConnectorReconnectHelper() +{ + if(connection_thread) + { + if (connection_thread->joinable()) + connection_thread->join(); + delete connection_thread; + connection_thread = nullptr; + } +} + +void UnixDomainConnectorReconnectHelper::connect(const char* path, size_t idx) +{ + int sfd; + if (!attempt_connection(sfd, path, cfg.connect_timeout_seconds)) { + if (cfg.conn_retries) { + + connection_thread = new std::thread(connection_retry_handler, cfg, idx, update_handler, this); + return; + } else { + close(sfd); + return; + } + } + if(update_handler) + { + LogMessage("UnixDomainC: Connected to %s\n", path); + UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx, this); + unixdomain_conn->set_update_handler(update_handler); + update_handler(unixdomain_conn, false); + } + else + { + assert(true); + close(sfd); + } +} + +void UnixDomainConnectorReconnectHelper::reconnect(size_t idx) +{ + if(!reconnect_enabled.load()) + { + return; + } + if (connection_thread) + { + if (connection_thread->joinable()) + connection_thread->join(); + delete connection_thread; + connection_thread = nullptr; + } + + connection_thread = new std::thread(connection_retry_handler, cfg, idx, update_handler, this); +} + +void UnixDomainConnectorReconnectHelper::set_reconnect_enabled(bool enabled) +{ + reconnect_enabled.store(enabled); +} + +bool UnixDomainConnectorReconnectHelper::is_reconnect_enabled() const +{ + return reconnect_enabled.load(); +} diff --git a/src/connectors/unixdomain_connector/unixdomain_connector.h b/src/connectors/unixdomain_connector/unixdomain_connector.h index 08339e30f..d0be0804a 100644 --- a/src/connectors/unixdomain_connector/unixdomain_connector.h +++ b/src/connectors/unixdomain_connector/unixdomain_connector.h @@ -50,6 +50,7 @@ public: }; class UnixDomainConnector; +class UnixDomainConnectorReconnectHelper; typedef std::function UnixDomainConnectorUpdateHandler; typedef std::function UnixDomainConnectorMessageReceivedHandler; @@ -57,7 +58,7 @@ typedef std::function UnixDomainConnectorMessageReceivedHandler; class UnixDomainConnector : public snort::Connector { public: - UnixDomainConnector(const UnixDomainConnectorConfig& config, int sfd, size_t idx); + UnixDomainConnector(const UnixDomainConnectorConfig& config, int sfd, size_t idx, UnixDomainConnectorReconnectHelper* reconnect_helper = nullptr); ~UnixDomainConnector() override; bool transmit_message(const snort::ConnectorMsg&, const ID& = null) override; @@ -69,12 +70,13 @@ public: void set_update_handler(UnixDomainConnectorUpdateHandler handler); void set_message_received_handler(UnixDomainConnectorMessageReceivedHandler handler); + void start_receive_thread(); + int sock_fd; private: typedef Ring ReceiveRing; - void start_receive_thread(); void stop_receive_thread(); void receive_processing_thread(); @@ -89,6 +91,8 @@ private: UnixDomainConnectorUpdateHandler update_handler; UnixDomainConnectorMessageReceivedHandler message_received_handler; + + UnixDomainConnectorReconnectHelper* reconnect_helper; }; typedef std::function UnixDomainConnectorAcceptHandler; @@ -110,6 +114,25 @@ public: }; +class UnixDomainConnectorReconnectHelper +{ + public: + UnixDomainConnectorReconnectHelper(const UnixDomainConnectorConfig& cfg, const UnixDomainConnectorUpdateHandler& update_handler = nullptr) + : cfg(cfg), update_handler(update_handler), connection_thread(nullptr), reconnect_enabled(true) { } + ~UnixDomainConnectorReconnectHelper(); + + void connect(const char* path, size_t idx); + void reconnect(size_t idx); + void set_reconnect_enabled(bool enabled); + bool is_reconnect_enabled() const; + + private: + UnixDomainConnectorConfig cfg; + UnixDomainConnectorUpdateHandler update_handler; + std::thread* connection_thread; + std::atomic reconnect_enabled; +}; + extern SO_PUBLIC UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx, const UnixDomainConnectorUpdateHandler& update_handler = nullptr); #endif // UNIXDOMAIN_CONNECTOR_H diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport.cc b/src/mp_transport/mp_unix_transport/mp_unix_transport.cc index ca550658e..c9c8354ef 100644 --- a/src/mp_transport/mp_unix_transport/mp_unix_transport.cc +++ b/src/mp_transport/mp_unix_transport/mp_unix_transport.cc @@ -113,6 +113,9 @@ void MPUnixDomainTransport::handle_new_connection(UnixDomainConnector *connector std::lock_guard guard_send(_send_mutex); std::lock_guard guard_read(_read_mutex); + if(!this->is_running.load()) + return; + transport_stats.successful_connections++; auto side_channel = new SideChannel(ScMsgFormat::BINARY); @@ -122,6 +125,7 @@ void MPUnixDomainTransport::handle_new_connection(UnixDomainConnector *connector connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this)); this->side_channels.push_back(new SideChannelHandle(side_channel, cfg, channel_id)); connector->set_update_handler(std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel)); + connector->start_receive_thread(); } MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig *c, MPUnixTransportStats& stats) : MPTransport(), @@ -235,10 +239,14 @@ void MPUnixDomainTransport::notify_process_thread() this->consume_message_received = true; } -void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connector, bool is_recconecting, SideChannel *side_channel) +void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connector, bool is_reconecting, SideChannel *side_channel) { std::lock_guard guard_send(_send_mutex); std::lock_guard guard_read(_read_mutex); + + if(!this->is_running.load()) + return; + if (side_channel->connector_receive) { delete side_channel->connector_receive; @@ -249,11 +257,12 @@ void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connec { connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this)); side_channel->connector_receive = side_channel->connector_transmit = connector; + connector->start_receive_thread(); this->transport_stats.successful_connections++; } else { - if (is_recconecting == false) + if (is_reconecting == false) { MPTransportLog("Accepted connection interrupted, removing handle\n"); for(auto it = this->side_channels.begin(); it != this->side_channels.end(); ++it) @@ -369,6 +378,9 @@ void MPUnixDomainTransport::init_side_channels() if (config->max_processes < 2) return; + if (this->is_running.load()) + return; + auto instance_id = mp_current_process_id = Snort::get_process_id();//Snort instance id auto max_processes = config->max_processes; @@ -392,7 +404,7 @@ void MPUnixDomainTransport::init_side_channels() UnixDomainConnectorConfig* unix_config = new UnixDomainConnectorConfig(); unix_config->setup = UnixDomainConnectorConfig::Setup::ANSWER; - unix_config->async_receive = true; + unix_config->async_receive = false; if (config->conn_retries) { unix_config->conn_retries = config->conn_retries; @@ -426,18 +438,22 @@ void MPUnixDomainTransport::init_side_channels() UnixDomainConnectorConfig* connector_conf = new UnixDomainConnectorConfig(); connector_conf->setup = UnixDomainConnectorConfig::Setup::CALL; - connector_conf->async_receive = true; + connector_conf->async_receive = false; connector_conf->conn_retries = config->conn_retries; connector_conf->retry_interval = config->retry_interval_seconds; connector_conf->max_retries = config->max_retries; connector_conf->connect_timeout_seconds = config->connect_timeout_seconds; connector_conf->paths.push_back(send_path); - unixdomain_connector_tinit_call(*connector_conf, send_path.c_str(), 0, std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel)); + UnixDomainConnectorReconnectHelper* reconnect_helper = new UnixDomainConnectorReconnectHelper(*connector_conf, + std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel)); + + reconnect_helper->connect(send_path.c_str(), 0); - this->side_channels.push_back( new SideChannelHandle(side_channel, connector_conf, i)); + this->side_channels.push_back( new SideChannelHandle(side_channel, connector_conf, i, reconnect_helper)); } + assert(!this->consume_thread); this->consume_thread = new std::thread(&MPUnixDomainTransport::process_messages_from_side_channels, this); } @@ -456,6 +472,9 @@ void MPUnixDomainTransport::cleanup_side_channels() SideChannelHandle::~SideChannelHandle() { + if(reconnect_helper) + reconnect_helper->set_reconnect_enabled(false); + if (side_channel) { if (side_channel->connector_receive) @@ -463,6 +482,9 @@ SideChannelHandle::~SideChannelHandle() delete side_channel; } + + if(reconnect_helper) + delete reconnect_helper; if (connector_config) delete connector_config; diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport.h b/src/mp_transport/mp_unix_transport/mp_unix_transport.h index f3c687d09..5d69dbe10 100644 --- a/src/mp_transport/mp_unix_transport/mp_unix_transport.h +++ b/src/mp_transport/mp_unix_transport/mp_unix_transport.h @@ -83,8 +83,9 @@ struct SerializeFunctionHandle struct SideChannelHandle { - SideChannelHandle(SideChannel* sc, UnixDomainConnectorConfig* cc, const unsigned short& channel_id) : - side_channel(sc), connector_config(cc), channel_id(channel_id) + SideChannelHandle(SideChannel* sc, UnixDomainConnectorConfig* cc, const unsigned short& channel_id, + UnixDomainConnectorReconnectHelper* reconnect_helper = nullptr) : + side_channel(sc), connector_config(cc), channel_id(channel_id), reconnect_helper(reconnect_helper) { } ~SideChannelHandle(); @@ -92,6 +93,7 @@ struct SideChannelHandle SideChannel* side_channel; UnixDomainConnectorConfig* connector_config; unsigned short channel_id; + UnixDomainConnectorReconnectHelper* reconnect_helper; }; struct UnixAcceptorHandle @@ -133,7 +135,7 @@ class MPUnixDomainTransport : public MPTransport void handle_new_connection(UnixDomainConnector* connector, UnixDomainConnectorConfig* cfg, const unsigned short& channel_id); void process_messages_from_side_channels(); void notify_process_thread(); - void connector_update_handler(UnixDomainConnector* connector, bool is_recconecting, SideChannel* side_channel); + void connector_update_handler(UnixDomainConnector* connector, bool is_reconecting, SideChannel* side_channel); void MPTransportLog(const char* msg, ...); MPSerializeFunc get_event_serialization_function(unsigned pub_id, unsigned event_id); diff --git a/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc b/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc index bad3074ac..133e9602a 100644 --- a/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc +++ b/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc @@ -108,6 +108,11 @@ void UnixDomainConnector::set_update_handler(std::function