]> git.ipfire.org Git - thirdparty/chrony.git/commitdiff
nts: improve session code
authorMiroslav Lichvar <mlichvar@redhat.com>
Tue, 7 Jul 2020 10:34:29 +0000 (12:34 +0200)
committerMiroslav Lichvar <mlichvar@redhat.com>
Thu, 9 Jul 2020 12:47:30 +0000 (14:47 +0200)
Add more comments and assertions, replace getsockopt() call with
SCK_GetIntOption(), replace strncmp() with memcmp(), move a return
statement for clarity, and remove an unused field from the instance
record.

nts_ke_session.c

index 05ca99f384f2e0edb0f754e4874a10925820e9da..a686db25e99537390aafff8d217332ed0ad44dad 100644 (file)
@@ -81,7 +81,6 @@ struct NKSN_Instance_Record {
 
   struct Message message;
   int new_message;
-  int ended_message;
 };
 
 /* ================================================== */
@@ -110,6 +109,8 @@ add_record(struct Message *message, int critical, int type, const void *body, in
 {
   struct RecordHeader header;
 
+  assert(message->length <= sizeof (message->data));
+
   if (body_length < 0 || body_length > 0xffff || type < 0 || type > 0x7fff ||
       message->length + sizeof (header) + body_length > sizeof (message->data))
     return 0;
@@ -301,31 +302,14 @@ session_timeout(void *arg)
 
 /* ================================================== */
 
-static int
-get_socket_error(int sock_fd)
-{
-  int optval;
-  socklen_t optlen = sizeof (optval);
-
-  if (getsockopt(sock_fd, SOL_SOCKET, SO_ERROR, &optval, &optlen) < 0) {
-    DEBUG_LOG("getsockopt() failed : %s", strerror(errno));
-    return EINVAL;
-  }
-
-  return optval;
-}
-
-/* ================================================== */
-
 static int
 check_alpn(NKSN_Instance inst)
 {
   gnutls_datum_t alpn;
-  int r;
 
-  r = gnutls_alpn_get_selected_protocol(inst->tls_session, &alpn);
-  if (r < 0 || alpn.size != sizeof (NKE_ALPN_NAME) - 1 ||
-      strncmp((const char *)alpn.data, NKE_ALPN_NAME, sizeof (NKE_ALPN_NAME) - 1))
+  if (gnutls_alpn_get_selected_protocol(inst->tls_session, &alpn) < 0 ||
+      alpn.size != sizeof (NKE_ALPN_NAME) - 1 ||
+      memcmp(alpn.data, NKE_ALPN_NAME, sizeof (NKE_ALPN_NAME) - 1) != 0)
     return 0;
 
   return 1;
@@ -375,9 +359,11 @@ handle_event(NKSN_Instance inst, int event)
       if (event != SCH_FILE_OUTPUT)
         return 0;
 
-      r = get_socket_error(inst->sock_fd);
+      /* Get the socket error */
+      if (!SCK_GetIntOption(inst->sock_fd, SOL_SOCKET, SO_ERROR, &r))
+        r = EINVAL;
 
-      if (r) {
+      if (r != 0) {
         LOG(LOGS_ERR, "Could not connect to %s : %s", inst->label, strerror(r));
         stop_session(inst);
         return 0;
@@ -446,6 +432,7 @@ handle_event(NKSN_Instance inst, int event)
 
     case KE_SEND:
       assert(inst->new_message && message->complete);
+      assert(message->length <= sizeof (message->data) && message->length > message->sent);
 
       r = gnutls_record_send(inst->tls_session, &message->data[message->sent],
                              message->length - message->sent);
@@ -513,7 +500,9 @@ handle_event(NKSN_Instance inst, int event)
 
       /* Server will send a response to the client */
       change_state(inst, inst->server ? KE_SEND : KE_SHUTDOWN);
-      break;
+
+      /* Return success to process the received message */
+      return 1;
 
     case KE_SHUTDOWN:
       r = gnutls_bye(inst->tls_session, GNUTLS_SHUT_RDWR);
@@ -539,9 +528,8 @@ handle_event(NKSN_Instance inst, int event)
 
     default:
       assert(0);
+      return 0;
   }
-
-  return 1;
 }
 
 /* ================================================== */
@@ -554,6 +542,9 @@ read_write_socket(int fd, int event, void *arg)
   if (!handle_event(inst, event))
     return;
 
+  /* A valid message was received.  Call the handler to process the message,
+     and prepare a response if it is a server. */
+
   reset_message_parsing(&inst->message);
 
   if (!(inst->handler)(inst->handler_arg)) {
@@ -602,13 +593,15 @@ init_gnutls(void)
   if (r < 0)
     LOG_FATAL("Could not initialise %s : %s", "gnutls", gnutls_strerror(r));
 
-  /* NTS specification requires TLS1.3 or later */
+  /* Prepare a priority cache for server and client NTS-KE sessions
+     (the NTS specification requires TLS1.3 or later) */
   r = gnutls_priority_init2(&priority_cache,
                             "-VERS-SSL3.0:-VERS-TLS1.0:-VERS-TLS1.1:-VERS-TLS1.2",
                             NULL, GNUTLS_PRIORITY_INIT_DEF_APPEND);
   if (r < 0)
     LOG_FATAL("Could not initialise %s : %s", "priority cache", gnutls_strerror(r));
 
+  /* Use our clock instead of the system clock in certificate verification */
   gnutls_global_set_time_function(get_time);
 
   gnutls_initialised = 1;
@@ -704,7 +697,7 @@ NKSN_CreateInstance(int server_mode, const char *server_name,
   inst->server_name = server_name ? Strdup(server_name) : NULL;
   inst->handler = handler;
   inst->handler_arg = handler_arg;
-  /* Replace NULL arg with the session itself */
+  /* Replace a NULL argument with the session itself */
   if (!inst->handler_arg)
     inst->handler_arg = inst;
 
@@ -751,7 +744,6 @@ NKSN_StartSession(NKSN_Instance inst, int sock_fd, const char *label,
 
   reset_message(&inst->message);
   inst->new_message = 0;
-  inst->ended_message = 0;
 
   change_state(inst, inst->server ? KE_HANDSHAKE : KE_WAIT_CONNECT);
 
@@ -785,6 +777,7 @@ NKSN_EndMessage(NKSN_Instance inst)
 {
   assert(!inst->message.complete);
 
+  /* Terminate the message */
   if (!add_record(&inst->message, 1, NKE_RECORD_END_OF_MESSAGE, NULL, 0))
     return 0;
 
@@ -806,6 +799,7 @@ NKSN_GetRecord(NKSN_Instance inst, int *critical, int *type, int *body_length,
   if (!get_record(&inst->message, critical, &type2, body_length, body, buffer_length))
     return 0;
 
+  /* Hide the end-of-message record */
   if (type2 == NKE_RECORD_END_OF_MESSAGE)
     return 0;