]> git.ipfire.org Git - thirdparty/ipxe.git/commitdiff
[base64] Add buffer size parameter to base64_encode() and base64_decode()
authorMichael Brown <mcb30@ipxe.org>
Fri, 24 Apr 2015 14:32:04 +0000 (15:32 +0100)
committerMichael Brown <mcb30@ipxe.org>
Fri, 24 Apr 2015 14:32:04 +0000 (15:32 +0100)
Signed-off-by: Michael Brown <mcb30@ipxe.org>
src/core/base64.c
src/crypto/ocsp.c
src/include/ipxe/base64.h
src/net/tcp/httpcore.c
src/net/tcp/iscsi.c
src/net/validator.c
src/tests/base64_test.c

index c115fcea525f012831be2d9f9b7498863ef1a6fb..e452f7d41ad84cee2545259bbac1da0856e9c317 100644 (file)
@@ -43,80 +43,73 @@ static const char base64[64] =
  * Base64-encode data
  *
  * @v raw              Raw data
- * @v len              Length of raw data
- * @v encoded          Buffer for encoded string
- *
- * The buffer must be the correct length for the encoded string.  Use
- * something like
- *
- *     char buf[ base64_encoded_len ( len ) + 1 ];
- *
- * (the +1 is for the terminating NUL) to provide a buffer of the
- * correct size.
+ * @v raw_len          Length of raw data
+ * @v data             Buffer
+ * @v len              Length of buffer
+ * @ret len            Encoded length
  */
-void base64_encode ( const uint8_t *raw, size_t len, char *encoded ) {
+size_t base64_encode ( const void *raw, size_t raw_len, char *data,
+                      size_t len ) {
        const uint8_t *raw_bytes = ( ( const uint8_t * ) raw );
-       uint8_t *encoded_bytes = ( ( uint8_t * ) encoded );
-       size_t raw_bit_len = ( 8 * len );
+       size_t raw_bit_len = ( 8 * raw_len );
+       size_t used = 0;
        unsigned int bit;
        unsigned int byte;
        unsigned int shift;
        unsigned int tmp;
 
-       for ( bit = 0 ; bit < raw_bit_len ; bit += 6 ) {
+       for ( bit = 0 ; bit < raw_bit_len ; bit += 6, used++ ) {
                byte = ( bit / 8 );
                shift = ( bit % 8 );
                tmp = ( raw_bytes[byte] << shift );
-               if ( ( byte + 1 ) < len )
+               if ( ( byte + 1 ) < raw_len )
                        tmp |= ( raw_bytes[ byte + 1 ] >> ( 8 - shift ) );
                tmp = ( ( tmp >> 2 ) & 0x3f );
-               *(encoded_bytes++) = base64[tmp];
+               if ( used < len )
+                       data[used] = base64[tmp];
        }
-       for ( ; ( bit % 8 ) != 0 ; bit += 6 )
-               *(encoded_bytes++) = '=';
-       *(encoded_bytes++) = '\0';
+       for ( ; ( bit % 8 ) != 0 ; bit += 6, used++ ) {
+               if ( used < len )
+                       data[used] = '=';
+       }
+       if ( used < len )
+               data[used] = '\0';
+       if ( len )
+               data[ len - 1 ] = '\0'; /* Ensure terminator exists */
 
-       DBG ( "Base64-encoded to \"%s\":\n", encoded );
-       DBG_HDA ( 0, raw, len );
-       assert ( strlen ( encoded ) == base64_encoded_len ( len ) );
+       return used;
 }
 
 /**
  * Base64-decode string
  *
  * @v encoded          Encoded string
- * @v raw              Raw data
- * @ret len            Length of raw data, or negative error
- *
- * The buffer must be large enough to contain the decoded data.  Use
- * something like
- *
- *     char buf[ base64_decoded_max_len ( encoded ) ];
- *
- * to provide a buffer of the correct size.
+ * @v data             Buffer
+ * @v len              Length of buffer
+ * @ret len            Length of data, or negative error
  */
-int base64_decode ( const char *encoded, uint8_t *raw ) {
-       const uint8_t *encoded_bytes = ( ( const uint8_t * ) encoded );
-       uint8_t *raw_bytes = ( ( uint8_t * ) raw );
-       uint8_t encoded_byte;
+int base64_decode ( const char *encoded, void *data, size_t len ) {
+       const char *in = encoded;
+       uint8_t *out = data;
+       uint8_t in_char;
        char *match;
-       int decoded;
+       int in_bits;
        unsigned int bit = 0;
        unsigned int pad_count = 0;
-       size_t len;
+       size_t offset;
 
-       /* Zero the raw data */
-       memset ( raw, 0, base64_decoded_max_len ( encoded ) );
+       /* Zero the output buffer */
+       memset ( data, 0, len );
 
        /* Decode string */
-       while ( ( encoded_byte = *(encoded_bytes++) ) ) {
+       while ( ( in_char = *(in++) ) ) {
 
                /* Ignore whitespace characters */
-               if ( isspace ( encoded_byte ) )
+               if ( isspace ( in_char ) )
                        continue;
 
                /* Process pad characters */
-               if ( encoded_byte == '=' ) {
+               if ( in_char == '=' ) {
                        if ( pad_count >= 2 ) {
                                DBG ( "Base64-encoded string \"%s\" has too "
                                      "many pad characters\n", encoded );
@@ -133,18 +126,22 @@ int base64_decode ( const char *encoded, uint8_t *raw ) {
                }
 
                /* Process normal characters */
-               match = strchr ( base64, encoded_byte );
+               match = strchr ( base64, in_char );
                if ( ! match ) {
                        DBG ( "Base64-encoded string \"%s\" contains invalid "
-                             "character '%c'\n", encoded, encoded_byte );
+                             "character '%c'\n", encoded, in_char );
                        return -EINVAL;
                }
-               decoded = ( match - base64 );
+               in_bits = ( match - base64 );
 
                /* Add to raw data */
-               decoded <<= 2;
-               raw_bytes[ bit / 8 ] |= ( decoded >> ( bit % 8 ) );
-               raw_bytes[ bit / 8 + 1 ] |= ( decoded << ( 8 - ( bit % 8 ) ) );
+               in_bits <<= 2;
+               offset = ( bit / 8 );
+               if ( offset < len )
+                       out[offset] |= ( in_bits >> ( bit % 8 ) );
+               offset++;
+               if ( offset < len )
+                       out[offset] |= ( in_bits << ( 8 - ( bit % 8 ) ) );
                bit += 6;
        }
 
@@ -154,12 +151,7 @@ int base64_decode ( const char *encoded, uint8_t *raw ) {
                      "%d\n", encoded, bit );
                return -EINVAL;
        }
-       len = ( bit / 8 );
-
-       DBG ( "Base64-decoded \"%s\" to:\n", encoded );
-       DBG_HDA ( 0, raw, len );
-       assert ( len <= base64_decoded_max_len ( encoded ) );
 
        /* Return length in bytes */
-       return ( len );
+       return ( bit / 8 );
 }
index 66e47c57e57ee750275a14d6e323e4a05705da68..5df55bc96e11878510ae6a22af8409c52bc7aa38 100644 (file)
@@ -233,7 +233,7 @@ static int ocsp_uri_string ( struct ocsp_check *ocsp ) {
                goto err_path_base64;
        }
        base64_encode ( ocsp->request.builder.data, ocsp->request.builder.len,
-                       path_base64_string );
+                       path_base64_string, path_len );
 
        /* URI-encode the Base64-encoded request */
        memset ( &path_uri, 0, sizeof ( path_uri ) );
index eeae2f393b2636d37ca855a619c2a02a99aca986..0c70d83821d8d7e506fdbb921b7c339d908aebe3 100644 (file)
@@ -35,7 +35,8 @@ static inline size_t base64_decoded_max_len ( const char *encoded ) {
        return ( ( ( strlen ( encoded ) + 4 - 1 ) / 4 ) * 3 );
 }
 
-extern void base64_encode ( const uint8_t *raw, size_t len, char *encoded );
-extern int base64_decode ( const char *encoded, uint8_t *raw );
+extern size_t base64_encode ( const void *raw, size_t raw_len, char *data,
+                             size_t len );
+extern int base64_decode ( const char *encoded, void *data, size_t len );
 
 #endif /* _IPXE_BASE64_H */
index d94ab5f0e8c90f5692fd09d4c7ddb9a850dddf11..f14ce9a82b6094592334e5c663a9c4fb8f1ed618 100644 (file)
@@ -1081,7 +1081,8 @@ static char * http_basic_auth ( struct http_request *http ) {
        snprintf ( user_pw, sizeof ( user_pw ), "%s:%s", user, password );
 
        /* Base64-encode the "user:password" string */
-       base64_encode ( ( void * ) user_pw, user_pw_len, user_pw_base64 );
+       base64_encode ( user_pw, user_pw_len, user_pw_base64,
+                       sizeof ( user_pw_base64 ) );
 
        /* Generate the authorisation string */
        len = asprintf ( &auth, "Authorization: Basic %s\r\n",
index e553b214845c4707e9e2b8069a3579318e75a1fa..0099bf5bdb2cd7943f7a65db0c227cc0abf4a587 100644 (file)
@@ -845,7 +845,7 @@ static int iscsi_large_binary_decode ( const char *encoded, uint8_t *raw,
                case 'x' :
                        return base16_decode ( encoded, raw, len );
                case 'b' :
-                       return base64_decode ( encoded, raw );
+                       return base64_decode ( encoded, raw, len );
                }
        }
 
index 4c26cd1bbf6fbe71917719b3566809a26dc3344f..a01269da8be850127cb898702525a11cca6dce9c 100644 (file)
@@ -254,7 +254,8 @@ static int validator_start_download ( struct validator *validator,
        /* Generate URI string */
        len = snprintf ( uri_string, uri_string_len, "%s/%08x.der?subject=",
                         crosscert, crc );
-       base64_encode ( issuer->data, issuer->len, ( uri_string + len ) );
+       base64_encode ( issuer->data, issuer->len, ( uri_string + len ),
+                       ( uri_string_len - len ) );
        DBGC ( validator, "VALIDATOR %p downloading cross-signed certificate "
               "from %s\n", validator, uri_string );
 
index b22158f548a950afb7dd9a058f5f00609c5e7930..0fc595d9073ff94704ba3e270cc7513e8989cc64 100644 (file)
@@ -80,30 +80,42 @@ BASE64 ( random_test,
  * Report a base64 encoding test result
  *
  * @v test             Base64 test
+ * @v file             Test code file
+ * @v line             Test code line
  */
-#define base64_encode_ok( test ) do {                                  \
-       size_t len = base64_encoded_len ( (test)->len );                \
-       char buf[ len + 1 /* NUL */ ];                                  \
-       ok ( len == strlen ( (test)->encoded ) );                       \
-       base64_encode ( (test)->data, (test)->len, buf );               \
-       ok ( strcmp ( (test)->encoded, buf ) == 0 );                    \
-       } while ( 0 )
+static void base64_encode_okx ( struct base64_test *test, const char *file,
+                               unsigned int line ) {
+       size_t len = base64_encoded_len ( test->len );
+       char buf[ len + 1 /* NUL */ ];
+       size_t check_len;
+
+       okx ( len == strlen ( test->encoded ), file, line );
+       check_len = base64_encode ( test->data, test->len, buf, sizeof ( buf ));
+       okx ( check_len == len, file, line );
+       okx ( strcmp ( test->encoded, buf ) == 0, file, line );
+}
+#define base64_encode_ok( test ) base64_encode_okx ( test, __FILE__, __LINE__ )
 
 /**
  * Report a base64 decoding test result
  *
  * @v test             Base64 test
+ * @v file             Test code file
+ * @v line             Test code line
  */
-#define base64_decode_ok( test ) do {                                  \
-       size_t max_len = base64_decoded_max_len ( (test)->encoded );    \
-       uint8_t buf[max_len];                                           \
-       int len;                                                        \
-       len = base64_decode ( (test)->encoded, buf );                   \
-       ok ( len >= 0 );                                                \
-       ok ( ( size_t ) len <= max_len );                               \
-       ok ( ( size_t ) len == (test)->len );                           \
-       ok ( memcmp ( (test)->data, buf, len ) == 0 );                  \
-       } while ( 0 )
+static void base64_decode_okx ( struct base64_test *test, const char *file,
+                               unsigned int line ) {
+       size_t max_len = base64_decoded_max_len ( test->encoded );
+       uint8_t buf[max_len];
+       int len;
+
+       len = base64_decode ( test->encoded, buf, sizeof ( buf ) );
+       okx ( len >= 0, file, line );
+       okx ( ( size_t ) len <= max_len, file, line );
+       okx ( ( size_t ) len == test->len, file, line );
+       okx ( memcmp ( test->data, buf, len ) == 0, file, line );
+}
+#define base64_decode_ok( test ) base64_decode_okx ( test, __FILE__, __LINE__ )
 
 /**
  * Perform Base64 self-tests