]> git.ipfire.org Git - thirdparty/chrony.git/commitdiff
socket: improve code
authorMiroslav Lichvar <mlichvar@redhat.com>
Tue, 11 Aug 2020 15:07:14 +0000 (17:07 +0200)
committerMiroslav Lichvar <mlichvar@redhat.com>
Thu, 13 Aug 2020 08:40:18 +0000 (10:40 +0200)
Add more assertions and other checks, and improve coding style a bit.

ntp_io_linux.c
socket.c
socket.h

index bfe9f996f567b4213c6240b78d2b641b7c20baff..7be242faf141edab5d125edd886a910a28dfabd0 100644 (file)
@@ -757,7 +757,7 @@ NIO_Linux_ProcessMessage(SCK_Message *message, NTP_Local_Address *local_addr,
   l2_length = message->length;
   message->length = extract_udp_data(message->data, &message->remote_addr.ip, message->length);
 
-  DEBUG_LOG("Extracted message for %s fd=%d len=%u",
+  DEBUG_LOG("Extracted message for %s fd=%d len=%d",
             UTI_IPSockAddrToString(&message->remote_addr.ip),
             local_addr->sock_fd, message->length);
 
index bdf89914eb97aec227adbe7ce017074177256f25..2262aad548ee1191c4686daf40462f784f0ac95d 100644 (file)
--- a/socket.c
+++ b/socket.c
@@ -118,7 +118,7 @@ prepare_buffers(unsigned int n)
     hdr->msg_hdr.msg_namelen = sizeof (msg->name);
     hdr->msg_hdr.msg_iov = &msg->iov;
     hdr->msg_hdr.msg_iovlen = 1;
-    hdr->msg_hdr.msg_control = &msg->cmsg_buf;
+    hdr->msg_hdr.msg_control = msg->cmsg_buf;
     hdr->msg_hdr.msg_controllen = sizeof (msg->cmsg_buf);
     hdr->msg_hdr.msg_flags = 0;
     hdr->msg_len = 0;
@@ -176,7 +176,7 @@ check_socket_flag(int sock_flag, int fd_flag, int fs_flag)
 static int
 set_socket_nonblock(int sock_fd)
 {
-  if (fcntl(sock_fd, F_SETFL, O_NONBLOCK)) {
+  if (fcntl(sock_fd, F_SETFL, O_NONBLOCK) < 0) {
     DEBUG_LOG("Could not set O_NONBLOCK : %s", strerror(errno));
     return 0;
   }
@@ -656,9 +656,8 @@ log_message(int sock_fd, int direction, SCK_Message *message, const char *prefix
     case SCK_ADDR_IP:
       if (message->remote_addr.ip.ip_addr.family != IPADDR_UNSPEC)
         remote_addr = UTI_IPSockAddrToString(&message->remote_addr.ip);
-      if (message->local_addr.ip.family != IPADDR_UNSPEC) {
+      if (message->local_addr.ip.family != IPADDR_UNSPEC)
         local_addr = UTI_IPToString(&message->local_addr.ip);
-      }
       break;
     case SCK_ADDR_UNIX:
       remote_addr = message->remote_addr.path;
@@ -684,7 +683,7 @@ log_message(int sock_fd, int direction, SCK_Message *message, const char *prefix
       snprintf(tslen, sizeof (tslen), " tslen=%d", message->timestamp.l2_length);
   }
 
-  DEBUG_LOG("%s message%s%s%s%s fd=%d len=%u%s%s%s%s%s%s",
+  DEBUG_LOG("%s message%s%s%s%s fd=%d len=%d%s%s%s%s%s%s",
             prefix,
             remote_addr ? (direction > 0 ? " from " : " to ") : "",
             remote_addr ? remote_addr : "",
@@ -739,7 +738,7 @@ init_message_nonaddress(SCK_Message *message)
 /* ================================================== */
 
 static int
-process_header(struct msghdr *msg, unsigned int msg_length, int sock_fd, int flags,
+process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags,
                SCK_Message *message)
 {
   struct cmsghdr *cmsg;
@@ -921,7 +920,10 @@ receive_messages(int sock_fd, int flags, int max_messages, int *num_messages)
   hdr = ARR_GetElements(recv_headers);
   n = ARR_GetSize(recv_headers);
   n = MIN(n, max_messages);
-  assert(n >= 1);
+
+  if (n < 1 || n > MAX_RECV_MESSAGES ||
+      n > ARR_GetSize(recv_messages) || n > ARR_GetSize(recv_sck_messages))
+    assert(0);
 
   recv_flags = get_recv_flags(flags);
 
@@ -1030,6 +1032,11 @@ send_message(int sock_fd, SCK_Message *message, int flags)
     msg.msg_namelen = 0;
   }
 
+  if (message->length < 0) {
+    DEBUG_LOG("Invalid length %d", message->length);
+    return 0;
+  }
+
   iov.iov_base = message->data;
   iov.iov_len = message->length;
   msg.msg_iov = &iov;
@@ -1404,10 +1411,15 @@ SCK_ShutdownConnection(int sock_fd)
 /* ================================================== */
 
 int
-SCK_Receive(int sock_fd, void *buffer, unsigned int length, int flags)
+SCK_Receive(int sock_fd, void *buffer, int length, int flags)
 {
   int r;
 
+  if (length < 0) {
+    DEBUG_LOG("Invalid length %d", length);
+    return -1;
+  }
+
   r = recv(sock_fd, buffer, length, get_recv_flags(flags));
 
   if (r < 0) {
@@ -1423,16 +1435,21 @@ SCK_Receive(int sock_fd, void *buffer, unsigned int length, int flags)
 /* ================================================== */
 
 int
-SCK_Send(int sock_fd, const void *buffer, unsigned int length, int flags)
+SCK_Send(int sock_fd, const void *buffer, int length, int flags)
 {
   int r;
 
   assert(flags == 0);
 
+  if (length < 0) {
+    DEBUG_LOG("Invalid length %d", length);
+    return -1;
+  }
+
   r = send(sock_fd, buffer, length, 0);
 
   if (r < 0) {
-    DEBUG_LOG("Could not send data fd=%d len=%u : %s", sock_fd, length, strerror(errno));
+    DEBUG_LOG("Could not send data fd=%d len=%d : %s", sock_fd, length, strerror(errno));
     return r;
   }
 
index 7d3fa9d9984843cd083e1b4f6081c05f267007ed..cdbae2de45b9a889759be867c3c7d0fe04278cfa 100644 (file)
--- a/socket.h
+++ b/socket.h
@@ -49,7 +49,7 @@ typedef enum {
 
 typedef struct {
   void *data;
-  unsigned int length;
+  int length;
   SCK_AddressType addr_type;
   int if_index;
 
@@ -119,8 +119,8 @@ extern int SCK_AcceptConnection(int sock_fd, IPSockAddr *remote_addr);
 extern int SCK_ShutdownConnection(int sock_fd);
 
 /* Receive and send data on connected sockets - recv()/send() wrappers */
-extern int SCK_Receive(int sock_fd, void *buffer, unsigned int length, int flags);
-extern int SCK_Send(int sock_fd, const void *buffer, unsigned int length, int flags);
+extern int SCK_Receive(int sock_fd, void *buffer, int length, int flags);
+extern int SCK_Send(int sock_fd, const void *buffer, int length, int flags);
 
 /* Receive a single message or multiple messages.  The functions return
    a pointer to static buffers, or NULL on error.  The buffers are valid until