]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
unix-manager: implement multi client support
authorEric Leblond <eric@regit.org>
Mon, 29 Oct 2012 10:56:46 +0000 (11:56 +0100)
committerEric Leblond <eric@regit.org>
Mon, 19 Nov 2012 22:54:27 +0000 (23:54 +0100)
This patch implements the support of multiple clients connected
at once to the unix socket.

src/unix-manager.c

index ee87678d545a69e9df1f9f2708c06d58279b6f91..9b857562429747bb85a69bf9936df89f3dad88fd 100644 (file)
@@ -58,15 +58,19 @@ typedef struct Task_ {
     TAILQ_ENTRY(Task_) next;
 } Task;
 
+typedef struct UnixClient_ {
+    int fd;
+    TAILQ_ENTRY(UnixClient_) next;
+} UnixClient;
+
 typedef struct UnixCommand_ {
     time_t start_timestamp;
     int socket;
-    int client;
     struct sockaddr_un client_addr;
     int select_max;
-    fd_set select_set;
     TAILQ_HEAD(, Command_) commands;
     TAILQ_HEAD(, Task_) tasks;
+    TAILQ_HEAD(, UnixClient_) clients;
 } UnixCommand;
 
 /**
@@ -85,11 +89,11 @@ int UnixNew(UnixCommand * this)
 
     this->start_timestamp = time(NULL);
     this->socket = -1;
-    this->client = -1;
     this->select_max = 0;
 
     TAILQ_INIT(&this->commands);
     TAILQ_INIT(&this->tasks);
+    TAILQ_INIT(&this->clients);
 
     /* Create socket dir */
     ret = mkdir(SOCKET_PATH, S_IRWXU|S_IXGRP|S_IRGRP);
@@ -180,17 +184,47 @@ int UnixNew(UnixCommand * this)
     return 1;
 }
 
+void UnixCommandSetMaxFD(UnixCommand *this) {
+    UnixClient *item;
+
+    if (this == NULL) {
+        SCLogError(SC_ERR_INVALID_ARGUMENT, "Unix command is NULL, warn devel");
+        return;
+    }
+
+    this->select_max = this->socket + 1;
+    TAILQ_FOREACH(item, &this->clients, next) {
+        if (item->fd >= this->select_max) {
+            this->select_max = item->fd + 1;
+        }
+    }
+}
+
 /**
  * \brief Close the unix socket
  */
-void UnixCommandClose(UnixCommand  *this)
+void UnixCommandClose(UnixCommand  *this, int fd)
 {
-    if (this->client == -1)
+    UnixClient *item;
+    int found = 0;
+
+    TAILQ_FOREACH(item, &this->clients, next) {
+        if (item->fd == fd) {
+            found = 1;
+            break;
+        }
+    }
+
+    if (found == 0) {
+        SCLogError(SC_ERR_INVALID_VALUE, "No fd found in client list");
         return;
-    SCLogInfo("Unix socket: close client connection");
-    close(this->client);
-    this->client = -1;
-    this->select_max = this->socket + 1;
+    }
+
+    TAILQ_REMOVE(&this->clients, item, next);
+
+    close(item->fd);
+    UnixCommandSetMaxFD(this);
+    SCFree(item);
 }
 
 /**
@@ -198,9 +232,9 @@ void UnixCommandClose(UnixCommand  *this)
  */
 int UnixCommandSendCallback(const char *buffer, size_t size, void *data)
 {
-    UnixCommand *this = (UnixCommand *) data;
+    int fd = *(int *) data;
 
-    if (send(this->client, buffer, size, MSG_NOSIGNAL) == -1) {
+    if (send(fd, buffer, size, MSG_NOSIGNAL) == -1) {
         SCLogInfo("Unable to send block: %s", strerror(errno));
         return -1;
     }
@@ -227,13 +261,15 @@ int UnixCommandAccept(UnixCommand *this)
     json_t *server_msg;
     json_t *version;
     json_error_t jerror;
+    int client;
     int ret;
+    UnixClient *uclient = NULL;
 
     /* accept client socket */
     socklen_t len = sizeof(this->client_addr);
-    this->client = accept(this->socket, (struct sockaddr *) &this->client_addr,
+    client = accept(this->socket, (struct sockaddr *) &this->client_addr,
                           &len);
-    if (this->client < 0) {
+    if (client < 0) {
         SCLogInfo("Unix socket: accept() error: %s",
                   strerror(errno));
         return 0;
@@ -242,30 +278,30 @@ int UnixCommandAccept(UnixCommand *this)
 
     /* read client version */
     buffer[sizeof(buffer)-1] = 0;
-    ret = recv(this->client, buffer, sizeof(buffer)-1, 0);
+    ret = recv(client, buffer, sizeof(buffer)-1, 0);
     if (ret < 0) {
         SCLogInfo("Command server: client doesn't send version");
-        UnixCommandClose(this);
+        UnixCommandClose(this, client);
         return 0;
     }
     if (ret >= (int)(sizeof(buffer)-1)) {
         SCLogInfo("Command server: client message is too long, "
                   "disconnect him.");
-        UnixCommandClose(this);
+        UnixCommandClose(this, client);
     }
     buffer[ret] = 0;
 
     client_msg = json_loads(buffer, 0, &jerror);
     if (client_msg == NULL) {
         SCLogInfo("Invalid command, error on line %d: %s\n", jerror.line, jerror.text);
-        UnixCommandClose(this);
+        UnixCommandClose(this, client);
         return 0;
     }
 
     version = json_object_get(client_msg, "version");
     if(!json_is_string(version)) {
         SCLogInfo("error: version is not a string");
-        UnixCommandClose(this);
+        UnixCommandClose(this, client);
         return 0;
     }
 
@@ -273,7 +309,7 @@ int UnixCommandAccept(UnixCommand *this)
     if (strcmp(json_string_value(version), UNIX_PROTO_VERSION) != 0) {
         SCLogInfo("Unix socket: invalid client version: \"%s\"",
                 json_string_value(version));
-        UnixCommandClose(this);
+        UnixCommandClose(this, client);
         return 0;
     } else {
         SCLogInfo("Unix socket: client version: \"%s\"",
@@ -283,23 +319,28 @@ int UnixCommandAccept(UnixCommand *this)
     /* send answer */
     server_msg = json_object();
     if (server_msg == NULL) {
-        UnixCommandClose(this);
+        UnixCommandClose(this, client);
         return 0;
     }
     json_object_set_new(server_msg, "return", json_string("OK"));
 
-    if (json_dump_callback(server_msg, UnixCommandSendCallback, this, 0) == -1) {
+    if (json_dump_callback(server_msg, UnixCommandSendCallback, &client, 0) == -1) {
         SCLogWarning(SC_ERR_SOCKET, "Unable to send command");
-        UnixCommandClose(this);
+        UnixCommandClose(this, client);
         return 0;
     }
 
     /* client connected */
     SCLogInfo("Unix socket: client connected");
-    if (this->socket < this->client)
-        this->select_max = this->client + 1;
-    else
-        this->select_max = this->socket + 1;
+    
+    uclient = SCMalloc(sizeof(UnixClient));
+    if (uclient == NULL) {
+        SCLogError(SC_ERR_MEM_ALLOC, "Can't allocate new cient");
+        return 0;
+    }
+    uclient->fd = client;
+    TAILQ_INSERT_TAIL(&this->clients, uclient, next);
+    UnixCommandSetMaxFD(this);
     return 1;
 }
 
@@ -326,7 +367,7 @@ int UnixCommandBackgroundTasks(UnixCommand* this)
  *
  * \retval 0 in case of error, 1 in case of success
  */
-int UnixCommandExecute(UnixCommand * this, char *command)
+int UnixCommandExecute(UnixCommand * this, char *command, UnixClient *client)
 {
     int ret = 1;
     json_error_t error;
@@ -387,7 +428,7 @@ int UnixCommandExecute(UnixCommand * this, char *command)
     }
 
     /* send answer */
-    if (json_dump_callback(server_msg, UnixCommandSendCallback, this, 0) == -1) {
+    if (json_dump_callback(server_msg, UnixCommandSendCallback, &client->fd, 0) == -1) {
         SCLogWarning(SC_ERR_SOCKET, "Unable to send command");
         goto error_cmd;
     }
@@ -399,15 +440,15 @@ error_cmd:
 error:
     json_decref(jsoncmd);
     json_decref(server_msg);
-    UnixCommandClose(this);
+    UnixCommandClose(this, client->fd);
     return 0;
 }
 
-void UnixCommandRun(UnixCommand * this)
+void UnixCommandRun(UnixCommand * this, UnixClient *client)
 {
     char buffer[4096];
     int ret;
-    ret = recv(this->client, buffer, sizeof(buffer) - 1, 0);
+    ret = recv(client->fd, buffer, sizeof(buffer) - 1, 0);
     if (ret <= 0) {
         if (ret == 0) {
             SCLogInfo("Unix socket: lost connection with client");
@@ -415,16 +456,16 @@ void UnixCommandRun(UnixCommand * this)
             SCLogInfo("Unix socket: error on recv() from client: %s",
                       strerror(errno));
         }
-        UnixCommandClose(this);
+        UnixCommandClose(this, client->fd);
         return;
     }
     if (ret >= (int)(sizeof(buffer)-1)) {
         SCLogInfo("Command server: client command is too long, "
                   "disconnect him.");
-        UnixCommandClose(this);
+        UnixCommandClose(this, client->fd);
     }
     buffer[ret] = 0;
-    UnixCommandExecute(this, buffer);
+    UnixCommandExecute(this, buffer, client);
 }
 
 /**
@@ -436,15 +477,19 @@ int UnixMain(UnixCommand * this)
 {
     struct timeval tv;
     int ret;
+    fd_set select_set;
+    UnixClient *uclient;
 
     /* Wait activity on the socket */
-    FD_ZERO(&this->select_set);
-    FD_SET(this->socket, &this->select_set);
-    if (0 <= this->client)
-        FD_SET(this->client, &this->select_set);
+    FD_ZERO(&select_set);
+    FD_SET(this->socket, &select_set);
+    TAILQ_FOREACH(uclient, &this->clients, next) {
+        FD_SET(uclient->fd, &select_set);
+    }
+
     tv.tv_sec = 0;
     tv.tv_usec = 200 * 1000;
-    ret = select(this->select_max, &this->select_set, NULL, NULL, &tv);
+    ret = select(this->select_max, &select_set, NULL, NULL, &tv);
 
     /* catch select() error */
     if (ret == -1) {
@@ -457,7 +502,6 @@ int UnixMain(UnixCommand * this)
     }
 
     if (suricata_ctl_flags & (SURICATA_STOP | SURICATA_KILL)) {
-        UnixCommandClose(this);
         return 1;
     }
 
@@ -466,10 +510,13 @@ int UnixMain(UnixCommand * this)
         return 1;
     }
 
-    if (0 <= this->client && FD_ISSET(this->client, &this->select_set)) {
-        UnixCommandRun(this);
+    
+    TAILQ_FOREACH(uclient, &this->clients, next) {
+        if (FD_ISSET(uclient->fd, &select_set)) {
+            UnixCommandRun(this, uclient);
+        }
     }
-    if (FD_ISSET(this->socket, &this->select_set)) {
+    if (FD_ISSET(this->socket, &select_set)) {
         if (!UnixCommandAccept(this))
             return 0;
     }
@@ -644,6 +691,7 @@ TmEcode UnixManagerRegisterBackgroundTask(
 void *UnixManagerThread(void *td)
 {
     ThreadVars *th_v = (ThreadVars *)td;
+    int ret;
 
     /* set the thread name */
     (void) SCSetThreadName(th_v->name);
@@ -677,10 +725,17 @@ void *UnixManagerThread(void *td)
 
     TmThreadsSetFlag(th_v, THV_INIT_DONE);
     while (1) {
-        UnixMain(&command);
+        ret = UnixMain(&command);
+        if (ret == 0) {
+            SCLogError(SC_ERR_FATAL, "Fatal error on unix socket");
+        }
 
-        if (TmThreadsCheckFlag(th_v, THV_KILL)) {
-            UnixCommandClose(&command);
+        if ((ret == 0) || (TmThreadsCheckFlag(th_v, THV_KILL))) {
+            UnixClient *item;
+            TAILQ_FOREACH(item, &(&command)->clients, next) {
+                close(item->fd);
+                SCFree(item);
+            }
             SCPerfSyncCounters(th_v, 0);
             break;
         }