]> git.ipfire.org Git - thirdparty/chrony.git/commitdiff
socket: simplify receiving messages
authorMiroslav Lichvar <mlichvar@redhat.com>
Tue, 24 Mar 2020 14:22:31 +0000 (15:22 +0100)
committerMiroslav Lichvar <mlichvar@redhat.com>
Thu, 26 Mar 2020 10:04:18 +0000 (11:04 +0100)
Don't require the caller to provide a SCK_Message (on stack). Modify the
SCK_ReceiveMessage*() functions to return a pointer to static buffers,
as the message buffer which SCK_Message points to already is.

cmdmon.c
ntp_io.c
nts_ke_server.c
privops.c
socket.c
socket.h

index 908977effd76b90bcfbd67a677f247ad7e1e9b90..60ae1ccb9dbc2814173736a3efccd6ffed15bea2 100644 (file)
--- a/cmdmon.c
+++ b/cmdmon.c
@@ -1237,7 +1237,7 @@ handle_reset(CMD_Request *rx_message, CMD_Reply *tx_message)
 static void
 read_from_cmd_socket(int sock_fd, int event, void *anything)
 {
-  SCK_Message sck_message;
+  SCK_Message *sck_message;
   CMD_Request rx_message;
   CMD_Reply tx_message;
   IPAddr loopback_addr, remote_ip;
@@ -1246,26 +1246,27 @@ read_from_cmd_socket(int sock_fd, int event, void *anything)
   unsigned short rx_command;
   struct timespec now, cooked_now;
 
-  if (!SCK_ReceiveMessage(sock_fd, &sck_message, 0))
+  sck_message = SCK_ReceiveMessage(sock_fd, 0);
+  if (!sck_message)
     return;
 
-  read_length = sck_message.length;
+  read_length = sck_message->length;
 
   /* Get current time cheaply */
   SCH_GetLastEventTime(&cooked_now, NULL, &now);
 
   /* Check if it's from localhost (127.0.0.1, ::1, or Unix domain),
      or an authorised address */
-  switch (sck_message.addr_type) {
+  switch (sck_message->addr_type) {
     case SCK_ADDR_IP:
       assert(sock_fd == sock_fd4 || sock_fd == sock_fd6);
-      remote_ip = sck_message.remote_addr.ip.ip_addr;
+      remote_ip = sck_message->remote_addr.ip.ip_addr;
       SCK_GetLoopbackIPAddress(remote_ip.family, &loopback_addr);
       localhost = UTI_CompareIPs(&remote_ip, &loopback_addr, NULL) == 0;
 
       if (!localhost && !ADF_IsAllowed(access_auth_table, &remote_ip)) {
         DEBUG_LOG("Unauthorised host %s",
-                  UTI_IPSockAddrToString(&sck_message.remote_addr.ip));
+                  UTI_IPSockAddrToString(&sck_message->remote_addr.ip));
         return;
       }
 
@@ -1291,7 +1292,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything)
     return;
   }
 
-  memcpy(&rx_message, sck_message.data, read_length);
+  memcpy(&rx_message, sck_message->data, read_length);
 
   if (rx_message.pkt_type != PKT_TYPE_CMD_REQUEST ||
       rx_message.res1 != 0 ||
@@ -1313,8 +1314,8 @@ read_from_cmd_socket(int sock_fd, int event, void *anything)
   rx_command = ntohs(rx_message.command);
 
   memset(&tx_message, 0, sizeof (tx_message));
-  sck_message.data = &tx_message;
-  sck_message.length = 0;
+  sck_message->data = &tx_message;
+  sck_message->length = 0;
 
   tx_message.version = PROTO_VERSION_NUMBER;
   tx_message.pkt_type = PKT_TYPE_CMD_REPLY;
@@ -1329,7 +1330,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything)
 
     if (rx_message.version >= PROTO_VERSION_MISMATCH_COMPAT_SERVER) {
       tx_message.status = htons(STT_BADPKTVERSION);
-      transmit_reply(sock_fd, &sck_message);
+      transmit_reply(sock_fd, sck_message);
     }
     return;
   }
@@ -1339,7 +1340,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything)
     DEBUG_LOG("Command packet has invalid command %d", rx_command);
 
     tx_message.status = htons(STT_INVALID);
-    transmit_reply(sock_fd, &sck_message);
+    transmit_reply(sock_fd, sck_message);
     return;
   }
 
@@ -1348,7 +1349,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything)
               expected_length);
 
     tx_message.status = htons(STT_BADPKTLENGTH);
-    transmit_reply(sock_fd, &sck_message);
+    transmit_reply(sock_fd, sck_message);
     return;
   }
 
@@ -1629,7 +1630,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything)
     static int do_it=1;
 
     if (do_it) {
-      transmit_reply(sock_fd, &sck_message);
+      transmit_reply(sock_fd, sck_message);
     }
 
 #if 0
index a70f6c64e1ba460e051fc9190bf3873284e14639..31e21c7bee99932e8ead327935d664e2c9f0ff07 100644 (file)
--- a/ntp_io.c
+++ b/ntp_io.c
@@ -407,7 +407,7 @@ read_from_socket(int sock_fd, int event, void *anything)
   /* This should only be called when there is something
      to read, otherwise it may block */
 
-  SCK_Message messages[SCK_MAX_RECV_MESSAGES];
+  SCK_Message *messages;
   int i, received, flags = 0;
 
 #ifdef HAVE_LINUX_TIMESTAMPING
@@ -423,8 +423,8 @@ read_from_socket(int sock_fd, int event, void *anything)
 #endif
   }
 
-  received = SCK_ReceiveMessages(sock_fd, messages, SCK_MAX_RECV_MESSAGES, flags);
-  if (received <= 0)
+  messages = SCK_ReceiveMessages(sock_fd, flags, &received);
+  if (!messages)
     return;
 
   for (i = 0; i < received; i++)
index 14bb621b67584d76f40054a5d96a429e9c84fdbb..a3fa9f924b0d6bd81cb20c43dd27c35d3067389d 100644 (file)
@@ -139,28 +139,29 @@ handle_client(int sock_fd, IPSockAddr *addr)
 static void
 handle_helper_request(int fd, int event, void *arg)
 {
-  SCK_Message message;
+  SCK_Message *message;
   HelperRequest *req;
   IPSockAddr client_addr;
   int sock_fd;
 
-  if (!SCK_ReceiveMessage(fd, &message, SCK_FLAG_MSG_DESCRIPTOR))
+  message = SCK_ReceiveMessage(fd, SCK_FLAG_MSG_DESCRIPTOR);
+  if (!message)
     return;
 
-  sock_fd = message.descriptor;
+  sock_fd = message->descriptor;
   if (sock_fd < 0) {
     /* Message with no descriptor is a shutdown command */
     SCH_QuitProgram();
     return;
   }
 
-  if (message.length != sizeof (HelperRequest)) {
+  if (message->length != sizeof (HelperRequest)) {
     DEBUG_LOG("Unexpected message length");
     SCK_CloseSocket(sock_fd);
     return;
   }
 
-  req = message.data;
+  req = message->data;
 
   /* Extract the server key and client address from the request */
   server_keys[current_server_key].id = ntohl(req->key_id);
index 6d06c4ced8c3f2d6b4f4d2d9894cbaa0c9333dab..e999f366bbbd4c41c65b043929a8522ac2608b0c 100644 (file)
--- a/privops.c
+++ b/privops.c
@@ -171,22 +171,22 @@ send_response(int fd, const PrvResponse *res)
 static int
 receive_from_daemon(int fd, PrvRequest *req)
 {
-  SCK_Message message;
+  SCK_Message *message;
 
-  if (!SCK_ReceiveMessage(fd, &message, SCK_FLAG_MSG_DESCRIPTOR) ||
-      message.length != sizeof (*req))
+  message = SCK_ReceiveMessage(fd, SCK_FLAG_MSG_DESCRIPTOR);
+  if (!message || message->length != sizeof (*req))
     return 0;
 
-  memcpy(req, message.data, sizeof (*req));
+  memcpy(req, message->data, sizeof (*req));
 
   if (req->op == OP_BINDSOCKET) {
-    req->data.bind_socket.sock = message.descriptor;
+    req->data.bind_socket.sock = message->descriptor;
 
     /* return error if valid descriptor not found */
     if (req->data.bind_socket.sock < 0)
       return 0;
-  } else if (message.descriptor >= 0) {
-    SCK_CloseSocket(message.descriptor);
+  } else if (message->descriptor >= 0) {
+    SCK_CloseSocket(message->descriptor);
     return 0;
   }
 
index 3c1ffd334ff592e439a8b1d4042d67df350eb4af..9e874d2bb11e44c0fc6efbb7f1c96c812580872b 100644 (file)
--- a/socket.c
+++ b/socket.c
@@ -68,7 +68,7 @@ struct Message {
 };
 
 #ifdef HAVE_RECVMMSG
-#define MAX_RECV_MESSAGES SCK_MAX_RECV_MESSAGES
+#define MAX_RECV_MESSAGES 4
 #define MessageHeader mmsghdr
 #else
 /* Compatible with mmsghdr */
@@ -85,9 +85,10 @@ static int initialised;
 /* Flags supported by socket() */
 static int supported_socket_flags;
 
-/* Arrays of Message and MessageHeader */
+/* Arrays of Message, MessageHeader, and SCK_Message */
 static ARR_Instance recv_messages;
 static ARR_Instance recv_headers;
+static ARR_Instance recv_sck_messages;
 
 static unsigned int received_messages;
 
@@ -867,22 +868,27 @@ process_header(struct msghdr *msg, unsigned int msg_length, int sock_fd, int fla
 
 /* ================================================== */
 
-static int
-receive_messages(int sock_fd, SCK_Message *messages, int max_messages, int flags)
+static SCK_Message *
+receive_messages(int sock_fd, int flags, int max_messages, int *num_messages)
 {
   struct MessageHeader *hdr;
+  SCK_Message *messages;
   unsigned int i, n;
   int ret, recv_flags = 0;
 
   assert(initialised);
 
+  *num_messages = 0;
+
   if (max_messages < 1)
-    return 0;
+    return NULL;
 
   /* Prepare used buffers for new messages */
   prepare_buffers(received_messages);
   received_messages = 0;
 
+  messages = ARR_GetElements(recv_sck_messages);
+
   hdr = ARR_GetElements(recv_headers);
   n = ARR_GetSize(recv_headers);
   n = MIN(n, max_messages);
@@ -903,7 +909,7 @@ receive_messages(int sock_fd, SCK_Message *messages, int max_messages, int flags
 
   if (ret < 0) {
     handle_recv_error(sock_fd, flags);
-    return 0;
+    return NULL;
   }
 
   received_messages = n;
@@ -911,13 +917,15 @@ receive_messages(int sock_fd, SCK_Message *messages, int max_messages, int flags
   for (i = 0; i < n; i++) {
     hdr = ARR_GetElement(recv_headers, i);
     if (!process_header(&hdr->msg_hdr, hdr->msg_len, sock_fd, flags, &messages[i]))
-      return 0;
+      return NULL;
 
     log_message(sock_fd, 1, &messages[i],
                 flags & SCK_FLAG_MSG_ERRQUEUE ? "Received error" : "Received", NULL);
   }
 
-  return n;
+  *num_messages = n;
+
+  return messages;
 }
 
 /* ================================================== */
@@ -1092,6 +1100,8 @@ SCK_Initialise(void)
   ARR_SetSize(recv_messages, MAX_RECV_MESSAGES);
   recv_headers = ARR_CreateInstance(sizeof (struct MessageHeader));
   ARR_SetSize(recv_headers, MAX_RECV_MESSAGES);
+  recv_sck_messages = ARR_CreateInstance(sizeof (SCK_Message));
+  ARR_SetSize(recv_sck_messages, MAX_RECV_MESSAGES);
 
   received_messages = MAX_RECV_MESSAGES;
 
@@ -1115,6 +1125,7 @@ SCK_Initialise(void)
 void
 SCK_Finalise(void)
 {
+  ARR_DestroyInstance(recv_sck_messages);
   ARR_DestroyInstance(recv_headers);
   ARR_DestroyInstance(recv_messages);
 
@@ -1381,18 +1392,20 @@ SCK_Send(int sock_fd, const void *buffer, unsigned int length, int flags)
 
 /* ================================================== */
 
-int
-SCK_ReceiveMessage(int sock_fd, SCK_Message *message, int flags)
+SCK_Message *
+SCK_ReceiveMessage(int sock_fd, int flags)
 {
-  return SCK_ReceiveMessages(sock_fd, message, 1, flags);
+  int num_messages;
+
+  return receive_messages(sock_fd, flags, 1, &num_messages);
 }
 
 /* ================================================== */
 
-int
-SCK_ReceiveMessages(int sock_fd, SCK_Message *messages, int max_messages, int flags)
+SCK_Message *
+SCK_ReceiveMessages(int sock_fd, int flags, int *num_messages)
 {
-  return receive_messages(sock_fd, messages, max_messages, flags);
+  return receive_messages(sock_fd, flags, MAX_RECV_MESSAGES, num_messages);
 }
 
 /* ================================================== */
index ee44526631dcfa0482a0ba4a1f2f9a2cd9211132..949690b1eb5228319aefc828d037501e41f384b0 100644 (file)
--- a/socket.h
+++ b/socket.h
@@ -41,9 +41,6 @@
 #define SCK_FLAG_MSG_ERRQUEUE 1
 #define SCK_FLAG_MSG_DESCRIPTOR 2
 
-/* Maximum number of received messages */
-#define SCK_MAX_RECV_MESSAGES 4
-
 typedef enum {
   SCK_ADDR_UNSPEC = 0,
   SCK_ADDR_IP,
@@ -119,12 +116,11 @@ extern int SCK_ShutdownConnection(int sock_fd);
 extern int SCK_Receive(int sock_fd, void *buffer, unsigned int length, int flags);
 extern int SCK_Send(int sock_fd, const void *buffer, unsigned int length, int flags);
 
-/* Receive a single message or multiple messages.  The functions return the
-   number of received messages, or 0 on error.  The returned data point to
-   static buffers, which are valid until another call of these functions.  */
-extern int SCK_ReceiveMessage(int sock_fd, SCK_Message *message, int flags);
-extern int SCK_ReceiveMessages(int sock_fd, SCK_Message *messages, int max_messages,
-                               int flags);
+/* Receive a single message or multiple messages.  The functions return
+   a pointer to static buffers, or NULL on error.  The buffers are valid until
+   another call of the functions and can be reused for sending messages. */
+extern SCK_Message *SCK_ReceiveMessage(int sock_fd, int flags);
+extern SCK_Message *SCK_ReceiveMessages(int sock_fd, int flags, int *num_messages);
 
 /* Initialise a new message (e.g. before sending) */
 extern void SCK_InitMessage(SCK_Message *message, SCK_AddressType addr_type);