]> git.ipfire.org Git - thirdparty/ipxe.git/commitdiff
[tls] Pass I/O buffer to received record handlers
authorMichael Brown <mcb30@ipxe.org>
Thu, 30 Mar 2023 15:28:40 +0000 (16:28 +0100)
committerMichael Brown <mcb30@ipxe.org>
Thu, 30 Mar 2023 22:37:55 +0000 (23:37 +0100)
Prepare for the possibility that a record handler may choose not to
consume the entire record by passing the I/O buffer and requiring the
handler to mark consumed data using iob_pull().

Signed-off-by: Michael Brown <mcb30@ipxe.org>
src/net/tls.c

index e0231b1c40e53bd21750f9ff0b5714ae05b2d94d..272ced24fa91da40631af26e7f9209c2fd535e13 100644 (file)
@@ -1736,15 +1736,15 @@ static int tls_send_finished ( struct tls_connection *tls ) {
  * Receive new Change Cipher record
  *
  * @v tls              TLS connection
- * @v data             Plaintext record
- * @v len              Length of plaintext record
+ * @v iobuf            I/O buffer
  * @ret rc             Return status code
  */
 static int tls_new_change_cipher ( struct tls_connection *tls,
-                                  const void *data, size_t len ) {
+                                  struct io_buffer *iobuf ) {
        const struct {
                uint8_t spec;
-       } __attribute__ (( packed )) *change_cipher = data;
+       } __attribute__ (( packed )) *change_cipher = iobuf->data;
+       size_t len = iob_len ( iobuf );
        int rc;
 
        /* Sanity check */
@@ -1754,6 +1754,7 @@ static int tls_new_change_cipher ( struct tls_connection *tls,
                DBGC_HD ( tls, change_cipher, len );
                return -EINVAL_CHANGE_CIPHER;
        }
+       iob_pull ( iobuf, sizeof ( *change_cipher ) );
 
        /* Change receive cipher spec */
        if ( ( rc = tls_change_cipher ( tls, &tls->rx_cipherspec_pending,
@@ -1771,25 +1772,27 @@ static int tls_new_change_cipher ( struct tls_connection *tls,
  * Receive new Alert record
  *
  * @v tls              TLS connection
- * @v data             Plaintext record
- * @v len              Length of plaintext record
+ * @v iobuf            I/O buffer
  * @ret rc             Return status code
  */
-static int tls_new_alert ( struct tls_connection *tls, const void *data,
-                          size_t len ) {
+static int tls_new_alert ( struct tls_connection *tls,
+                          struct io_buffer *iobuf ) {
        const struct {
                uint8_t level;
                uint8_t description;
                char next[0];
-       } __attribute__ (( packed )) *alert = data;
+       } __attribute__ (( packed )) *alert = iobuf->data;
+       size_t len = iob_len ( iobuf );
 
        /* Sanity check */
        if ( sizeof ( *alert ) != len ) {
                DBGC ( tls, "TLS %p received overlength Alert\n", tls );
-               DBGC_HD ( tls, data, len );
+               DBGC_HD ( tls, alert, len );
                return -EINVAL_ALERT;
        }
+       iob_pull ( iobuf, sizeof ( *alert ) );
 
+       /* Handle alert */
        switch ( alert->level ) {
        case TLS_ALERT_WARNING:
                DBGC ( tls, "TLS %p received warning alert %d\n",
@@ -2403,21 +2406,20 @@ static int tls_new_finished ( struct tls_connection *tls,
  * Receive new Handshake record
  *
  * @v tls              TLS connection
- * @v data             Plaintext record
- * @v len              Length of plaintext record
+ * @v iobuf            I/O buffer
  * @ret rc             Return status code
  */
 static int tls_new_handshake ( struct tls_connection *tls,
-                              const void *data, size_t len ) {
-       size_t remaining = len;
+                              struct io_buffer *iobuf ) {
+       size_t remaining;
        int rc;
 
-       while ( remaining ) {
+       while ( ( remaining = iob_len ( iobuf ) ) ) {
                const struct {
                        uint8_t type;
                        tls24_t length;
                        uint8_t payload[0];
-               } __attribute__ (( packed )) *handshake = data;
+               } __attribute__ (( packed )) *handshake = iobuf->data;
                const void *payload;
                size_t payload_len;
                size_t record_len;
@@ -2426,14 +2428,14 @@ static int tls_new_handshake ( struct tls_connection *tls,
                if ( sizeof ( *handshake ) > remaining ) {
                        DBGC ( tls, "TLS %p received underlength Handshake\n",
                               tls );
-                       DBGC_HD ( tls, data, remaining );
+                       DBGC_HD ( tls, handshake, remaining );
                        return -EINVAL_HANDSHAKE;
                }
                payload_len = tls_uint24 ( &handshake->length );
                if ( payload_len > ( remaining - sizeof ( *handshake ) ) ) {
                        DBGC ( tls, "TLS %p received overlength Handshake\n",
                               tls );
-                       DBGC_HD ( tls, data, len );
+                       DBGC_HD ( tls, handshake, remaining );
                        return -EINVAL_HANDSHAKE;
                }
                payload = &handshake->payload;
@@ -2481,15 +2483,60 @@ static int tls_new_handshake ( struct tls_connection *tls,
                 * which are explicitly excluded).
                 */
                if ( handshake->type != TLS_HELLO_REQUEST )
-                       tls_add_handshake ( tls, data, record_len );
+                       tls_add_handshake ( tls, handshake, record_len );
 
                /* Abort on failure */
                if ( rc != 0 )
                        return rc;
 
                /* Move to next handshake record */
-               data += record_len;
-               remaining -= record_len;
+               iob_pull ( iobuf, record_len );
+       }
+
+       return 0;
+}
+
+/**
+ * Receive new unknown record
+ *
+ * @v tls              TLS connection
+ * @v iobuf            I/O buffer
+ * @ret rc             Return status code
+ */
+static int tls_new_unknown ( struct tls_connection *tls __unused,
+                            struct io_buffer *iobuf ) {
+
+       /* RFC4346 says that we should just ignore unknown record types */
+       iob_pull ( iobuf, iob_len ( iobuf ) );
+       return 0;
+}
+
+/**
+ * Receive new data record
+ *
+ * @v tls              TLS connection
+ * @v rx_data          List of received data buffers
+ * @ret rc             Return status code
+ */
+static int tls_new_data ( struct tls_connection *tls,
+                         struct list_head *rx_data ) {
+       struct io_buffer *iobuf;
+       int rc;
+
+       /* Fail unless we are ready to receive data */
+       if ( ! tls_ready ( tls ) )
+               return -ENOTCONN;
+
+       /* Deliver each I/O buffer in turn */
+       while ( ( iobuf = list_first_entry ( rx_data, struct io_buffer,
+                                            list ) ) ) {
+               list_del ( &iobuf->list );
+               if ( ( rc = xfer_deliver_iob ( &tls->plainstream,
+                                              iobuf ) ) != 0 ) {
+                       DBGC ( tls, "TLS %p could not deliver data: "
+                              "%s\n", tls, strerror ( rc ) );
+                       return rc;
+               }
        }
 
        return 0;
@@ -2505,39 +2552,14 @@ static int tls_new_handshake ( struct tls_connection *tls,
  */
 static int tls_new_record ( struct tls_connection *tls, unsigned int type,
                            struct list_head *rx_data ) {
+       int ( * handler ) ( struct tls_connection *tls,
+                           struct io_buffer *iobuf );
        struct io_buffer *iobuf;
-       int ( * handler ) ( struct tls_connection *tls, const void *data,
-                           size_t len );
        int rc;
 
-       /* Deliver data records to the plainstream interface */
-       if ( type == TLS_TYPE_DATA ) {
-
-               /* Fail unless we are ready to receive data */
-               if ( ! tls_ready ( tls ) )
-                       return -ENOTCONN;
-
-               /* Deliver each I/O buffer in turn */
-               while ( ( iobuf = list_first_entry ( rx_data, struct io_buffer,
-                                                    list ) ) ) {
-                       list_del ( &iobuf->list );
-                       if ( ( rc = xfer_deliver_iob ( &tls->plainstream,
-                                                      iobuf ) ) != 0 ) {
-                               DBGC ( tls, "TLS %p could not deliver data: "
-                                      "%s\n", tls, strerror ( rc ) );
-                               return rc;
-                       }
-               }
-               return 0;
-       }
-
-       /* For all other records, merge into a single I/O buffer */
-       iobuf = iob_concatenate ( rx_data );
-       if ( ! iobuf ) {
-               DBGC ( tls, "TLS %p could not concatenate non-data record "
-                      "type %d\n", tls, type );
-               return -ENOMEM_RX_CONCAT;
-       }
+       /* Deliver data records as-is to the plainstream interface */
+       if ( type == TLS_TYPE_DATA )
+               return tls_new_data ( tls, rx_data );
 
        /* Determine handler */
        switch ( type ) {
@@ -2551,17 +2573,35 @@ static int tls_new_record ( struct tls_connection *tls, unsigned int type,
                handler = tls_new_handshake;
                break;
        default:
-               /* RFC4346 says that we should just ignore unknown
-                * record types.
-                */
-               handler = NULL;
-               DBGC ( tls, "TLS %p ignoring record type %d\n", tls, type );
+               DBGC ( tls, "TLS %p unknown record type %d\n", tls, type );
+               handler = tls_new_unknown;
                break;
        }
 
-       /* Handle record and free I/O buffer */
-       rc = ( handler ? handler ( tls, iobuf->data, iob_len ( iobuf ) ) : 0 );
+       /* Merge into a single I/O buffer */
+       iobuf = iob_concatenate ( rx_data );
+       if ( ! iobuf ) {
+               DBGC ( tls, "TLS %p could not concatenate non-data record "
+                      "type %d\n", tls, type );
+               rc = -ENOMEM_RX_CONCAT;
+               goto err_concatenate;
+       }
+
+       /* Handle record */
+       if ( ( rc = handler ( tls, iobuf ) ) != 0 )
+               goto err_handle;
+
+       /* Sanity check */
+       assert ( iob_len ( iobuf ) == 0 );
+
+       /* Free I/O buffer */
+       free_iob ( iobuf );
+
+       return 0;
+
+ err_handle:
        free_iob ( iobuf );
+ err_concatenate:
        return rc;
 }