]> git.ipfire.org Git - people/ms/rstp.git/blobdiff - ctl_socket_client.c
Use new RSTP library.
[people/ms/rstp.git] / ctl_socket_client.c
index 46d585193b876fa0ea93e7eec83a80e970d7c293..1a219bc441672b349dfbec64eb4fa4f5eac27c75 100644 (file)
@@ -32,6 +32,7 @@
 #include <stdio.h>
 #include <stdlib.h>
 
+#define NO_DAEMON
 #include "log.h"
 
 static int fd = -1;
@@ -81,28 +82,31 @@ void ctl_client_cleanup(void)
        }
 }
 
-int send_ctl_message(int cmd, void *inbuf, int lin, void *outbuf, int *lout,
-                    int *res)
+int send_ctl_message(int cmd, void *inbuf, int lin, void *outbuf, int lout,
+                    LogString *log, int *res)
 {
        struct ctl_msg_hdr mhdr;
        struct msghdr msg;
-       struct iovec iov[2];
+       struct iovec iov[3];
        int l;
 
        msg.msg_name = NULL;
        msg.msg_namelen = 0;
        msg.msg_iov = iov;
-       msg.msg_iovlen = 2;
+       msg.msg_iovlen = 3;
        msg.msg_control = NULL;
        msg.msg_controllen = 0;
 
        mhdr.cmd = cmd;
        mhdr.lin = lin;
-       mhdr.lout = lout != NULL ? *lout : 0;
+       mhdr.lout = lout;
+       mhdr.llog = sizeof(log->buf) - 1;
        iov[0].iov_base = &mhdr;
        iov[0].iov_len = sizeof(mhdr);
        iov[1].iov_base = (void *)inbuf;
        iov[1].iov_len = lin;
+        iov[2].iov_base = log->buf;
+        iov[2].iov_len = 0;
 
        l = sendmsg(fd, &msg, 0);
        if (l < 0) {
@@ -115,7 +119,10 @@ int send_ctl_message(int cmd, void *inbuf, int lin, void *outbuf, int *lout,
        }
 
        iov[1].iov_base = outbuf;
-       iov[1].iov_len = lout != NULL ? *lout : 0;
+       iov[1].iov_len = lout;
+
+        iov[2].iov_base = log->buf;
+        iov[2].iov_len = sizeof(log->buf);
 
        {
                struct pollfd pfd;
@@ -145,16 +152,24 @@ int send_ctl_message(int cmd, void *inbuf, int lin, void *outbuf, int *lout,
                        ERROR("Error getting message from server: %m");
                        return -1;
                }
-               if (l < sizeof(mhdr) || l != sizeof(mhdr) + mhdr.lout
+               if (l < sizeof(mhdr)
+                   || l != sizeof(mhdr) + mhdr.lout + mhdr.llog
                    || mhdr.cmd != cmd) {
                        ERROR("Error getting message from server: Bad format");
                        return -1;
                }
+               if (mhdr.lout != lout) {
+                       ERROR("Error, unexpected result length %d, "
+                             "expected %d\n", mhdr.lout, lout);
+                       return -1;
+               }
+               if (mhdr.llog >= sizeof(log->buf)) {
+                       ERROR("Invalid log message length %d", mhdr.llog);
+                       return -1;
+               }
        }
-       if (lout)
-               *lout = mhdr.lout;
        if (res)
                *res = mhdr.res;
-
+       log->buf[mhdr.llog] = 0;
        return 0;
 }