]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
memory-util: add CLEANUP_ERASE_PTR() macro and use it 26004/head
authorLennart Poettering <lennart@poettering.net>
Tue, 10 Jan 2023 11:39:14 +0000 (12:39 +0100)
committerLennart Poettering <lennart@poettering.net>
Mon, 16 Jan 2023 15:19:07 +0000 (16:19 +0100)
src/basic/hexdecoct.c
src/fundamental/memory-util-fundamental.h
src/test/test-hexdecoct.c

index dc3b948d8e025d039271dc46fd557b838257ecda..898ed83f862ac6dd759febcd4cefb446bbf39918 100644 (file)
@@ -110,12 +110,17 @@ static int unhex_next(const char **p, size_t *l) {
         return r;
 }
 
-int unhexmem_full(const char *p, size_t l, bool secure, void **ret, size_t *ret_len) {
+int unhexmem_full(
+                const char *p,
+                size_t l,
+                bool secure,
+                void **ret,
+                size_t *ret_len) {
+
         _cleanup_free_ uint8_t *buf = NULL;
         size_t buf_size;
         const char *x;
         uint8_t *z;
-        int r;
 
         assert(p || l == 0);
 
@@ -128,22 +133,20 @@ int unhexmem_full(const char *p, size_t l, bool secure, void **ret, size_t *ret_
         if (!buf)
                 return -ENOMEM;
 
+        CLEANUP_ERASE_PTR(secure ? &buf : NULL, buf_size);
+
         for (x = p, z = buf;;) {
                 int a, b;
 
                 a = unhex_next(&x, &l);
                 if (a == -EPIPE) /* End of string */
                         break;
-                if (a < 0) {
-                        r = a;
-                        goto on_failure;
-                }
+                if (a < 0)
+                        return a;
 
                 b = unhex_next(&x, &l);
-                if (b < 0) {
-                        r = b;
-                        goto on_failure;
-                }
+                if (b < 0)
+                        return b;
 
                 *(z++) = (uint8_t) a << 4 | (uint8_t) b;
         }
@@ -156,12 +159,6 @@ int unhexmem_full(const char *p, size_t l, bool secure, void **ret, size_t *ret_
                 *ret = TAKE_PTR(buf);
 
         return 0;
-
-on_failure:
-        if (secure)
-                explicit_bzero_safe(buf, buf_size);
-
-        return r;
 }
 
 /* https://tools.ietf.org/html/rfc4648#section-6
@@ -765,12 +762,17 @@ static int unbase64_next(const char **p, size_t *l) {
         return ret;
 }
 
-int unbase64mem_full(const char *p, size_t l, bool secure, void **ret, size_t *ret_size) {
+int unbase64mem_full(
+                const char *p,
+                size_t l,
+                bool secure,
+                void **ret,
+                size_t *ret_size) {
+
         _cleanup_free_ uint8_t *buf = NULL;
         const char *x;
         uint8_t *z;
         size_t len;
-        int r;
 
         assert(p || l == 0);
 
@@ -785,60 +787,44 @@ int unbase64mem_full(const char *p, size_t l, bool secure, void **ret, size_t *r
         if (!buf)
                 return -ENOMEM;
 
+        CLEANUP_ERASE_PTR(secure ? &buf : NULL, len);
+
         for (x = p, z = buf;;) {
                 int a, b, c, d; /* a == 00XXXXXX; b == 00YYYYYY; c == 00ZZZZZZ; d == 00WWWWWW */
 
                 a = unbase64_next(&x, &l);
                 if (a == -EPIPE) /* End of string */
                         break;
-                if (a < 0) {
-                        r = a;
-                        goto on_failure;
-                }
-                if (a == INT_MAX) { /* Padding is not allowed at the beginning of a 4ch block */
-                        r = -EINVAL;
-                        goto on_failure;
-                }
+                if (a < 0)
+                        return a;
+                if (a == INT_MAX) /* Padding is not allowed at the beginning of a 4ch block */
+                        return -EINVAL;
 
                 b = unbase64_next(&x, &l);
-                if (b < 0) {
-                        r = b;
-                        goto on_failure;
-                }
-                if (b == INT_MAX) { /* Padding is not allowed at the second character of a 4ch block either */
-                        r = -EINVAL;
-                        goto on_failure;
-                }
+                if (b < 0)
+                        return b;
+                if (b == INT_MAX) /* Padding is not allowed at the second character of a 4ch block either */
+                        return -EINVAL;
 
                 c = unbase64_next(&x, &l);
-                if (c < 0) {
-                        r = c;
-                        goto on_failure;
-                }
+                if (c < 0)
+                        return c;
 
                 d = unbase64_next(&x, &l);
-                if (d < 0) {
-                        r = d;
-                        goto on_failure;
-                }
+                if (d < 0)
+                        return d;
 
                 if (c == INT_MAX) { /* Padding at the third character */
 
-                        if (d != INT_MAX) { /* If the third character is padding, the fourth must be too */
-                                r = -EINVAL;
-                                goto on_failure;
-                        }
+                        if (d != INT_MAX) /* If the third character is padding, the fourth must be too */
+                                return -EINVAL;
 
                         /* b == 00YY0000 */
-                        if (b & 15) {
-                                r = -EINVAL;
-                                goto on_failure;
-                        }
+                        if (b & 15)
+                                return -EINVAL;
 
-                        if (l > 0) { /* Trailing rubbish? */
-                                r = -ENAMETOOLONG;
-                                goto on_failure;
-                        }
+                        if (l > 0) /* Trailing rubbish? */
+                                return -ENAMETOOLONG;
 
                         *(z++) = (uint8_t) a << 2 | (uint8_t) (b >> 4); /* XXXXXXYY */
                         break;
@@ -846,15 +832,11 @@ int unbase64mem_full(const char *p, size_t l, bool secure, void **ret, size_t *r
 
                 if (d == INT_MAX) {
                         /* c == 00ZZZZ00 */
-                        if (c & 3) {
-                                r = -EINVAL;
-                                goto on_failure;
-                        }
+                        if (c & 3)
+                                return -EINVAL;
 
-                        if (l > 0) { /* Trailing rubbish? */
-                                r = -ENAMETOOLONG;
-                                goto on_failure;
-                        }
+                        if (l > 0) /* Trailing rubbish? */
+                                return -ENAMETOOLONG;
 
                         *(z++) = (uint8_t) a << 2 | (uint8_t) b >> 4; /* XXXXXXYY */
                         *(z++) = (uint8_t) b << 4 | (uint8_t) c >> 2; /* YYYYZZZZ */
@@ -868,18 +850,14 @@ int unbase64mem_full(const char *p, size_t l, bool secure, void **ret, size_t *r
 
         *z = 0;
 
+        assert((size_t) (z - buf) <= len);
+
         if (ret_size)
                 *ret_size = (size_t) (z - buf);
         if (ret)
                 *ret = TAKE_PTR(buf);
 
         return 0;
-
-on_failure:
-        if (secure)
-                explicit_bzero_safe(buf, len);
-
-        return r;
 }
 
 void hexdump(FILE *f, const void *p, size_t s) {
index 67621fdb4243f51dcf74c4fa3cc9527ee61d8998..78e2dbec5985f0ca7d715e533ab3412c407ec8db 100644 (file)
@@ -29,6 +29,8 @@ static inline void *explicit_bzero_safe(void *p, size_t l) {
 #endif
 
 struct VarEraser {
+        /* NB: This is a pointer to memory to erase in case of CLEANUP_ERASE(). Pointer to pointer to memory
+         * to erase in case of CLEANUP_ERASE_PTR() */
         void *p;
         size_t size;
 };
@@ -38,5 +40,27 @@ static inline void erase_var(struct VarEraser *e) {
 }
 
 /* Mark var to be erased when leaving scope. */
-#define CLEANUP_ERASE(var) \
-        _cleanup_(erase_var) _unused_ struct VarEraser CONCATENATE(_eraser_, UNIQ) = { .p = &var, .size = sizeof(var) }
+#define CLEANUP_ERASE(var)                                              \
+        _cleanup_(erase_var) _unused_ struct VarEraser CONCATENATE(_eraser_, UNIQ) = { \
+                .p = &(var),                                            \
+                .size = sizeof(var),                                    \
+        }
+
+static inline void erase_varp(struct VarEraser *e) {
+
+        /* Very similar to erase_var(), but assumes `p` is a pointer to a pointer whose memory shall be destructed. */
+        if (!e->p)
+                return;
+
+        explicit_bzero_safe(*(void**) e->p, e->size);
+}
+
+/* Mark pointer so that memory pointed to is erased when leaving scope. Note: this takes a pointer to the
+ * specified pointer, instead of just a copy of it. This is to allow callers to invalidate the pointer after
+ * use, if they like, disabling our automatic erasure (for example because they succeeded with whatever they
+ * wanted to do and now intend to return the allocated buffer to their caller without it being erased). */
+#define CLEANUP_ERASE_PTR(ptr, sz)                                      \
+        _cleanup_(erase_varp) _unused_ struct VarEraser CONCATENATE(_eraser_, UNIQ) = { \
+                .p = (ptr),                                             \
+                .size = (sz),                                           \
+        }
index afdc3b543681e37f351d2bca548abb2dd0148b6d..9d71db6ae19126d6793d9d6ee1395ee00ce60941 100644 (file)
@@ -322,6 +322,13 @@ TEST(base64mem_linebreak) {
                 assert_se(decoded_size == n);
                 assert_se(memcmp(data, decoded, n) == 0);
 
+                /* Also try in secure mode */
+                decoded = mfree(decoded);
+                decoded_size = 0;
+                assert_se(unbase64mem_full(encoded, SIZE_MAX, /* secure= */ true, &decoded, &decoded_size) >= 0);
+                assert_se(decoded_size == n);
+                assert_se(memcmp(data, decoded, n) == 0);
+
                 for (size_t j = 0; j < (size_t) l; j++)
                         assert_se((encoded[j] == '\n') == (j % (m + 1) == m));
         }
@@ -446,7 +453,17 @@ static void test_unbase64mem_one(const char *input, const char *output, int ret)
         size_t size = 0;
 
         assert_se(unbase64mem(input, SIZE_MAX, &buffer, &size) == ret);
+        if (ret >= 0) {
+                assert_se(size == strlen(output));
+                assert_se(memcmp(buffer, output, size) == 0);
+                assert_se(((char*) buffer)[size] == 0);
+        }
+
+        /* also try in secure mode */
+        buffer = mfree(buffer);
+        size = 0;
 
+        assert_se(unbase64mem_full(input, SIZE_MAX, /* secure=*/ true, &buffer, &size) == ret);
         if (ret >= 0) {
                 assert_se(size == strlen(output));
                 assert_se(memcmp(buffer, output, size) == 0);