From: Steve Chew (stechew) Date: Wed, 2 Jun 2021 16:47:38 +0000 (+0000) Subject: Merge pull request #2847 in SNORT/snort3 from ~SBAIGAL/snort3:control to master X-Git-Tag: 3.1.6.0~33 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=086b8adf78ec4555529f37ca98aef3a9f1aa0edc;p=thirdparty%2Fsnort3.git Merge pull request #2847 in SNORT/snort3 from ~SBAIGAL/snort3:control to master Squashed commit of the following: commit f796ba1326bf9713867d2bde5234273887282f98 Author: Steven Baigal (sbaigal) Date: Thu Apr 22 14:56:59 2021 -0400 control: expose ContrlConn API commit 3d0c000b8d0652bec02df2a08db9f23d2be971ec Author: Michael Altizer Date: Tue Feb 23 12:38:31 2021 -0500 control: Remove unused IdleProcessing functionality commit 90df551fac422ae1bf5ddee21a0d040dd111373c Author: Michael Altizer Date: Thu Nov 14 12:05:10 2019 -0500 control: refactor control channel management to better handle control responses commit 5d017cb4965f875f80dc5bf8edc3d074128f4c4e Author: Michael Altizer Date: Wed Feb 10 12:05:49 2021 -0500 Revert "Merge pull request #2639 in SNORT/snort3 from ~STECHEW/snort3:control_request_fix_shared_ptr to master" This reverts commit e7250bd6995941337e37529fd8594093de4db2ef. --- diff --git a/src/control/CMakeLists.txt b/src/control/CMakeLists.txt index 1e6610057..f58c705de 100644 --- a/src/control/CMakeLists.txt +++ b/src/control/CMakeLists.txt @@ -1,11 +1,13 @@ +set ( CONTROL_INCLUDES control.h ) + +if ( ENABLE_SHELL ) + set ( SHELL_SOURCES control.cc control_mgmt.cc control_mgmt.h ) +endif ( ENABLE_SHELL ) add_library ( control OBJECT - idle_processing.h - idle_processing.cc + ${SHELL_SOURCES} ) -add_catch_test( idle_processing_test - NO_TEST_SOURCE - SOURCES - idle_processing.cc +install (FILES ${CONTROL_INCLUDES} + DESTINATION ${INCLUDE_INSTALL_PATH}/control ) diff --git a/src/control/control.cc b/src/control/control.cc new file mode 100644 index 000000000..e2d34cca1 --- /dev/null +++ b/src/control/control.cc @@ -0,0 +1,192 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2017-2020 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// control.cc author Bhagya Tholpady +// author Michael Altizer + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "control.h" + +#include "log/messages.h" +#include "main.h" +#include "main/shell.h" +#include "managers/module_manager.h" +#include "utils/util.h" + +#include "control_mgmt.h" + +using namespace snort; + + +ControlConn* ControlConn::query_from_lua(const lua_State* L) +{ + return ControlMgmt::find_control(L); +} + +//------------------------------------------------------------------------ +// control channel class +// ----------------------------------------------------------------------- + +ControlConn::ControlConn(int fd, bool local) : fd(fd), local(local) +{ + shell = new Shell; + configure(); + show_prompt(); +} + +ControlConn::~ControlConn() +{ + shutdown(); + delete shell; +} + +void ControlConn::shutdown() +{ + if (!local) + close(fd); + fd = -1; +} + +void ControlConn::configure() const +{ + ModuleManager::load_commands(shell); +} + +int ControlConn::read_commands() +{ + char buf[STD_BUF]; + int commands_found = 0; + ssize_t n = 0; + + while ((n = read(fd, buf, sizeof(buf) - 1)) > 0) + { + buf[n] = '\0'; + char* p = buf; + char* nl; + while ((nl = strchr(p, '\n')) != nullptr) + { + std::string command = next_command; + next_command.append(buf, nl - p); + pending_commands.push(std::move(next_command)); + next_command.clear(); + p = nl + 1; + commands_found++; + } + if (*p != '\0') + next_command.append(p); + else if (local) + { + // For stdin, we are only guaranteed to have some amount of input ending in a + // newline and future read() calls will block. To avoid blocking, assume that + // we're done reading if the input ended with a newline. + break; + } + } + + if (n < 0 && errno != EAGAIN && errno != EINTR) + { + ErrorMessage("Error reading from control descriptor: %s\n", get_error(errno)); + return -1; + } + if (n == 0 && commands_found == 0) + return -1; + + return commands_found; +} + +int ControlConn::execute_commands() +{ + int executed = 0; + while (!is_closed() && !blocked && !pending_commands.empty()) + { + const std::string& command = pending_commands.front(); + std::string rsp; + shell->execute(command.c_str(), rsp); + if (!rsp.empty()) + respond("%s", rsp.c_str()); + if (!blocked) + show_prompt(); + pending_commands.pop(); + executed++; + } + + return executed; +} + +void ControlConn::block() +{ + blocked = true; +} + +void ControlConn::unblock() +{ + if (blocked) + { + blocked = false; + execute_commands(); + if (!blocked && !show_prompt()) + shutdown(); + } +} + +// FIXIT-L would like to flush prompt w/o \n +bool ControlConn::show_prompt() +{ + return respond("%s\n", get_prompt()); +} + +bool ControlConn::respond(const char* format, va_list& ap) +{ + char buf[STD_BUF]; + int response_len = vsnprintf(buf, sizeof(buf), format, ap); + if (response_len < 0 || response_len == sizeof(buf)) + return false; + buf[response_len] = '\0'; + + int bytes_written = 0; + while (bytes_written < response_len) + { + ssize_t n = write(fd, buf + bytes_written, response_len - bytes_written); + if (n < 0) + { + if (errno != EAGAIN && errno != EINTR) + { + shutdown(); + return false; + } + } + else + bytes_written += n; + } + return true; +} + +bool ControlConn::respond(const char* format, ...) +{ + if (is_closed()) + return false; + + va_list ap; + va_start(ap, format); + bool ret = respond(format, ap); + va_end(ap); + + return ret; +} diff --git a/src/main/control.h b/src/control/control.h similarity index 67% rename from src/main/control.h rename to src/control/control.h index 563354bd6..42452a312 100644 --- a/src/main/control.h +++ b/src/control/control.h @@ -23,38 +23,54 @@ #ifndef CONTROL_H #define CONTROL_H -#include "main/request.h" +#include +#include +#include + #include "main/snort_types.h" +struct lua_State; + class ControlConn { public: - ControlConn(int fd, bool local_control = false); + ControlConn(int fd, bool local); ~ControlConn(); ControlConn(const ControlConn&) = delete; ControlConn& operator=(const ControlConn&) = delete; int get_fd() const { return fd; } - class Shell* get_shell() const { return sh; } - SharedRequest get_request() const { return request; } - bool is_local_control() const { return local_control; } + class Shell* get_shell() const { return shell; } void block(); void unblock(); - void send_queued_response(); + bool is_blocked() const { return blocked; } + bool is_closed() const { return (fd == -1); } + SO_PUBLIC bool is_local() const { return local; } + + bool has_pending_command() const { return !pending_commands.empty(); } void configure() const; - int shell_execute(int& current_fd, SharedRequest& current_request); - bool show_prompt() const; + int read_commands(); + int execute_commands(); + SO_PUBLIC bool respond(const char* format, va_list& ap); + SO_PUBLIC bool respond(const char* format, ...) __attribute__((format (printf, 2, 3))); + void shutdown(); + + SO_PUBLIC static ControlConn* query_from_lua(const lua_State*); + +private: + bool show_prompt(); private: + std::queue pending_commands; + std::string next_command; + class Shell *shell; int fd; + bool local; bool blocked = false; - bool local_control; - class Shell *sh; - SharedRequest request; }; #endif diff --git a/src/control/control_mgmt.cc b/src/control/control_mgmt.cc new file mode 100644 index 000000000..eb29f1696 --- /dev/null +++ b/src/control/control_mgmt.cc @@ -0,0 +1,474 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2017-2020 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// control_mgmt.cc author Bhagya Tholpady +// author Devendra Dahiphale +// author Michael Altizer + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "control_mgmt.h" + +#include +#include +#include +#include + +#include +#include + +#include "log/messages.h" +#include "main/shell.h" +#include "main/snort_config.h" +#include "utils/stats.h" +#include "utils/util.h" +#include "utils/util_cstring.h" + +#include "control.h" + +using namespace snort; + +static constexpr unsigned MAX_CONTROL_FDS = 16; + +static int listener = -1; +static socklen_t sock_addr_size = 0; +static struct sockaddr* sock_addr = nullptr; +static struct sockaddr_in in_addr; +static struct sockaddr_un unix_addr; +static std::unordered_map controls; + +#ifdef __linux__ + +//------------------------------------------------------------------------- +// Linux epoll descriptor polling implementation (Linux-only) +//------------------------------------------------------------------------- + +#include + +static int epoll_fd = -1; +static unsigned nfds; + +static bool init_controls() +{ + epoll_fd = epoll_create1(0); + if (epoll_fd == -1) + { + ErrorMessage("Failed to create epoll file descriptor: %s\n", get_error(errno)); + return false; + } + nfds = 0; + return true; +} + +static bool register_control_fd(const int fd) +{ + if (nfds == MAX_CONTROL_FDS) + { + WarningMessage("Failed to add file descriptor, exceed max (%d)\n", nfds); + return false; + } + + struct epoll_event event; + event.events = EPOLLIN; + event.data.fd = fd; + if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fd, &event)) + { + WarningMessage("Failed to add file descriptor %d to epoll(%d): %s\n", fd, epoll_fd, get_error(errno)); + return false; + } + + nfds++; + return true; +} + +static void unregister_control_fd(const int, const int curr_fd) +{ + // File descriptors are automatically removed from the epoll instance when they're closed + if (curr_fd != -1 && epoll_ctl(epoll_fd, EPOLL_CTL_DEL, curr_fd, nullptr)) + WarningMessage("Failed to remove file descriptor %d from epoll(%d): %s (%d)\n", curr_fd, epoll_fd, get_error(errno), errno); + nfds--; +} + +static bool poll_control_fds(int ready[MAX_CONTROL_FDS], unsigned& nready, int dead[MAX_CONTROL_FDS], unsigned& ndead) +{ + if (epoll_fd == -1 || nfds == 0) + return false; + + static struct epoll_event events[MAX_CONTROL_FDS]; + int ret = epoll_wait(epoll_fd, events, nfds, 0); + if (ret <= 0) + { + if (ret < 0 && errno != EINTR) + ErrorMessage("Failed to poll control descriptors: %s\n", get_error(errno)); + return false; + } + nready = ndead = 0; + for (int i = 0; i < ret; i++) + { + struct epoll_event* ev = &events[i]; + int fd = ev->data.fd; + if (ev->events & POLLIN) + ready[nready++] = fd; + if (ev->events & (POLLHUP | POLLERR)) + { + if (ev->events & POLLERR) + ErrorMessage("Failed to poll control descriptor %d!\n", fd); + dead[ndead++] = fd; + } + } + + return true; +} + +static void term_controls() +{ + if (epoll_fd >= 0) + { + close(epoll_fd); + epoll_fd = -1; + } +} + +#else + +//------------------------------------------------------------------------- +// POSIX poll descriptor polling implementation (default) +//------------------------------------------------------------------------- + +static struct pollfd pfds[MAX_CONTROL_FDS]; +static nfds_t npfds; + +static bool init_controls() +{ + npfds = 0; + return true; +} + +static bool register_control_fd(const int fd) +{ + if (npfds == MAX_CONTROL_FDS) + return false; + + struct pollfd* pfd = &pfds[npfds]; + pfd->fd = fd; + pfd->events = POLLIN; + npfds++; + + return true; +} + +static void unregister_control_fd(const int orig_fd, const int) +{ + for (nfds_t i = 0; i < npfds; i++) + { + if (pfds[i].fd == orig_fd) + { + npfds--; + // If this wasn't the last element, swap that in + if (i < npfds) + pfds[i].fd = pfds[npfds].fd; + break; + } + } +} + +static bool poll_control_fds(int ready[MAX_CONTROL_FDS], unsigned& nready, int dead[MAX_CONTROL_FDS], unsigned& ndead) +{ + if (npfds == 0) + return false; + + int ret = poll(pfds, npfds, 0); + if (ret <= 0) + { + if (ret < 0 && errno != EINTR) + ErrorMessage("Failed to poll control descriptors: %s\n", get_error(errno)); + return false; + } + nready = ndead = 0; + for (unsigned i = 0; i < MAX_CONTROL_FDS; i++) + { + struct pollfd* pfd = &pfds[i]; + int fd = pfd->fd; + if (pfd->revents & (POLLHUP | POLLERR | POLLNVAL)) + { + if (pfd->revents & (POLLERR | POLLNVAL)) + ErrorMessage("Failed to poll control descriptor %d!\n", fd); + dead[ndead++] = fd; + } + if (pfd->revents & POLLIN) + ready[nready++] = fd; + } + return true; +} + +static void term_controls() +{ + npfds = 0; +} + +#endif + +//------------------------------------------------------------------------- +// Platform agnostic private functions +//------------------------------------------------------------------------- + +// FIXIT-M make these non-blocking +// FIXIT-M bind to configured ip including INADDR_ANY +// (default is loopback if enabled) +static int setup_socket_family(const SnortConfig* sc) +{ + int family = AF_UNSPEC; + + if (sc->remote_control_port) + { + memset(&in_addr, 0, sizeof(in_addr)); + + in_addr.sin_family = AF_INET; + in_addr.sin_addr.s_addr = htonl(0x7F000001); + in_addr.sin_port = htons(sc->remote_control_port); + sock_addr = (struct sockaddr*)&in_addr; + sock_addr_size = sizeof(in_addr); + family = AF_INET; + } + else if (!sc->remote_control_socket.empty()) + { + std::string fullpath; + const char* path_sep = strrchr(sc->remote_control_socket.c_str(), '/'); + if (path_sep != nullptr) + fullpath = sc->remote_control_socket; + else + get_instance_file(fullpath, sc->remote_control_socket.c_str()); + + memset(&unix_addr, 0, sizeof(unix_addr)); + unix_addr.sun_family = AF_UNIX; + SnortStrncpy(unix_addr.sun_path, fullpath.c_str(), sizeof(unix_addr.sun_path)); + sock_addr = (struct sockaddr*)&unix_addr; + sock_addr_size = sizeof(unix_addr); + unlink(fullpath.c_str()); + family = AF_UNIX; + } + return family; +} + +static bool accept_conn() +{ + int fd = accept(listener, sock_addr, &sock_addr_size); + if (fd < 0) + { + ErrorMessage("Failed to accept control socket connection: %s\n", get_error(errno)); + return false; + } + if (fcntl(fd, F_SETFL, fcntl(fd, F_GETFL) | O_NONBLOCK) < 0) + { + ErrorMessage("Failed to put control socket connection in non-blocking mode: %s\n", + get_error(errno)); + close(fd); + return false; + } + if (!ControlMgmt::add_control(fd, false)) + { + ErrorMessage("Failed to add control connection for descriptor %d\n", fd); + close(fd); + return false; + } + + // FIXIT-L authenticate, use ssl ? + return true; +} + +static void delete_control(const std::unordered_map::const_iterator& iter) +{ + ControlConn* ctrlcon = iter->second; + + // FIXIT-L hacky way to keep the control around until it's no longer being referenced + if (ctrlcon->is_blocked()) + return; + + unregister_control_fd(iter->first, ctrlcon->get_fd()); + delete ctrlcon; + controls.erase(iter); +} + +static void delete_control(int fd) +{ + const auto& iter = controls.find(fd); + if (iter != controls.cend()) + delete_control(iter); +} + +static bool process_control_commands(int fd) +{ + const auto iter = controls.find(fd); + if (iter == controls.cend()) + return false; + + ControlConn* ctrlcon = iter->second; + + int read = ctrlcon->read_commands(); + if (read <= 0) + { + if (read < 0) + delete_control(iter); + return false; + } + + int executed = ctrlcon->execute_commands(); + if (executed > 0) + { + if (ctrlcon->is_local()) + proc_stats.local_commands += executed; + else + proc_stats.remote_commands += executed; + } + + if (ctrlcon->is_closed()) + delete_control(iter); + + return (executed > 0); +} + +static void clear_controls() +{ + for (const auto& p : controls) + { + ControlConn* ctrlcon = p.second; + unregister_control_fd(p.first, ctrlcon->get_fd()); + delete ctrlcon; + } + controls.clear(); +} + +//------------------------------------------------------------------------- +// Public API +//------------------------------------------------------------------------- + +bool ControlMgmt::add_control(int fd, bool local) +{ + auto i = controls.find(fd); + if (i != controls.cend()) + { + if (i->second->is_closed()) + { + delete_control(i); + } + else + { + WarningMessage("Duplicated control channel file descriptor, fd = %d\n", fd); + return false; + } + } + + if (!register_control_fd(fd)) + return false; + + ControlConn* ctrlcon = new ControlConn(fd, local); + controls[fd] = ctrlcon; + + return true; +} + +ControlConn* ControlMgmt::find_control(const lua_State* L) +{ + for (const auto& p : controls) + { + ControlConn* ctrlcon = p.second; + if (ctrlcon->get_shell()->get_lua() == L) + return ctrlcon; + } + return nullptr; +} + +void ControlMgmt::reconfigure_controls() +{ + for (const auto& p : controls) + p.second->configure(); +} + +int ControlMgmt::socket_init(const SnortConfig* sc) +{ + if (!init_controls()) + FatalError("Failed to initialize controls.\n"); + + int sock_family = setup_socket_family(sc); + if (sock_family == AF_UNSPEC) + return -1; + + listener = socket(sock_family, SOCK_STREAM, 0); + + if (listener < 0) + FatalError("Failed to create control listener: %s\n", get_error(errno)); + + // FIXIT-M want to disable time wait + int on = 1; + setsockopt(listener, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); + + if (bind(listener, sock_addr, sock_addr_size) < 0) + FatalError("Failed to bind control listener: %s\n", get_error(errno)); + + if (listen(listener, MAX_CONTROL_FDS) < 0) + FatalError("Failed to start listening on control listener: %s\n", get_error(errno)); + + if (!register_control_fd(listener)) + FatalError("Failed to register listener socket.\n"); + + return 0; +} + +void ControlMgmt::socket_term() +{ + clear_controls(); + + if (listener >= 0) + { + close(listener); + listener = -1; + } + + term_controls(); +} + +bool ControlMgmt::service_users() +{ + static int ready[MAX_CONTROL_FDS], dead[MAX_CONTROL_FDS]; + unsigned nready, ndead; + + if (!poll_control_fds(ready, nready, dead, ndead)) + return false; + + // Process ready descriptors first, even if they're dead, to honor their last request + unsigned serviced = 0; + for (unsigned i = 0; i < nready; i++) + { + int fd = ready[i]; + if (fd == listener) + { + // Got a new connection request, attempt to accept it and store it in controls + if (accept_conn()) + serviced++; + } + else if (process_control_commands(fd)) + serviced++; + } + + for (unsigned i = 0; i < ndead; i++) + delete_control(dead[i]); + + return (serviced > 0); +} + diff --git a/src/main/control_mgmt.h b/src/control/control_mgmt.h similarity index 64% rename from src/main/control_mgmt.h rename to src/control/control_mgmt.h index 6aa772257..d860bc615 100644 --- a/src/main/control_mgmt.h +++ b/src/control/control_mgmt.h @@ -17,41 +17,33 @@ //-------------------------------------------------------------------------- // control_mgmt.h author Bhagya Tholpady // author Devendra Dahiphale +// author Michael Altizer // This provides functions to create and control remote/local connections, // socket creation/deletion/management functions, and shell commands used by the analyzer. #ifndef CONTROL_MGMT_H #define CONTROL_MGMT_H -#include - -#include "request.h" - class ControlConn; +struct lua_State; + +namespace snort +{ +struct SnortConfig; +} class ControlMgmt { public: - static void add_control(int fd, bool local_control); + static bool add_control(int fd, bool local_control); static void reconfigure_controls(); - static int socket_init(); - static int socket_term(); - static int socket_conn(); + static int socket_init(const snort::SnortConfig*); + static void socket_term(); - static bool process_control_commands(int& current_fd, SharedRequest& current_request, int); - static bool process_control_commands(int& current_fd, SharedRequest& current_request); + static ControlConn* find_control(const lua_State*); - static ControlConn* find_control(int fd); - static bool find_control(int fd, std::vector::iterator& control); - - static void delete_controls(); - static void delete_control(int fd); - static void delete_control(std::vector::iterator& control); - - static bool service_users(int& current_fd, SharedRequest& current_request); - -private: - static int setup_socket_family(); + static bool service_users(); }; + #endif diff --git a/src/control/idle_processing.cc b/src/control/idle_processing.cc deleted file mode 100644 index e757dff08..000000000 --- a/src/control/idle_processing.cc +++ /dev/null @@ -1,82 +0,0 @@ -//-------------------------------------------------------------------------- -// Copyright (C) 2014-2021 Cisco and/or its affiliates. All rights reserved. -// Copyright (C) 2011-2013 Sourcefire, Inc. -// -// This program is free software; you can redistribute it and/or modify it -// under the terms of the GNU General Public License Version 2 as published -// by the Free Software Foundation. You may not use, modify or distribute -// this program under any other version of the GNU General Public License. -// -// This program is distributed in the hope that it will be useful, but -// WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -// General Public License for more details. -// -// You should have received a copy of the GNU General Public License along -// with this program; if not, write to the Free Software Foundation, Inc., -// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. -//-------------------------------------------------------------------------- - -// idle_processing.c author Ron Dempster -// -// Allow functions to be registered to be called when packet -// processing is idle. - -#ifdef HAVE_CONFIG_H -#include "config.h" -#endif - -#include "idle_processing.h" - -#include - -static std::vector s_idle_handlers; - -void IdleProcessing::register_handler(IdleHook f) -{ s_idle_handlers.emplace_back(f); } - -void IdleProcessing::execute() -{ - for ( const auto& f : s_idle_handlers ) - f(); -} - -void IdleProcessing::unregister_all() -{ s_idle_handlers.clear(); } - -//-------------------------------------------------------------------------- -// tests -//-------------------------------------------------------------------------- - -#ifdef CATCH_TEST_BUILD - -#include "catch/catch.hpp" - -static unsigned s_niph1 = 0; -static unsigned s_niph2 = 0; - -static void iph1() { s_niph1++; } -static void iph2() { s_niph2++; } - -TEST_CASE("idle callback", "[control]") -{ - IdleProcessing::register_handler(iph1); - IdleProcessing::register_handler(iph2); - - IdleProcessing::execute(); - CHECK(s_niph1 == 1); - CHECK(s_niph2 == 1); - - IdleProcessing::execute(); - CHECK((s_niph1 == 2)); - CHECK((s_niph2 == 2)); - - IdleProcessing::unregister_all(); - - IdleProcessing::execute(); - CHECK((s_niph1 == 2)); - CHECK((s_niph2 == 2)); -} - -#endif - diff --git a/src/control/idle_processing.h b/src/control/idle_processing.h deleted file mode 100644 index e29a17a88..000000000 --- a/src/control/idle_processing.h +++ /dev/null @@ -1,36 +0,0 @@ -//-------------------------------------------------------------------------- -// Copyright (C) 2014-2021 Cisco and/or its affiliates. All rights reserved. -// Copyright (C) 2011-2013 Sourcefire, Inc. -// -// This program is free software; you can redistribute it and/or modify it -// under the terms of the GNU General Public License Version 2 as published -// by the Free Software Foundation. You may not use, modify or distribute -// this program under any other version of the GNU General Public License. -// -// This program is distributed in the hope that it will be useful, but -// WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -// General Public License for more details. -// -// You should have received a copy of the GNU General Public License along -// with this program; if not, write to the Free Software Foundation, Inc., -// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. -//-------------------------------------------------------------------------- - -#ifndef IDLE_PROCESSING_H -#define IDLE_PROCESSING_H - -using IdleHook = void (*)(); - -class IdleProcessing -{ -public: - static void register_handler(IdleHook); - static void execute(); - - // only needs to be called if changing out the handler set - static void unregister_all(); -}; - -#endif - diff --git a/src/host_tracker/host_cache_module.cc b/src/host_tracker/host_cache_module.cc index a845adf7f..8987ee20b 100644 --- a/src/host_tracker/host_cache_module.cc +++ b/src/host_tracker/host_cache_module.cc @@ -28,8 +28,8 @@ #include #include +#include "control/control.h" #include "log/messages.h" -#include "main.h" #include "managers/module_manager.h" #include "utils/util.h" @@ -48,15 +48,15 @@ static int host_cache_dump(lua_State* L) return 0; } -static int host_cache_get_stats(lua_State*) +static int host_cache_get_stats(lua_State* L) { HostCacheModule* mod = (HostCacheModule*) ModuleManager::get_module(HOST_CACHE_NAME); if ( mod ) { - SharedRequest current_request = get_current_request(); + ControlConn* ctrlcon = ControlConn::query_from_lua(L); string outstr = mod->get_host_cache_stats(); - current_request->respond(outstr.c_str()); + ctrlcon->respond("%s", outstr.c_str()); } return 0; } diff --git a/src/host_tracker/test/host_cache_module_test.cc b/src/host_tracker/test/host_cache_module_test.cc index 318ea60e2..6a495290c 100644 --- a/src/host_tracker/test/host_cache_module_test.cc +++ b/src/host_tracker/test/host_cache_module_test.cc @@ -26,9 +26,9 @@ #include #include +#include "control/control.h" #include "host_tracker/host_cache_module.h" #include "host_tracker/host_cache.h" -#include "main/request.h" #include "main/snort_config.h" #include "managers/module_manager.h" @@ -45,9 +45,11 @@ static HostCacheModule module; #define LOG_MAX 128 static char logged_message[LOG_MAX+1]; -static SharedRequest mock_request = std::make_shared(); -void Request::respond(const char*, bool, bool) { } -SharedRequest get_current_request() { return mock_request; } +static ControlConn ctrlcon(1, true); +ControlConn::ControlConn(int, bool) {} +ControlConn::~ControlConn() {} +ControlConn* ControlConn::query_from_lua(const lua_State*) { return &ctrlcon; } +bool ControlConn::respond(const char*, ...) { return true; } namespace snort { diff --git a/src/main.cc b/src/main.cc index cc9a9c43e..775c33fa3 100644 --- a/src/main.cc +++ b/src/main.cc @@ -25,7 +25,7 @@ #include -#include "control/idle_processing.h" +#include "control/control.h" #include "detection/signature.h" #include "framework/module.h" #include "helpers/process.h" @@ -65,7 +65,7 @@ #endif #ifdef SHELL -#include "main/control_mgmt.h" +#include "control/control_mgmt.h" #include "main/ac_shell_cmd.h" #endif @@ -111,16 +111,6 @@ static int main_read() return pig_poke->get(-1); } -static SharedRequest current_request = std::make_shared(); -#ifdef SHELL -static int current_fd = -1; -#endif - -SharedRequest get_current_request() -{ - return current_request; -} - //------------------------------------------------------------------------- // pig foo //------------------------------------------------------------------------- @@ -283,23 +273,31 @@ static Pig* get_lazy_pig(unsigned max) // main commands //------------------------------------------------------------------------- -static AnalyzerCommand* get_command(AnalyzerCommand* ac, bool from_shell) +static AnalyzerCommand* get_command(AnalyzerCommand* ac, ControlConn* ctrlcon) { #ifndef SHELL - UNUSED(from_shell); + UNUSED(ctrlcon); #else - if ( from_shell ) - return ( new ACShellCmd(current_fd, ac) ); + if (ctrlcon) + return ( new ACShellCmd(ctrlcon, ac) ); else #endif return ac; } -void snort::main_broadcast_command(AnalyzerCommand* ac, bool from_shell) +static void send_response(ControlConn* ctrlcon, const char* response) +{ + if (ctrlcon) + ctrlcon->respond("%s", response); + else + LogMessage("%s", response); +} + +void snort::main_broadcast_command(AnalyzerCommand* ac, ControlConn* ctrlcon) { unsigned dispatched = 0; - ac = get_command(ac, from_shell); + ac = get_command(ac, ctrlcon); debug_logf(snort_trace, TRACE_MAIN, nullptr, "Broadcasting %s command\n", ac->stringify()); for (unsigned idx = 0; idx < max_pigs; ++idx) @@ -313,10 +311,10 @@ void snort::main_broadcast_command(AnalyzerCommand* ac, bool from_shell) } #ifdef REG_TEST -void snort::main_unicast_command(AnalyzerCommand* ac, unsigned target, bool from_shell) +void snort::main_unicast_command(AnalyzerCommand* ac, unsigned target, ControlConn* ctrlcon) { assert(target < max_pigs); - ac = get_command(ac, from_shell); + ac = get_command(ac, ctrlcon); if (!pigs[target].queue_command(ac)) orphan_commands.push(ac); } @@ -324,34 +322,35 @@ void snort::main_unicast_command(AnalyzerCommand* ac, unsigned target, bool from int main_dump_stats(lua_State* L) { - bool from_shell = ( L != nullptr ); - current_request->respond("== dumping stats\n", from_shell); - main_broadcast_command(new ACGetStats(), from_shell); + ControlConn* ctrlcon = ControlConn::query_from_lua(L); + send_response(ctrlcon, "== dumping stats\n"); + main_broadcast_command(new ACGetStats(), ctrlcon); return 0; } int main_reset_stats(lua_State* L) { + ControlConn* ctrlcon = ControlConn::query_from_lua(L); int type = luaL_optint(L, 1, 0); - bool from_shell = ( L != nullptr ); - current_request->respond("== clearing stats\n", from_shell); - main_broadcast_command(new ACResetStats(static_cast(type)), true); + ctrlcon->respond("== clearing stats\n"); + main_broadcast_command(new ACResetStats(static_cast(type)), ctrlcon); return 0; } int main_rotate_stats(lua_State* L) { - bool from_shell = ( L != nullptr ); - current_request->respond("== rotating stats\n", from_shell); - main_broadcast_command(new ACRotate(), from_shell); + ControlConn* ctrlcon = ControlConn::query_from_lua(L); + send_response(ctrlcon, "== rotating stats\n"); + main_broadcast_command(new ACRotate(), ctrlcon); return 0; } int main_reload_config(lua_State* L) { + ControlConn* ctrlcon = ControlConn::query_from_lua(L); if ( Swapper::get_reload_in_progress() ) { - current_request->respond("== reload pending; retry\n"); + send_response(ctrlcon, "== reload pending; retry\n"); return 0; } const char* fname = nullptr; @@ -366,11 +365,11 @@ int main_reload_config(lua_State* L) plugin_path = luaL_checkstring(L, 2); std::ostringstream plugin_path_msg; plugin_path_msg << "-- reload plugin_path: " << plugin_path << "\n"; - current_request->respond(plugin_path_msg.str().c_str()); + send_response(ctrlcon, plugin_path_msg.str().c_str()); } } - current_request->respond(".. reloading configuration\n"); + send_response(ctrlcon, ".. reloading configuration\n"); const SnortConfig* old = SnortConfig::get_conf(); SnortConfig* sc = Snort::get_reload_config(fname, plugin_path, old); @@ -380,12 +379,13 @@ int main_reload_config(lua_State* L) { std::string response_message = "== reload failed - restart required - "; response_message += get_reload_errors_description() + "\n"; - current_request->respond(response_message.c_str()); + send_response(ctrlcon, response_message.c_str()); reset_reload_errors(); } else - current_request->respond("== reload failed - bad config\n"); + send_response(ctrlcon, "== reload failed - bad config\n"); + HostAttributesManager::load_failure_cleanup(); return 0; } @@ -408,18 +408,18 @@ int main_reload_config(lua_State* L) TraceApi::thread_reinit(sc->trace_config); proc_stats.conf_reloads++; - bool from_shell = ( L != nullptr ); - current_request->respond(".. swapping configuration\n", from_shell); - main_broadcast_command(new ACSwap(new Swapper(old, sc), current_request, from_shell), from_shell); + send_response(ctrlcon, ".. swapping configuration\n"); + main_broadcast_command(new ACSwap(new Swapper(old, sc), ctrlcon), ctrlcon); return 0; } int main_reload_policy(lua_State* L) { + ControlConn* ctrlcon = ControlConn::query_from_lua(L); if ( Swapper::get_reload_in_progress() ) { - current_request->respond("== reload pending; retry\n"); + send_response(ctrlcon, "== reload pending; retry\n"); return 0; } const char* fname = nullptr; @@ -431,10 +431,10 @@ int main_reload_policy(lua_State* L) } if ( fname and *fname ) - current_request->respond(".. reloading policy\n"); + send_response(ctrlcon, ".. reloading policy\n"); else { - current_request->respond("== filename required\n"); + send_response(ctrlcon, "== filename required\n"); return 0; } @@ -443,25 +443,25 @@ int main_reload_policy(lua_State* L) if ( !sc ) { - current_request->respond("== reload failed\n"); + send_response(ctrlcon, "== reload failed\n"); return 0; } sc->update_reload_id(); SnortConfig::set_conf(sc); proc_stats.policy_reloads++; - bool from_shell = ( L != nullptr ); - current_request->respond(".. swapping policy\n", from_shell); - main_broadcast_command(new ACSwap(new Swapper(old, sc), current_request, from_shell), from_shell); + send_response(ctrlcon, ".. swapping policy\n"); + main_broadcast_command(new ACSwap(new Swapper(old, sc), ctrlcon), ctrlcon); return 0; } int main_reload_module(lua_State* L) { + ControlConn* ctrlcon = ControlConn::query_from_lua(L); if ( Swapper::get_reload_in_progress() ) { - current_request->respond("== reload pending; retry\n"); + send_response(ctrlcon, "== reload pending; retry\n"); return 0; } const char* fname = nullptr; @@ -473,10 +473,10 @@ int main_reload_module(lua_State* L) } if ( fname and *fname ) - current_request->respond(".. reloading module\n"); + send_response(ctrlcon, ".. reloading module\n"); else { - current_request->respond("== module name required\n"); + send_response(ctrlcon, "== module name required\n"); return 0; } @@ -485,25 +485,24 @@ int main_reload_module(lua_State* L) if ( !sc ) { - current_request->respond("== reload failed\n"); + send_response(ctrlcon, "== reload failed\n"); return 0; } sc->update_reload_id(); SnortConfig::set_conf(sc); proc_stats.policy_reloads++; - bool from_shell = ( L != nullptr ); - current_request->respond(".. swapping module\n", from_shell); - main_broadcast_command(new ACSwap(new Swapper(old, sc), current_request, from_shell), from_shell); + send_response(ctrlcon, ".. swapping module\n"); + main_broadcast_command(new ACSwap(new Swapper(old, sc), ctrlcon), ctrlcon); return 0; } int main_reload_daq(lua_State* L) { - bool from_shell = ( L != nullptr ); - current_request->respond(".. reloading daq module\n", from_shell); - main_broadcast_command(new ACDAQSwap(), from_shell); + ControlConn* ctrlcon = ControlConn::query_from_lua(L); + send_response(ctrlcon, ".. reloading daq module\n"); + main_broadcast_command(new ACDAQSwap(), ctrlcon); proc_stats.daq_reloads++; return 0; @@ -511,22 +510,21 @@ int main_reload_daq(lua_State* L) int main_reload_hosts(lua_State* L) { + ControlConn* ctrlcon = ControlConn::query_from_lua(L); if ( Swapper::get_reload_in_progress() ) { WarningMessage("Reload in progress. Cannot reload host attribute table.\n"); - current_request->respond("== reload pending; retry\n"); + send_response(ctrlcon, "== reload pending; retry\n"); return 0; } SnortConfig* sc = SnortConfig::get_main_conf(); - bool from_shell = false; const char* fname; if ( L ) { Lua::ManageStack(L, 1); fname = luaL_optstring(L, 1, sc->attribute_hosts_file.c_str()); - from_shell = true; } else fname = sc->attribute_hosts_file.c_str(); @@ -534,19 +532,19 @@ int main_reload_hosts(lua_State* L) if ( fname and *fname ) { LogMessage("Reloading Host attribute table from %s.\n", fname); - current_request->respond(".. reloading hosts table\n"); + send_response(ctrlcon, ".. reloading hosts table\n"); } else { ErrorMessage("Reload failed. Host attribute table filename required.\n"); - current_request->respond("== filename required\n"); + send_response(ctrlcon, "== filename required\n"); return 0; } if ( !HostAttributesManager::load_hosts_file(sc, fname) ) { ErrorMessage("Host attribute table reload from %s failed.\n", fname); - current_request->respond("== reload failed\n"); + send_response(ctrlcon, "== reload failed\n"); return 0; } @@ -555,17 +553,18 @@ int main_reload_hosts(lua_State* L) assert( num_hosts >= 0 ); LogMessage("Host attribute table: %d hosts loaded successfully.\n", num_hosts); - current_request->respond(".. swapping hosts table\n", from_shell); - main_broadcast_command(new ACHostAttributesSwap(current_request, from_shell), from_shell); + send_response(ctrlcon, ".. swapping hosts table\n"); + main_broadcast_command(new ACHostAttributesSwap(ctrlcon), ctrlcon); return 0; } int main_delete_inspector(lua_State* L) { + ControlConn* ctrlcon = ControlConn::query_from_lua(L); if ( Swapper::get_reload_in_progress() ) { - current_request->respond("== delete pending; retry\n"); + send_response(ctrlcon, "== delete pending; retry\n"); return 0; } const char* iname = nullptr; @@ -577,10 +576,10 @@ int main_delete_inspector(lua_State* L) } if ( iname and *iname ) - current_request->respond(".. deleting inspector\n"); + send_response(ctrlcon, ".. deleting inspector\n"); else { - current_request->respond("== inspector name required\n"); + send_response(ctrlcon, "== inspector name required\n"); return 0; } @@ -589,51 +588,51 @@ int main_delete_inspector(lua_State* L) if ( !sc ) { - current_request->respond("== reload failed\n"); + send_response(ctrlcon, "== reload failed\n"); return 0; } SnortConfig::set_conf(sc); proc_stats.inspector_deletions++; - bool from_shell = ( L != nullptr ); - current_request->respond(".. deleted inspector\n", from_shell); - main_broadcast_command(new ACSwap(new Swapper(old, sc), current_request, from_shell), from_shell); + send_response(ctrlcon, ".. deleted inspector\n"); + main_broadcast_command(new ACSwap(new Swapper(old, sc), ctrlcon), ctrlcon); return 0; } int main_process(lua_State* L) { + ControlConn* ctrlcon = ControlConn::query_from_lua(L); const char* f = lua_tostring(L, 1); if ( !f ) { - current_request->respond("== pcap filename required\n"); + send_response(ctrlcon, "== pcap filename required\n"); return 0; } - current_request->respond("== queuing pcap\n"); + send_response(ctrlcon, "== queuing pcap\n"); Trough::add_source(Trough::SOURCE_LIST, f); return 0; } int main_pause(lua_State* L) { - bool from_shell = ( L != nullptr ); - current_request->respond("== pausing\n", from_shell); - main_broadcast_command(new ACPause(), from_shell); + ControlConn* ctrlcon = ControlConn::query_from_lua(L); + send_response(ctrlcon, "== pausing\n"); + main_broadcast_command(new ACPause(), ctrlcon); paused = true; return 0; } int main_resume(lua_State* L) { - bool from_shell = ( L != nullptr ); + ControlConn* ctrlcon = ControlConn::query_from_lua(L); uint64_t pkt_num = 0; #ifdef REG_TEST int target = -1; #endif - if (from_shell) + if (L) { const int num_of_args = lua_gettop(L); if (num_of_args) @@ -641,7 +640,7 @@ int main_resume(lua_State* L) pkt_num = lua_tointeger(L, 1); if (pkt_num < 1) { - current_request->respond("Invalid usage of resume(n), n should be a number > 0\n"); + send_response(ctrlcon, "Invalid usage of resume(n), n should be a number > 0\n"); return 0; } #ifdef REG_TEST @@ -650,7 +649,7 @@ int main_resume(lua_State* L) target = lua_tointeger(L, 2); if (target < 0 or unsigned(target) >= max_pigs) { - current_request->respond( + send_response(ctrlcon, "Invalid usage of resume(n,m), m should be a number >= 0 and less than number of threads\n"); return 0; } @@ -658,24 +657,26 @@ int main_resume(lua_State* L) #endif } } - current_request->respond("== resuming\n", from_shell); + send_response(ctrlcon, "== resuming\n"); #ifdef REG_TEST if (target >= 0) - main_unicast_command(new ACResume(pkt_num), target, from_shell); + main_unicast_command(new ACResume(pkt_num), target, ctrlcon); else - main_broadcast_command(new ACResume(pkt_num), from_shell); + main_broadcast_command(new ACResume(pkt_num), ctrlcon); #else - main_broadcast_command(new ACResume(pkt_num), from_shell); + main_broadcast_command(new ACResume(pkt_num), ctrlcon); #endif paused = false; return 0; } #ifdef SHELL -int main_detach(lua_State*) +int main_detach(lua_State* L) { - current_request->respond("== detaching\n"); + ControlConn* ctrlcon = ControlConn::query_from_lua(L); + send_response(ctrlcon, "== detaching\n"); + ctrlcon->shutdown(); return 0; } @@ -685,31 +686,43 @@ int main_dump_plugins(lua_State*) PluginManager::dump_plugins(); return 0; } - #endif int main_quit(lua_State* L) { - bool from_shell = ( L != nullptr ); - current_request->respond("== stopping\n", from_shell); - main_broadcast_command(new ACStop(), from_shell); + ControlConn* ctrlcon = ControlConn::query_from_lua(L); + send_response(ctrlcon, "== stopping\n"); + main_broadcast_command(new ACStop(), ctrlcon); exit_requested = true; return 0; } -int main_help(lua_State*) +int main_help(lua_State* L) { - const Command* cmd = get_snort_module()->get_commands(); - - while ( cmd->name ) + ControlConn* ctrlcon = ControlConn::query_from_lua(L); + std::list modules = ModuleManager::get_all_modules(); + for (const auto& m : modules) { - std::string info = cmd->name; - info += cmd->get_arg_list(); - info += ": "; - info += cmd->help; - info += "\n"; - current_request->respond(info.c_str()); - ++cmd; + const Command* cmd = m->get_commands(); + if (!cmd) + continue; + std::string prefix; + if (strcmp(m->get_name(), "snort")) + { + prefix = m->get_name(); + prefix += '.'; + } + while (cmd->name) + { + std::string info = prefix; + info += cmd->name; + info += cmd->get_arg_list(); + info += ": "; + info += cmd->help; + info += "\n"; + send_response(ctrlcon, info.c_str()); + ++cmd; + } } return 0; } @@ -785,8 +798,6 @@ static bool house_keeping() reap_commands(); - IdleProcessing::execute(); - Periodic::check(); InspectorManager::empty_trash(); @@ -797,7 +808,7 @@ static bool house_keeping() static void service_check() { #ifdef SHELL - if (all_pthreads_started && ControlMgmt::service_users(current_fd, current_request) ) + if (all_pthreads_started && ControlMgmt::service_users() ) return; #endif @@ -904,14 +915,6 @@ static bool set_mode() else LogMessage("Commencing packet processing\n"); -#ifdef SHELL - if ( use_shell(sc) ) - { - LogMessage("Entering command shell\n"); - ControlMgmt::add_control(STDOUT_FILENO, true); - } -#endif - return true; } @@ -1034,10 +1037,19 @@ static void main_loop() const unsigned num_threads = (!Trough::has_next()) ? swine : max_pigs; for (unsigned i = 0; i < num_threads; i++) all_pthreads_started &= pigs_started[i]; -#ifdef REG_TEST if (all_pthreads_started) + { +#ifdef REG_TEST LogMessage("All pthreads started\n"); #endif +#ifdef SHELL + if (use_shell(SnortConfig::get_conf())) + { + LogMessage("Entering command shell\n"); + ControlMgmt::add_control(STDOUT_FILENO, true); + } +#endif + } } if ( !exit_requested and (swine < max_pigs) and (src = Trough::get_next()) ) @@ -1054,7 +1066,7 @@ static void main_loop() static void snort_main() { #ifdef SHELL - ControlMgmt::socket_init(); + ControlMgmt::socket_init(SnortConfig::get_conf()); #endif SnortConfig::get_conf()->thread_config->implement_thread_affinity( diff --git a/src/main.h b/src/main.h index 4551dd910..94cca638c 100644 --- a/src/main.h +++ b/src/main.h @@ -21,12 +21,9 @@ #ifndef MAIN_H #define MAIN_H -#include "main/request.h" - struct lua_State; const char* get_prompt(); -SharedRequest get_current_request(); // commands provided by the snort module int main_delete_inspector(lua_State* = nullptr); diff --git a/src/main/CMakeLists.txt b/src/main/CMakeLists.txt index 680d861a9..0d378c553 100644 --- a/src/main/CMakeLists.txt +++ b/src/main/CMakeLists.txt @@ -2,7 +2,6 @@ set (INCLUDES analyzer_command.h policy.h - request.h snort.h snort_config.h snort_debug.h @@ -18,11 +17,9 @@ set (LOCAL_INCLUDES ) if ( ENABLE_SHELL ) - set ( SHELL_SOURCES control.cc control.h control_mgmt.cc control_mgmt.h ac_shell_cmd.h ac_shell_cmd.cc) + set ( SHELL_SOURCES ac_shell_cmd.h ac_shell_cmd.cc) endif ( ENABLE_SHELL ) -add_subdirectory(test) - add_library (main OBJECT analyzer.cc analyzer.h @@ -34,7 +31,6 @@ add_library (main OBJECT oops_handler.cc oops_handler.h policy.cc - request.cc shell.h shell.cc snort.cc diff --git a/src/main/ac_shell_cmd.cc b/src/main/ac_shell_cmd.cc index 6080e6e8b..e4cb34628 100644 --- a/src/main/ac_shell_cmd.cc +++ b/src/main/ac_shell_cmd.cc @@ -25,40 +25,25 @@ #include -#include "control_mgmt.h" -#include "control.h" +#include "control/control.h" -ACShellCmd::ACShellCmd(int fd, AnalyzerCommand *ac) : ac(ac) +ACShellCmd::ACShellCmd(ControlConn* ctrlcon, AnalyzerCommand* ac) : ctrlcon(ctrlcon), ac(ac) { assert(ac); - ControlConn* control_conn = ControlMgmt::find_control(fd); - - if( control_conn ) - { - control_conn->block(); - control_fd = fd; - } + if (ctrlcon) + ctrlcon->block(); } bool ACShellCmd::execute(Analyzer& analyzer, void** state) { - ControlConn* control_conn = ControlMgmt::find_control(control_fd); - - if( control_conn ) - control_conn->send_queued_response(); - return ac->execute(analyzer, state); } ACShellCmd::~ACShellCmd() { delete ac; - ControlConn* control = ControlMgmt::find_control(control_fd); - if( control ) - { - control->send_queued_response(); - control->unblock(); - } + if (ctrlcon) + ctrlcon->unblock(); } diff --git a/src/main/ac_shell_cmd.h b/src/main/ac_shell_cmd.h index 2fc6fd1bd..54e9d1241 100644 --- a/src/main/ac_shell_cmd.h +++ b/src/main/ac_shell_cmd.h @@ -15,7 +15,7 @@ // with this program; if not, write to the Free Software Foundation, Inc., // 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. //-------------------------------------------------------------------------- -// control_mgmt.h author Bhagya Tholpady +// ac_shell_cmd.h author Bhagya Tholpady // // This provides functions to create and control remote/local connections, // socket creation/deletion/management functions, and shell commands used by the analyzer. @@ -26,17 +26,19 @@ #include "main/analyzer.h" #include "main/analyzer_command.h" +class ControlConn; + class ACShellCmd : public snort::AnalyzerCommand { public: ACShellCmd() = delete; - ACShellCmd(int fd, snort::AnalyzerCommand* ac_cmd); + ACShellCmd(ControlConn*, snort::AnalyzerCommand*); bool execute(Analyzer&, void**) override; const char* stringify() override { return ac->stringify(); } ~ACShellCmd() override; private: - int control_fd = -1; + ControlConn* ctrlcon; snort::AnalyzerCommand* ac; }; diff --git a/src/main/analyzer_command.cc b/src/main/analyzer_command.cc index e7f6440e3..3a8fd2d10 100644 --- a/src/main/analyzer_command.cc +++ b/src/main/analyzer_command.cc @@ -25,6 +25,7 @@ #include +#include "control/control.h" #include "framework/module.h" #include "log/messages.h" #include "managers/module_manager.h" @@ -105,7 +106,7 @@ bool ACResetStats::execute(Analyzer&, void**) ACResetStats::ACResetStats(clear_counter_type_t requested_type_l) : requested_type( requested_type_l) { } -ACSwap::ACSwap(Swapper* ps, SharedRequest req, bool from_shell) : ps(ps), request(req), from_shell(from_shell) +ACSwap::ACSwap(Swapper* ps, ControlConn *ctrlcon) : ps(ps), ctrlcon(ctrlcon) { assert(Swapper::get_reload_in_progress() == false); Swapper::set_reload_in_progress(true); @@ -181,11 +182,12 @@ ACSwap::~ACSwap() Swapper::set_reload_in_progress(false); LogMessage("== reload complete\n"); - request->respond("== reload complete\n", from_shell, true); + if (ctrlcon && !ctrlcon->is_local()) + ctrlcon->respond("== reload complete\n"); } -ACHostAttributesSwap::ACHostAttributesSwap(SharedRequest req, bool from_shell) - : request(req), from_shell(from_shell) +ACHostAttributesSwap::ACHostAttributesSwap(ControlConn *ctrlcon) + : ctrlcon(ctrlcon) { assert(Swapper::get_reload_in_progress() == false); Swapper::set_reload_in_progress(true); @@ -202,7 +204,8 @@ ACHostAttributesSwap::~ACHostAttributesSwap() HostAttributesManager::swap_cleanup(); Swapper::set_reload_in_progress(false); LogMessage("== reload host attributes complete\n"); - request->respond("== reload host attributes complete\n", from_shell, true); + if (ctrlcon && !ctrlcon->is_local()) + ctrlcon->respond("== reload host attributes complete\n"); } bool ACDAQSwap::execute(Analyzer& analyzer, void**) diff --git a/src/main/analyzer_command.h b/src/main/analyzer_command.h index 4b2c5a576..f19b784ed 100644 --- a/src/main/analyzer_command.h +++ b/src/main/analyzer_command.h @@ -20,10 +20,12 @@ #ifndef ANALYZER_COMMANDS_H #define ANALYZER_COMMANDS_H -#include "request.h" -#include "snort_types.h" +#include + +#include "main/snort_types.h" class Analyzer; +class ControlConn; class Swapper; namespace snort @@ -140,27 +142,25 @@ class ACSwap : public snort::AnalyzerCommand { public: ACSwap() = delete; - ACSwap(Swapper* ps, SharedRequest req, bool from_shell); + ACSwap(Swapper* ps, ControlConn* ctrlcon); bool execute(Analyzer&, void**) override; const char* stringify() override { return "SWAP"; } ~ACSwap() override; private: Swapper *ps; - SharedRequest request; - bool from_shell; + ControlConn* ctrlcon; }; class ACHostAttributesSwap : public snort::AnalyzerCommand { public: - ACHostAttributesSwap(SharedRequest req, bool from_shell); + ACHostAttributesSwap(ControlConn* ctrlcon); bool execute(Analyzer&, void**) override; const char* stringify() override { return "HOST_ATTRIBUTES_SWAP"; } ~ACHostAttributesSwap() override; private: - SharedRequest request; - bool from_shell; + ControlConn* ctrlcon; }; class ACDAQSwap : public snort::AnalyzerCommand @@ -174,10 +174,10 @@ public: namespace snort { // from main.cc -SO_PUBLIC void main_broadcast_command(snort::AnalyzerCommand* ac, bool from_shell = false); #ifdef REG_TEST -void main_unicast_command(AnalyzerCommand* ac, unsigned target, bool from_shell = false); +void main_unicast_command(AnalyzerCommand* ac, unsigned target, ControlConn* ctrlcon = nullptr); #endif +SO_PUBLIC void main_broadcast_command(snort::AnalyzerCommand* ac, ControlConn* ctrlcon = nullptr); } #endif diff --git a/src/main/control.cc b/src/main/control.cc deleted file mode 100644 index a871e9cea..000000000 --- a/src/main/control.cc +++ /dev/null @@ -1,106 +0,0 @@ -//-------------------------------------------------------------------------- -// Copyright (C) 2017-2021 Cisco and/or its affiliates. All rights reserved. -// -// This program is free software; you can redistribute it and/or modify it -// under the terms of the GNU General Public License Version 2 as published -// by the Free Software Foundation. You may not use, modify or distribute -// this program under any other version of the GNU General Public License. -// -// This program is distributed in the hope that it will be useful, but -// WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -// General Public License for more details. -// -// You should have received a copy of the GNU General Public License along -// with this program; if not, write to the Free Software Foundation, Inc., -// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. -//-------------------------------------------------------------------------- - -#ifdef HAVE_CONFIG_H -#include "config.h" -#endif - -#include "control.h" - -#include "main.h" -#include "managers/module_manager.h" -#include "utils/util.h" - -#include "control_mgmt.h" -#include "shell.h" - -using namespace snort; -using namespace std; - -//------------------------------------------------------------------------ -// control channel class -// ----------------------------------------------------------------------- - -ControlConn::ControlConn(int i, bool local) -{ - fd = i; - local_control = local; - sh = new Shell; - request = std::make_shared(fd); - configure(); - show_prompt(); -} - -ControlConn::~ControlConn() -{ - if( !local_control ) - close(fd); - delete sh; -} - -void ControlConn::configure() const -{ - ModuleManager::load_commands(sh); -} - -int ControlConn::shell_execute(int& current_fd, SharedRequest& current_request) -{ - if ( !request->read() ) - return -1; - - current_fd = fd; - current_request = request; - - std::string rsp; - sh->execute(request->get(), rsp); - - if ( !rsp.empty() and !is_blocked() ) - request->respond(rsp.c_str()); - - if ( fd >= 0 and !is_blocked() ) - show_prompt(); - - return fd; -} - -void ControlConn::block() -{ - blocked = true; -} - -void ControlConn::unblock() -{ - blocked = false; - if ( !show_prompt() ) - ControlMgmt::delete_control(fd); -} - -void ControlConn::send_queued_response() -{ -#ifdef SHELL - request->send_queued_response(); -#endif -} - -// FIXIT-L would like to flush prompt w/o \n -bool ControlConn::show_prompt() const -{ - std::string s = get_prompt(); - s += "\n"; - return request->write_response(s.c_str()); -} diff --git a/src/main/control_mgmt.cc b/src/main/control_mgmt.cc deleted file mode 100644 index 3e08d2312..000000000 --- a/src/main/control_mgmt.cc +++ /dev/null @@ -1,474 +0,0 @@ -//-------------------------------------------------------------------------- -// Copyright (C) 2017-2021 Cisco and/or its affiliates. All rights reserved. -// -// This program is free software; you can redistribute it and/or modify it -// under the terms of the GNU General Public License Version 2 as published -// by the Free Software Foundation. You may not use, modify or distribute -// this program under any other version of the GNU General Public License. -// -// This program is distributed in the hope that it will be useful, but -// WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -// General Public License for more details. -// -// You should have received a copy of the GNU General Public License along -// with this program; if not, write to the Free Software Foundation, Inc., -// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. -//-------------------------------------------------------------------------- -// control_mgmt.cc author Bhagya Tholpady -// author Devendra Dahiphale - -#ifdef HAVE_CONFIG_H -#include "config.h" -#endif - -#include "control_mgmt.h" - -#include -#include - -#include -#include - -#include "log/messages.h" -#include "utils/stats.h" -#include "utils/util.h" -#include "control.h" -#include "snort_config.h" -#include "utils/util_cstring.h" - -using namespace snort; -using namespace std; - -static int listener = -1; -static socklen_t sock_addr_size = 0; -static struct sockaddr* sock_addr = nullptr; -static struct sockaddr_in in_addr; -static struct sockaddr_un unix_addr; - -//------------------------------------------------------------------------- -// epoll implementation (supported by only linux systems) -//------------------------------------------------------------------------- -// Only linux systems support epoll (event polling) mechanism. -// It allows a process to monitor multiple file descriptors -// and get notification (using epoll_wait) when I/O is possible on them. -#ifdef __linux__ - -#include -#include - -static int epoll_fd = -1; -static unordered_map controls; - -#define MAX_EPOLL_EVENTS 16 - -static void add_to_epoll(const int fd) -{ - if (epoll_fd == -1) - { - epoll_fd = epoll_create1(0); - if (epoll_fd == -1) - FatalError("Failed to create epoll file descriptor: %s\n", get_error(errno)); - } - - struct epoll_event event; - event.events = EPOLLIN; - event.data.fd = fd; - - if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fd, &event)) - WarningMessage("Failed to add file descriptor to epoll: %s\n", get_error(errno)); -} - -static void remove_from_epoll(const int fd) -{ - if (epoll_ctl(epoll_fd, EPOLL_CTL_DEL, fd, nullptr)) - WarningMessage("Failed to remove file descriptor from epoll\n"); -} - -void ControlMgmt::add_control(int fd, bool local) -{ - if (controls.find(fd) != controls.end()) - { - assert(0); - WarningMessage("Cannot have two active connections with the same fd\n"); - return; - } - - controls[fd] = new ControlConn(fd, local); - add_to_epoll(fd); -} - -ControlConn* ControlMgmt::find_control(int fd) -{ - auto control_conn = controls.find(fd); - if (control_conn == controls.end()) - return nullptr; - - return control_conn->second; -} - -void ControlMgmt::delete_control(int fd) -{ - auto control_conn = find_control(fd); - if (control_conn) - { - remove_from_epoll(fd); - delete control_conn; - controls.erase(fd); - } -} - -void ControlMgmt::reconfigure_controls() -{ - for (auto &control : controls) - control.second->configure(); -} - -void ControlMgmt::delete_controls() -{ - for (auto &control : controls) - { - remove_from_epoll(control.first); - delete control.second; - } - controls.clear(); -} - -int ControlMgmt::socket_init() -{ - int sock_family = setup_socket_family(); - - if (sock_family == AF_UNSPEC) - return -1; - - listener = socket(sock_family, SOCK_STREAM, 0); - - if (listener < 0) - FatalError("socket failed: %s\n", get_error(errno)); - - // FIXIT-M want to disable time wait - int on = 1; - if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) - FatalError("setsockopt() call failed: %s", get_error(errno)); - - if (::bind(listener, sock_addr, sock_addr_size) < 0) - FatalError("bind failed: %s\n", get_error(errno)); - - // FIXIT-M configure max conns - if (listen(listener, 0) < 0) - FatalError("listen failed: %s\n", get_error(errno)); - - add_to_epoll(listener); - - return 0; -} - -int ControlMgmt::socket_term() -{ - delete_controls(); - - if (listener >= 0) - close(listener); - - listener = -1; - - if (epoll_fd >= 0) - close(epoll_fd); - - epoll_fd = -1; - - return 0; -} - -bool ControlMgmt::process_control_commands(int& current_fd, SharedRequest& current_request, int evnt_fd) -{ - auto control_conn = controls.find(evnt_fd); - - if (control_conn == controls.end()) - return false; - - SharedRequest old_request = current_request; - int fd = control_conn->second->shell_execute(current_fd, current_request); - current_fd = -1; - current_request = old_request; - - if (fd < 0) - return false; - - if (control_conn->second->is_local_control()) - proc_stats.local_commands++; - else - proc_stats.remote_commands++; - - return true; -} - -bool ControlMgmt::service_users(int& current_fd, SharedRequest& current_request) -{ - bool ret = false; - struct epoll_event events[MAX_EPOLL_EVENTS]; - - int event_count = epoll_wait(epoll_fd, events, MAX_EPOLL_EVENTS, 0); - for(int i = 0; i < event_count; i++) - { - if (listener == events[i].data.fd) - { - // got a new connection request, accept it and store it in controls - if( !socket_conn() ) - { - ret = true; - } - } - else - { - ret = process_control_commands(current_fd, current_request, events[i].data.fd); - if (!ret && (events[i].events & EPOLLHUP)) - { - // FIXIT-L quick fix to emulate not poll()ing for events for blocked connections - ControlConn* control_conn = find_control(events[i].data.fd); - if (control_conn && !control_conn->is_blocked()) - delete_control(events[i].data.fd); - } - } - } - return ret; -} - -#else -//------------------------------------------------------------------------- -// select implementation (default) -//------------------------------------------------------------------------- -// Default implementation using select() for monitoring multiple fds. - -static fd_set inputs; -static vector controls; - -void ControlMgmt::add_control(int fd, bool local) -{ - controls.emplace_back(new ControlConn(fd, local)); -} - -bool ControlMgmt::find_control(int fd, vector::iterator& control) -{ - control = find_if(controls.begin(), controls.end(), - [=](const ControlConn* c) { return c->get_fd() == fd; }); - - if (control != controls.end()) - return true; - else - return false; -} - -ControlConn* ControlMgmt::find_control(int fd) -{ - vector::iterator it; - - ControlConn* control = find_control(fd, it) ? (*it) : nullptr; - return control; -} - -void ControlMgmt::delete_control(vector::iterator& control) -{ - delete *control; - control = controls.erase(control); -} - -void ControlMgmt::delete_control(int fd) -{ - vector::iterator control; - if (find_control(fd, control)) - delete_control(control); -} - -void ControlMgmt::reconfigure_controls() -{ - for (auto control : controls) - { - control->configure(); - } -} - -void ControlMgmt::delete_controls() -{ - for (auto control : controls) - { - delete control; - } - controls.clear(); -} - -int ControlMgmt::socket_init() -{ - int sock_family = setup_socket_family(); - - if (sock_family == AF_UNSPEC) - return -1; - - listener = socket(sock_family, SOCK_STREAM, 0); - - if (listener < 0) - FatalError("socket failed: %s\n", get_error(errno)); - - // FIXIT-M want to disable time wait - int on = 1; - setsockopt(listener, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); - - if (::bind(listener, sock_addr, sock_addr_size) < 0) - FatalError("bind failed: %s\n", get_error(errno)); - - // FIXIT-M configure max conns - if (listen(listener, 0) < 0) - FatalError("listen failed: %s\n", get_error(errno)); - - return 0; -} - -int ControlMgmt::socket_term() -{ - delete_controls(); - - if (listener >= 0) - close(listener); - - listener = -1; - - return 0; -} - -bool ControlMgmt::process_control_commands(int& current_fd, SharedRequest& current_request) -{ - bool ret = false; - - for (vector::iterator control = - controls.begin(); control != controls.end();) - { - int fd = (*control)->get_fd(); - if (FD_ISSET(fd, &inputs)) - { - SharedRequest old_request = current_request; - fd = (*control)->shell_execute(current_fd, current_request); - current_fd = -1; - current_request = old_request; - if (fd < 0) - { - delete_control(control); - ret = false; - continue; - } - else - { - if ((*control)->is_local_control()) - proc_stats.local_commands++; - else - proc_stats.remote_commands++; - ret = true; - } - } - ++control; - } - return ret; -} - -bool ControlMgmt::service_users(int& current_fd, SharedRequest& current_request) -{ - FD_ZERO(&inputs); - int max_fd = -1; - bool ret = false; - - for (auto control : controls) - { - int fd = control->get_fd(); - if (fd >= 0 and !control->is_blocked()) - { - FD_SET(fd, &inputs); - if (fd > max_fd) - max_fd = fd; - } - } - if (listener >= 0) - { - FD_SET(listener, &inputs); - if (listener > max_fd) - max_fd = listener; - } - - struct timeval timeout; - timeout.tv_sec = 0; - timeout.tv_usec = 0; - - if (select(max_fd+1, &inputs, nullptr, nullptr, &timeout) > 0) - { - ret = process_control_commands(current_fd, current_request); - - if (listener >= 0) - { - if (FD_ISSET(listener, &inputs)) - { - if (!socket_conn()) - { - ret = true; - } - } - } - } - return ret; -} - -#endif - -int ControlMgmt::socket_conn() -{ - int remote_control = accept(listener, sock_addr, &sock_addr_size); - - if (remote_control < 0) - return -1; - - add_control(remote_control, false); - - // FIXIT-L authenticate, use ssl ? - return 0; -} - -//------------------------------------------------------------------------- -// socket foo -//------------------------------------------------------------------------- -// FIXIT-M make these non-blocking -// FIXIT-M bind to configured ip including INADDR_ANY -// (default is loopback if enabled) -int ControlMgmt::setup_socket_family() -{ - int family = AF_UNSPEC; - const SnortConfig* sc = SnortConfig::get_conf(); - - if (sc->remote_control_port) - { - memset(&in_addr, 0, sizeof(in_addr)); - - in_addr.sin_family = AF_INET; - in_addr.sin_addr.s_addr = htonl(0x7F000001); - in_addr.sin_port = htons(sc->remote_control_port); - sock_addr = (struct sockaddr*)&in_addr; - sock_addr_size = sizeof(in_addr); - family = AF_INET; - } - else if (!sc->remote_control_socket.empty()) - { - string fullpath; - const char* path_sep = strrchr(sc->remote_control_socket.c_str(), '/'); - if (path_sep != nullptr) - fullpath = sc->remote_control_socket; - else - get_instance_file(fullpath, sc->remote_control_socket.c_str()); - - memset(&unix_addr, 0, sizeof(unix_addr)); - unix_addr.sun_family = AF_UNIX; - SnortStrncpy(unix_addr.sun_path, fullpath.c_str(), sizeof(unix_addr.sun_path)); - sock_addr = (struct sockaddr*)&unix_addr; - sock_addr_size = sizeof(unix_addr); - unlink(fullpath.c_str()); - family = AF_UNIX; - } - return family; -} - - diff --git a/src/main/request.cc b/src/main/request.cc deleted file mode 100644 index 9e5566fd9..000000000 --- a/src/main/request.cc +++ /dev/null @@ -1,119 +0,0 @@ -//-------------------------------------------------------------------------- -// Copyright (C) 2017-2021 Cisco and/or its affiliates. All rights reserved. -// -// This program is free software; you can redistribute it and/or modify it -// under the terms of the GNU General Public License Version 2 as published -// by the Free Software Foundation. You may not use, modify or distribute -// this program under any other version of the GNU General Public License. -// -// This program is distributed in the hope that it will be useful, but -// WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -// General Public License for more details. -// -// You should have received a copy of the GNU General Public License along -// with this program; if not, write to the Free Software Foundation, Inc., -// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. -//-------------------------------------------------------------------------- - -#ifdef HAVE_CONFIG_H -#include "config.h" -#endif - -#include "request.h" - -#include "log/messages.h" -#include "main.h" -#include "utils/util.h" - -using namespace snort; -using namespace std; - -//------------------------------------------------------------------------- -// request foo -//------------------------------------------------------------------------- - -bool Request::read() -{ - bool newline_found = false; - char buf; - ssize_t n = 0; - - while ( (bytes_read < sizeof(read_buf)) and ((n = ::read(fd, &buf, 1)) > 0) ) - { - read_buf[bytes_read++] = buf; - - if (buf == '\n') - { - newline_found = true; - break; - } - } - - if ( n <= 0 and errno != EAGAIN and errno != EINTR ) - return false; - - if ( bytes_read == sizeof(read_buf) ) - bytes_read = 0; - - if ( newline_found ) - { - read_buf[bytes_read] = '\0'; - bytes_read = 0; - return true; - } - else - return false; -} - -bool Request::write_response(const char* s) const -{ - ssize_t n = write(fd, s, strlen(s)); - if ( n < 0 and errno != EAGAIN and errno != EINTR ) - return false; - else - return true; -} - -// FIXIT-L supporting only simple strings for now -// could support var args formats -void Request::respond(const char* s, bool queue_response, bool remote_only) -{ - if (remote_only && (fd == STDOUT_FILENO)) - return; - - if ( fd < 1 ) - { - if (!remote_only) - LogMessage("%s", s); - return; - } - - if ( queue_response ) - { - lock_guard lock(queued_response_mutex); - queued_response.emplace(s); - return; - } - write_response(s); -} - -#ifdef SHELL -bool Request::send_queued_response() -{ - const char* qr; - { - lock_guard lock(queued_response_mutex); - if ( queued_response.empty() ) - return false; - qr = queued_response.front(); - queued_response.pop(); - } - return write_response(qr); -} -#endif - -SharedRequest get_dispatched_request() -{ - return get_current_request(); -} diff --git a/src/main/request.h b/src/main/request.h deleted file mode 100644 index 991b71b21..000000000 --- a/src/main/request.h +++ /dev/null @@ -1,56 +0,0 @@ -//-------------------------------------------------------------------------- -// Copyright (C) 2017-2021 Cisco and/or its affiliates. All rights reserved. -// -// This program is free software; you can redistribute it and/or modify it -// under the terms of the GNU General Public License Version 2 as published -// by the Free Software Foundation. You may not use, modify or distribute -// this program under any other version of the GNU General Public License. -// -// This program is distributed in the hope that it will be useful, but -// WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -// General Public License for more details. -// -// You should have received a copy of the GNU General Public License along -// with this program; if not, write to the Free Software Foundation, Inc., -// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. -//-------------------------------------------------------------------------- - -// This header includes request class which is used by the control connections -// to read control commands and send responses for those commands. - -#ifndef REQUEST_H -#define REQUEST_H - -#include -#include -#include - -#include "main/snort_types.h" - -class SO_PUBLIC Request -{ -public: - Request(int f = -1) : fd(f), bytes_read(0) { } - - bool read(); - const char* get() { return read_buf; } - bool write_response(const char* s) const; - void respond(const char* s, bool queue_response = false, bool remote_only = false); -#ifdef SHELL - bool send_queued_response(); -#endif - -private: - int fd; - char read_buf[1024]; - size_t bytes_read; - std::queue queued_response; - std::mutex queued_response_mutex; -}; - -using SharedRequest = std::shared_ptr; - -SO_PUBLIC SharedRequest get_dispatched_request(); - -#endif diff --git a/src/main/shell.h b/src/main/shell.h index 1c010b24e..356efeff3 100644 --- a/src/main/shell.h +++ b/src/main/shell.h @@ -66,6 +66,9 @@ public: bool get_loaded() const { return loaded; } + lua_State* get_lua() const + { return lua; } + public: static bool is_trusted(const std::string& key); static void allowlist_append(const char* keyword, bool is_prefix); diff --git a/src/main/snort.cc b/src/main/snort.cc index 3f7c2c2c8..ba2493982 100644 --- a/src/main/snort.cc +++ b/src/main/snort.cc @@ -79,8 +79,8 @@ #endif #ifdef SHELL +#include "control/control_mgmt.h" #include "ac_shell_cmd.h" -#include "control_mgmt.h" #endif #include "snort_config.h" diff --git a/src/main/snort_module.cc b/src/main/snort_module.cc index 94a33dc85..a57dbe2cb 100644 --- a/src/main/snort_module.cc +++ b/src/main/snort_module.cc @@ -120,9 +120,9 @@ static const Command snort_cmds[] = { "pause", main_pause, nullptr, "suspend packet processing" }, { "resume", main_resume, s_pktnum, "continue packet processing. " - "If number of packet is specified, will resume for n packets and pause" }, + "If number of packets is specified, will resume for n packets and pause" }, - { "detach", main_detach, nullptr, "exit shell w/o shutdown" }, + { "detach", main_detach, nullptr, "detach from control shell (without shutting down)" }, { "quit", main_quit, nullptr, "shutdown and dump-stats" }, { "help", main_help, nullptr, "this output" }, diff --git a/src/main/test/CMakeLists.txt b/src/main/test/CMakeLists.txt index 4edae5227..e26ac4664 100644 --- a/src/main/test/CMakeLists.txt +++ b/src/main/test/CMakeLists.txt @@ -1,9 +1,4 @@ if ( ENABLE_SHELL ) - add_cpputest(request_test - SOURCES - ../request.cc - ) - add_cpputest(distill_verdict_test SOURCES stubs.h diff --git a/src/main/test/request_test.cc b/src/main/test/request_test.cc deleted file mode 100644 index 79a761a2c..000000000 --- a/src/main/test/request_test.cc +++ /dev/null @@ -1,73 +0,0 @@ -//-------------------------------------------------------------------------- -// Copyright (C) 2019-2021 Cisco and/or its affiliates. All rights reserved. -// -// This program is free software; you can redistribute it and/or modify it -// under the terms of the GNU General Public License Version 2 as published -// by the Free Software Foundation. You may not use, modify or distribute -// this program under any other version of the GNU General Public License. -// -// This program is distributed in the hope that it will be useful, but -// WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -// General Public License for more details. -// -// You should have received a copy of the GNU General Public License along -// with this program; if not, write to the Free Software Foundation, Inc., -// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. -//-------------------------------------------------------------------------- -// request_test.cc author Devendra Dahiphale - -#ifdef HAVE_CONFIG_H -#include "config.h" -#endif - -#include - -#include "main/request.h" - -#include -#include - -namespace snort -{ -void ErrorMessage(const char*,...) { } -void LogMessage(const char*,...) { } -} - -using namespace snort; - -Request& get_current_request() -{ - static Request my_req; - return my_req; -} - -//-------------------------------------------------------------------------- -// Request tests -//-------------------------------------------------------------------------- -TEST_GROUP(request_tests) -{}; - -//-------------------------------------------------------------------------- -// Make sure multiple responses are queued -//-------------------------------------------------------------------------- -TEST(request_tests, queued_response_test) -{ - Request request(STDOUT_FILENO); - - CHECK(request.send_queued_response() == false); // empty queue - request.respond("reloading", true); - request.respond("swapping", true); - CHECK(request.send_queued_response() == true); - CHECK(request.send_queued_response() == true); - CHECK(request.send_queued_response() == false); // empty queue after being written -} - -//------------------------------------------------------------------------- -// main -//------------------------------------------------------------------------- -int main(int argc, char** argv) -{ - return CommandLineTestRunner::RunAllTests(argc, argv); -} - diff --git a/src/network_inspectors/appid/appid_inspector.cc b/src/network_inspectors/appid/appid_inspector.cc index 37abd1499..48a0ff738 100644 --- a/src/network_inspectors/appid/appid_inspector.cc +++ b/src/network_inspectors/appid/appid_inspector.cc @@ -192,7 +192,7 @@ void AppIdInspector::tterm() void AppIdInspector::tear_down(SnortConfig*) { - main_broadcast_command(new ACThirdPartyAppIdCleanup(), true); + main_broadcast_command(new ACThirdPartyAppIdCleanup()); } void AppIdInspector::eval(Packet* p) diff --git a/src/network_inspectors/appid/appid_module.cc b/src/network_inspectors/appid/appid_module.cc index 7f18386f5..3198bfbfa 100644 --- a/src/network_inspectors/appid/appid_module.cc +++ b/src/network_inspectors/appid/appid_module.cc @@ -28,6 +28,7 @@ #include #include +#include "control/control.h" #include "host_tracker/host_cache.h" #include "log/messages.h" #include "main/analyzer.h" @@ -182,15 +183,13 @@ class ACThirdPartyAppIdContextUnload : public AnalyzerCommand public: bool execute(Analyzer&, void**) override; ACThirdPartyAppIdContextUnload(const AppIdInspector& inspector, ThirdPartyAppIdContext* tp_ctxt, - SharedRequest current_request, bool from_shell): inspector(inspector), - tp_ctxt(tp_ctxt), request(current_request), from_shell(from_shell) { } + ControlConn* ctrlcon): inspector(inspector), tp_ctxt(tp_ctxt), ctrlcon(ctrlcon) { } ~ACThirdPartyAppIdContextUnload() override; const char* stringify() override { return "THIRD-PARTY_CONTEXT_UNLOAD"; } private: const AppIdInspector& inspector; ThirdPartyAppIdContext* tp_ctxt = nullptr; - SharedRequest request; - bool from_shell; + ControlConn* ctrlcon; }; bool ACThirdPartyAppIdContextUnload::execute(Analyzer& ac, void**) @@ -216,7 +215,7 @@ ACThirdPartyAppIdContextUnload::~ACThirdPartyAppIdContextUnload() ctxt.create_tp_appid_ctxt(); main_broadcast_command(new ACThirdPartyAppIdContextSwap(inspector)); LogMessage("== reload third-party complete\n"); - request->respond("== reload third-party complete\n", from_shell, true); + ctrlcon->respond("== reload third-party complete\n"); Swapper::set_reload_in_progress(false); } @@ -224,16 +223,14 @@ class ACOdpContextSwap : public AnalyzerCommand { public: bool execute(Analyzer&, void**) override; - ACOdpContextSwap(const AppIdInspector& inspector, OdpContext& odp_ctxt, - SharedRequest current_request, bool from_shell) : inspector(inspector), - odp_ctxt(odp_ctxt), request(current_request), from_shell(from_shell) { } + ACOdpContextSwap(const AppIdInspector& inspector, OdpContext& odp_ctxt, ControlConn* ctrlcon) : + inspector(inspector), odp_ctxt(odp_ctxt), ctrlcon(ctrlcon) { } ~ACOdpContextSwap() override; const char* stringify() override { return "ODP_CONTEXT_SWAP"; } private: const AppIdInspector& inspector; OdpContext& odp_ctxt; - SharedRequest request; - bool from_shell; + ControlConn* ctrlcon; }; bool ACOdpContextSwap::execute(Analyzer&, void**) @@ -269,7 +266,7 @@ ACOdpContextSwap::~ACOdpContextSwap() ctxt.get_odp_ctxt().get_app_info_mgr().dump_appid_configurations(file_path); } LogMessage("== reload detectors complete\n"); - request->respond("== reload detectors complete\n", from_shell, true); + ctrlcon->respond("== reload detectors complete\n"); Swapper::set_reload_in_progress(false); } @@ -307,46 +304,43 @@ static int enable_debug(lua_State* L) AppIdDebugLogEvent event(&constraints, "AppIdDbg"); DataBus::publish(APPID_DEBUG_LOG_EVENT, event); - main_broadcast_command(new AcAppIdDebug(&constraints), true); + main_broadcast_command(new AcAppIdDebug(&constraints), ControlConn::query_from_lua(L)); return 0; } -static int disable_debug(lua_State*) +static int disable_debug(lua_State* L) { AppIdDebugLogEvent event(nullptr, ""); DataBus::publish(APPID_DEBUG_LOG_EVENT, event); - main_broadcast_command(new AcAppIdDebug(nullptr), true); + main_broadcast_command(new AcAppIdDebug(nullptr), ControlConn::query_from_lua(L)); return 0; } static int reload_third_party(lua_State* L) { - SharedRequest current_request = get_current_request(); + ControlConn* ctrlcon = ControlConn::query_from_lua(L); if (Swapper::get_reload_in_progress()) { - current_request->respond("== reload pending; retry\n"); + ctrlcon->respond("== reload pending; retry\n"); return 0; } AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME); if (!inspector) { - current_request->respond("== reload third-party failed - appid not enabled\n"); + ctrlcon->respond("== reload third-party failed - appid not enabled\n"); return 0; } const AppIdContext& ctxt = inspector->get_ctxt(); ThirdPartyAppIdContext* old_ctxt = ctxt.get_tp_appid_ctxt(); if (!old_ctxt) { - current_request->respond("== reload third-party failed - third-party module doesn't exist\n"); + ctrlcon->respond("== reload third-party failed - third-party module doesn't exist\n"); return 0; } Swapper::set_reload_in_progress(true); - - bool from_shell = ( L != nullptr ); - current_request->respond("== unloading old third-party configuration\n", from_shell); - main_broadcast_command(new ACThirdPartyAppIdContextUnload(*inspector, old_ctxt, - current_request, from_shell), from_shell); + ctrlcon->respond("== unloading old third-party configuration\n"); + main_broadcast_command(new ACThirdPartyAppIdContextUnload(*inspector, old_ctxt, ctrlcon), ctrlcon); return 0; } @@ -361,20 +355,20 @@ static void clear_dynamic_host_cache_services() static int reload_detectors(lua_State* L) { - SharedRequest current_request = get_current_request(); + ControlConn* ctrlcon = ControlConn::query_from_lua(L); if (Swapper::get_reload_in_progress()) { - current_request->respond("== reload pending; retry\n"); + ctrlcon->respond("== reload pending; retry\n"); return 0; } AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME); if (!inspector) { - current_request->respond("== reload detectors failed - appid not enabled\n"); + ctrlcon->respond("== reload detectors failed - appid not enabled\n"); return 0; } Swapper::set_reload_in_progress(true); - current_request->respond(".. reloading detectors\n"); + ctrlcon->respond(".. reloading detectors\n"); AppIdContext& ctxt = inspector->get_ctxt(); OdpContext& old_odp_ctxt = ctxt.get_odp_ctxt(); @@ -393,10 +387,8 @@ static int reload_detectors(lua_State* L) odp_thread_local_ctxt->initialize(ctxt, true, true); odp_ctxt.initialize(*inspector); - bool from_shell = ( L != nullptr ); - current_request->respond("== swapping detectors configuration\n", from_shell); - main_broadcast_command(new ACOdpContextSwap(*inspector, old_odp_ctxt, - current_request, from_shell), from_shell); + ctrlcon->respond("== swapping detectors configuration\n"); + main_broadcast_command(new ACOdpContextSwap(*inspector, old_odp_ctxt, ctrlcon), ctrlcon); return 0; } diff --git a/src/network_inspectors/appid/appid_stats.cc b/src/network_inspectors/appid/appid_stats.cc index 6527ac90f..e780f448e 100644 --- a/src/network_inspectors/appid/appid_stats.cc +++ b/src/network_inspectors/appid/appid_stats.cc @@ -264,7 +264,6 @@ void AppIdStatistics::update(const AppIdSession& asd) update_stats(asd, client_id, bucket); } -// Currently not registered to IdleProcessing void AppIdStatistics::flush() { if ( !enabled ) diff --git a/src/network_inspectors/packet_capture/capture_module.cc b/src/network_inspectors/packet_capture/capture_module.cc index 3313d78dd..b4f4d179b 100644 --- a/src/network_inspectors/packet_capture/capture_module.cc +++ b/src/network_inspectors/packet_capture/capture_module.cc @@ -26,6 +26,7 @@ #include +#include "control/control.h" #include "main/analyzer_command.h" #include "profiler/profiler.h" @@ -100,13 +101,13 @@ private: static int enable(lua_State* L) { main_broadcast_command(new PacketCaptureDebug(lua_tostring(L, 1), - luaL_optint(L, 2, 0)), true); + luaL_optint(L, 2, 0)), ControlConn::query_from_lua(L)); return 0; } -static int disable(lua_State*) +static int disable(lua_State* L) { - main_broadcast_command(new PacketCaptureDebug(nullptr, -1), true); + main_broadcast_command(new PacketCaptureDebug(nullptr, -1), ControlConn::query_from_lua(L)); return 0; } diff --git a/src/network_inspectors/packet_tracer/packet_tracer_module.cc b/src/network_inspectors/packet_tracer/packet_tracer_module.cc index a8c8fc346..7b0302bbe 100644 --- a/src/network_inspectors/packet_tracer/packet_tracer_module.cc +++ b/src/network_inspectors/packet_tracer/packet_tracer_module.cc @@ -22,14 +22,15 @@ #include "config.h" #endif -#include - #include "packet_tracer_module.h" +#include + +#include "control/control.h" +#include "log/messages.h" +#include "main/analyzer_command.h" #include "main/snort_config.h" #include "profiler/profiler.h" -#include "main/analyzer_command.h" -#include "log/messages.h" #include "sfip/sf_ip.h" #include "packet_tracer.h" @@ -148,13 +149,13 @@ static int enable(lua_State* L) constraints.set_bits |= PacketConstraints::SetBits::SRC_PORT; if ( dport ) constraints.set_bits |= PacketConstraints::SetBits::DST_PORT; - main_broadcast_command(new PacketTracerDebug(&constraints), true); + main_broadcast_command(new PacketTracerDebug(&constraints), ControlConn::query_from_lua(L)); return 0; } -static int disable(lua_State*) +static int disable(lua_State* L) { - main_broadcast_command(new PacketTracerDebug(nullptr), true); + main_broadcast_command(new PacketTracerDebug(nullptr), ControlConn::query_from_lua(L)); return 0; } diff --git a/src/network_inspectors/perf_monitor/perf_module.cc b/src/network_inspectors/perf_monitor/perf_module.cc index 69310fda0..1e6ede6e5 100644 --- a/src/network_inspectors/perf_monitor/perf_module.cc +++ b/src/network_inspectors/perf_monitor/perf_module.cc @@ -26,6 +26,7 @@ #include +#include "control/control.h" #include "log/messages.h" #include "main/analyzer_command.h" #include "main/snort.h" @@ -155,8 +156,8 @@ static int enable_flow_ip_profiling(lua_State* L) auto* new_constraints = new PerfConstraints(true, luaL_optint(L, 1, 0), luaL_optint(L, 2, 0)); - main_broadcast_command(new PerfMonFlowIPDebug(new_constraints, true, perf_monitor), - true); + ControlConn* ctrlcon = ControlConn::query_from_lua(L); + main_broadcast_command(new PerfMonFlowIPDebug(new_constraints, true, perf_monitor), ctrlcon); LogMessage("Enabling flow ip profiling with sample interval %d packet count %d\n", new_constraints->sample_interval, new_constraints->pkt_cnt); @@ -164,7 +165,7 @@ static int enable_flow_ip_profiling(lua_State* L) return 0; } -static int disable_flow_ip_profiling(lua_State*) +static int disable_flow_ip_profiling(lua_State* L) { PerfMonitor* perf_monitor = (PerfMonitor*)InspectorManager::get_inspector(PERF_NAME, true); @@ -185,8 +186,8 @@ static int disable_flow_ip_profiling(lua_State*) auto* new_constraints = perf_monitor->get_original_constraints(); - main_broadcast_command(new PerfMonFlowIPDebug(new_constraints, false, perf_monitor), - true); + ControlConn* ctrlcon = ControlConn::query_from_lua(L); + main_broadcast_command(new PerfMonFlowIPDebug(new_constraints, false, perf_monitor), ctrlcon); LogMessage("Disabling flow ip profiling\n"); diff --git a/src/network_inspectors/rna/rna_module.cc b/src/network_inspectors/rna/rna_module.cc index 808f15d3e..af0a77360 100644 --- a/src/network_inspectors/rna/rna_module.cc +++ b/src/network_inspectors/rna/rna_module.cc @@ -31,10 +31,10 @@ #include #include +#include "control/control.h" #include "host_tracker/host_cache.h" #include "log/messages.h" #include "lua/lua.h" -#include "main/request.h" #include "main/snort_config.h" #include "managers/inspector_manager.h" #include "managers/module_manager.h" @@ -92,13 +92,12 @@ static int purge_data(lua_State* L) if ( rna ) { HostCacheMac* mac_cache = new HostCacheMac(MAC_CACHE_INITIAL_SIZE); - bool from_shell = ( L != nullptr ); - main_broadcast_command(new DataPurgeAC(mac_cache), from_shell); + ControlConn* ctrlcon = ControlConn::query_from_lua(L); + main_broadcast_command(new DataPurgeAC(mac_cache), ctrlcon); host_cache.invalidate(); - SharedRequest request = get_dispatched_request(); - request->respond("data purge done\n", from_shell, true); + ctrlcon->respond("data purge done\n"); LogMessage("data purge done\n"); } diff --git a/src/network_inspectors/rna/test/rna_module_mock.h b/src/network_inspectors/rna/test/rna_module_mock.h index fe4b2451e..367bb92c0 100644 --- a/src/network_inspectors/rna/test/rna_module_mock.h +++ b/src/network_inspectors/rna/test/rna_module_mock.h @@ -21,8 +21,6 @@ #ifndef RNA_MODULE_MOCK_H #define RNA_MODULE_MOCK_H -#include "main/request.h" - #include "../rna_mac_cache.cc" THREAD_LOCAL RnaStats rna_stats; @@ -111,19 +109,20 @@ private: RnaModuleConfig* mod_conf = nullptr; }; - } // end of namespace snort -static SharedRequest mock_request = std::make_shared(); -void Request::respond(const char*, bool, bool) { } -SharedRequest get_dispatched_request() { return mock_request; } +void snort::main_broadcast_command(snort::AnalyzerCommand*, ControlConn*) {} +static ControlConn s_ctrlcon(1, true); +ControlConn::ControlConn(int, bool) {} +ControlConn::~ControlConn() {} +ControlConn* ControlConn::query_from_lua(const lua_State*) { return &s_ctrlcon; } +bool ControlConn::respond(const char*, ...) { return true; } HostCacheMac* get_host_cache_mac() { return nullptr; } DataPurgeAC::~DataPurgeAC() = default; bool DataPurgeAC::execute(Analyzer&, void**) { return true;} -void snort::main_broadcast_command(AnalyzerCommand*, bool) { } void set_host_cache_mac(HostCacheMac*) { } Inspector* InspectorManager::get_inspector(const char*, bool, const SnortConfig*) diff --git a/src/trace/trace_swap.cc b/src/trace/trace_swap.cc index c9cabebe1..8ac9ff043 100644 --- a/src/trace/trace_swap.cc +++ b/src/trace/trace_swap.cc @@ -25,6 +25,7 @@ #include +#include "control/control.h" #include "framework/module.h" #include "framework/packet_constraints.h" #include "log/messages.h" @@ -362,8 +363,8 @@ static int set(lua_State* L) if ( log_params.set_constraints ) trace_parser.finalize_constraints(); - main_broadcast_command(new TraceSwap( - &trace_parser.get_trace_config(), log_params), true); + ControlConn* ctrlcon = ControlConn::query_from_lua(L); + main_broadcast_command(new TraceSwap(&trace_parser.get_trace_config(), log_params), ctrlcon); } else delete trace_config; @@ -371,11 +372,12 @@ static int set(lua_State* L) return 0; } -static int clear(lua_State*) +static int clear(lua_State* L) { + ControlConn* ctrlcon = ControlConn::query_from_lua(L); // Create an empty overlay TraceConfig // It will be set in a SnortConfig during TraceSwap execution and owned by it after - main_broadcast_command(new TraceSwap(new TraceConfig, {}), true); + main_broadcast_command(new TraceSwap(new TraceConfig, {}), ctrlcon); return 0; }