From: Shivani Bhardwaj Date: Thu, 12 May 2022 18:00:00 +0000 (+0530) Subject: base64: add Base64Mode enum X-Git-Tag: suricata-5.0.10~35 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ac0b72a055f3853412782ec4edd2e4765f0bac9a;p=thirdparty%2Fsuricata.git base64: add Base64Mode enum (cherry picked from commit 9131d1d85715c817a22d2a987f4a01cf42e07757) --- diff --git a/src/datasets.c b/src/datasets.c index 48619f4728..9233695a2c 100644 --- a/src/datasets.c +++ b/src/datasets.c @@ -319,7 +319,8 @@ static int DatasetLoadString(Dataset *set) SCLogDebug("line: '%s'", line); uint8_t decoded[strlen(line)]; - uint32_t len = DecodeBase64(decoded, (const uint8_t *)line, strlen(line), 1); + uint32_t len = + DecodeBase64(decoded, (const uint8_t *)line, strlen(line), BASE64_MODE_STRICT); if (len == 0) FatalError(SC_ERR_FATAL, "bad base64 encoding %s/%s", set->name, set->load); @@ -335,7 +336,8 @@ static int DatasetLoadString(Dataset *set) *r = '\0'; uint8_t decoded[strlen(line)]; - uint32_t len = DecodeBase64(decoded, (const uint8_t *)line, strlen(line), 1); + uint32_t len = + DecodeBase64(decoded, (const uint8_t *)line, strlen(line), BASE64_MODE_STRICT); if (len == 0) FatalError(SC_ERR_FATAL, "bad base64 encoding %s/%s", set->name, set->load); @@ -1023,7 +1025,8 @@ int DatasetAddSerialized(Dataset *set, const char *string) switch (set->type) { case DATASET_TYPE_STRING: { uint8_t decoded[strlen(string)]; - uint32_t len = DecodeBase64(decoded, (const uint8_t *)string, strlen(string), 1); + uint32_t len = DecodeBase64( + decoded, (const uint8_t *)string, strlen(string), BASE64_MODE_STRICT); if (len == 0) { return -2; } @@ -1104,7 +1107,8 @@ int DatasetRemoveSerialized(Dataset *set, const char *string) switch (set->type) { case DATASET_TYPE_STRING: { uint8_t decoded[strlen(string)]; - uint32_t len = DecodeBase64(decoded, (const uint8_t *)string, strlen(string), 1); + uint32_t len = DecodeBase64( + decoded, (const uint8_t *)string, strlen(string), BASE64_MODE_STRICT); if (len == 0) { return -2; } diff --git a/src/detect-base64-decode.c b/src/detect-base64-decode.c index 2e7808db37..020d00537c 100644 --- a/src/detect-base64-decode.c +++ b/src/detect-base64-decode.c @@ -84,8 +84,8 @@ int DetectBase64DecodeDoMatch(DetectEngineThreadCtx *det_ctx, const Signature *s PrintRawDataFp(stdout, payload, decode_len); #endif - det_ctx->base64_decoded_len = DecodeBase64(det_ctx->base64_decoded, - payload, decode_len, 0); + det_ctx->base64_decoded_len = + DecodeBase64(det_ctx->base64_decoded, payload, decode_len, BASE64_MODE_RELAX); SCLogDebug("Decoded %d bytes from base64 data.", det_ctx->base64_decoded_len); #if 0 diff --git a/src/runmode-unittests.c b/src/runmode-unittests.c index 357d30889c..3c91322ffd 100644 --- a/src/runmode-unittests.c +++ b/src/runmode-unittests.c @@ -178,6 +178,7 @@ static void RegisterUnittests(void) ThreadMacrosRegisterTests(); UtilSpmSearchRegistertests(); UtilActionRegisterTests(); + Base64RegisterTests(); SCClassConfRegisterTests(); SCThresholdConfRegisterTests(); SCRConfRegisterTests(); diff --git a/src/util-base64.c b/src/util-base64.c index bea92d52bb..5c6e0ac6e2 100644 --- a/src/util-base64.c +++ b/src/util-base64.c @@ -23,7 +23,7 @@ */ #include "util-base64.h" - +#include "util-unittest.h" /* Constants */ #define BASE64_TABLE_MAX 122 @@ -88,8 +88,7 @@ static inline void DecodeBase64Block(uint8_t ascii[ASCII_BLOCK], uint8_t b64[B64 * * \return Number of bytes decoded, or 0 if no data is decoded or it fails */ -uint32_t DecodeBase64(uint8_t *dest, const uint8_t *src, uint32_t len, - int strict) +uint32_t DecodeBase64(uint8_t *dest, const uint8_t *src, uint32_t len, Base64Mode mode) { int val; uint32_t padding = 0, numDecoded = 0, bbidx = 0, valid = 1, i; @@ -102,11 +101,13 @@ uint32_t DecodeBase64(uint8_t *dest, const uint8_t *src, uint32_t len, /* Get decimal representation */ val = GetBase64Value(src[i]); if (val < 0) { - + if (mode == BASE64_MODE_RFC2045 && src[i] == ' ') { + continue; + } /* Invalid character found, so decoding fails */ if (src[i] != '=') { valid = 0; - if (strict) { + if (mode != BASE64_MODE_RELAX) { numDecoded = 0; } break; @@ -149,3 +150,28 @@ uint32_t DecodeBase64(uint8_t *dest, const uint8_t *src, uint32_t len, return numDecoded; } + +#ifdef UNITTESTS + +static int DecodeString(void) +{ + /* + * SGV sbG8= : Hello + * SGVsbG8gV29ybGQ= : Hello World + * */ + + const char *src = "SGVs bG8 gV29y bGQ="; + uint8_t *dst = SCMalloc(sizeof(src) * 30); + int res = DecodeBase64(dst, (const uint8_t *)src, 30, 1); + printf("%d\n", res); + printf("dst str = \"%s\"", (const char *)dst); + FAIL_IF(res <= 0); + SCFree(dst); + PASS; +} + +void Base64RegisterTests(void) +{ + UtRegisterTest("DecodeString", DecodeString); +} +#endif diff --git a/src/util-base64.h b/src/util-base64.h index 7c8bed6262..2de0275ce5 100644 --- a/src/util-base64.h +++ b/src/util-base64.h @@ -48,8 +48,17 @@ #define ASCII_BLOCK 3 #define B64_BLOCK 4 +typedef enum { + BASE64_MODE_RELAX, + BASE64_MODE_RFC2045, /* SPs are allowed during transfer but must be skipped by Decoder */ + BASE64_MODE_STRICT, +} Base64Mode; + /* Function prototypes */ -uint32_t DecodeBase64(uint8_t *dest, const uint8_t *src, uint32_t len, - int strict); +uint32_t DecodeBase64(uint8_t *dest, const uint8_t *src, uint32_t len, Base64Mode mode); + +#endif +#ifdef UNITTESTS +void Base64RegisterTests(void); #endif diff --git a/src/util-decode-mime.c b/src/util-decode-mime.c index 3b1cc4ab5c..8ece5522a5 100644 --- a/src/util-decode-mime.c +++ b/src/util-decode-mime.c @@ -1235,8 +1235,8 @@ static uint8_t ProcessBase64Remainder(const uint8_t *buf, uint32_t len, /* Only decode if divisible by 4 */ if (state->bvr_len == B64_BLOCK || force) { - remdec = DecodeBase64(state->data_chunk + state->data_chunk_len, - state->bvremain, state->bvr_len, 1); + remdec = DecodeBase64(state->data_chunk + state->data_chunk_len, state->bvremain, + state->bvr_len, BASE64_MODE_RFC2045); if (remdec > 0) { /* Track decoded length */ @@ -1337,8 +1337,8 @@ static int ProcessBase64BodyLine(const uint8_t *buf, uint32_t len, SCLogDebug("Decoding: %u", len - rem1 - rem2); - numDecoded = DecodeBase64(state->data_chunk + state->data_chunk_len, - buf + offset, tobuf, 1); + numDecoded = DecodeBase64(state->data_chunk + state->data_chunk_len, buf + offset, + tobuf, BASE64_MODE_RFC2045); if (numDecoded > 0) { /* Track decoded length */ @@ -3007,7 +3007,7 @@ static int MimeBase64DecodeTest01(void) if (dst == NULL) return 0; - ret = DecodeBase64(dst, (const uint8_t *)base64msg, strlen(base64msg), 1); + ret = DecodeBase64(dst, (const uint8_t *)base64msg, strlen(base64msg), BASE64_MODE_RFC2045); if (memcmp(dst, msg, strlen(msg)) == 0) { ret = 1;