]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/libsystemd/sd-journal/compress.c
alloc-util: simplify GREEDY_REALLOC() logic by relying on malloc_usable_size()
[thirdparty/systemd.git] / src / libsystemd / sd-journal / compress.c
index f366a597b5a5b2693a965a8c633ddfe5a9c78228..c788dd8caf9e9e00b7ced2434cac1e0a94c45005 100644 (file)
@@ -1,10 +1,11 @@
 /* SPDX-License-Identifier: LGPL-2.1-or-later */
 
 #include <inttypes.h>
+#include <malloc.h>
 #include <stdlib.h>
 #include <sys/mman.h>
-#include <sys/types.h>
 #include <sys/stat.h>
+#include <sys/types.h>
 #include <unistd.h>
 
 #if HAVE_XZ
@@ -157,8 +158,12 @@ int compress_blob_zstd(
 #endif
 }
 
-int decompress_blob_xz(const void *src, uint64_t src_size,
-                       void **dst, size_t *dst_alloc_size, size_t* dst_size, size_t dst_max) {
+int decompress_blob_xz(
+                const void *src,
+                uint64_t src_size,
+                void **dst,
+                size_t* dst_size,
+                size_t dst_max) {
 
 #if HAVE_XZ
         _cleanup_(lzma_end) lzma_stream s = LZMA_STREAM_INIT;
@@ -168,16 +173,14 @@ int decompress_blob_xz(const void *src, uint64_t src_size,
         assert(src);
         assert(src_size > 0);
         assert(dst);
-        assert(dst_alloc_size);
         assert(dst_size);
-        assert(*dst_alloc_size == 0 || *dst);
 
         ret = lzma_stream_decoder(&s, UINT64_MAX, 0);
         if (ret != LZMA_OK)
                 return -ENOMEM;
 
         space = MIN(src_size * 2, dst_max ?: SIZE_MAX);
-        if (!greedy_realloc(dst, dst_alloc_size, space, 1))
+        if (!greedy_realloc(dst, space, 1))
                 return -ENOMEM;
 
         s.next_in = src;
@@ -203,7 +206,7 @@ int decompress_blob_xz(const void *src, uint64_t src_size,
 
                 used = space - s.avail_out;
                 space = MIN(2 * space, dst_max ?: SIZE_MAX);
-                if (!greedy_realloc(dst, dst_alloc_size, space, 1))
+                if (!greedy_realloc(dst, space, 1))
                         return -ENOMEM;
 
                 s.avail_out = space - used;
@@ -217,8 +220,12 @@ int decompress_blob_xz(const void *src, uint64_t src_size,
 #endif
 }
 
-int decompress_blob_lz4(const void *src, uint64_t src_size,
-                        void **dst, size_t *dst_alloc_size, size_t* dst_size, size_t dst_max) {
+int decompress_blob_lz4(
+                const void *src,
+                uint64_t src_size,
+                void **dst,
+                size_t* dst_size,
+                size_t dst_max) {
 
 #if HAVE_LZ4
         char* out;
@@ -227,9 +234,7 @@ int decompress_blob_lz4(const void *src, uint64_t src_size,
         assert(src);
         assert(src_size > 0);
         assert(dst);
-        assert(dst_alloc_size);
         assert(dst_size);
-        assert(*dst_alloc_size == 0 || *dst);
 
         if (src_size <= 8)
                 return -EBADMSG;
@@ -237,14 +242,9 @@ int decompress_blob_lz4(const void *src, uint64_t src_size,
         size = unaligned_read_le64(src);
         if (size < 0 || (unsigned) size != unaligned_read_le64(src))
                 return -EFBIG;
-        if ((size_t) size > *dst_alloc_size) {
-                out = realloc(*dst, size);
-                if (!out)
-                        return -ENOMEM;
-                *dst = out;
-                *dst_alloc_size = size;
-        } else
-                out = *dst;
+        out = greedy_realloc(dst, size, 1);
+        if (!out)
+                return -ENOMEM;
 
         r = LZ4_decompress_safe((char*)src + 8, out, src_size - 8, size);
         if (r < 0 || r != size)
@@ -258,8 +258,11 @@ int decompress_blob_lz4(const void *src, uint64_t src_size,
 }
 
 int decompress_blob_zstd(
-                const void *src, uint64_t src_size,
-                void **dst, size_t *dst_alloc_size, size_t *dst_size, size_t dst_max) {
+                const void *src,
+                uint64_t src_size,
+                void **dst,
+                size_t *dst_size,
+                size_t dst_max) {
 
 #if HAVE_ZSTD
         uint64_t size;
@@ -267,9 +270,7 @@ int decompress_blob_zstd(
         assert(src);
         assert(src_size > 0);
         assert(dst);
-        assert(dst_alloc_size);
         assert(dst_size);
-        assert(*dst_alloc_size == 0 || *dst);
 
         size = ZSTD_getFrameContentSize(src, src_size);
         if (IN_SET(size, ZSTD_CONTENTSIZE_ERROR, ZSTD_CONTENTSIZE_UNKNOWN))
@@ -280,7 +281,7 @@ int decompress_blob_zstd(
         if (size > SIZE_MAX)
                 return -E2BIG;
 
-        if (!(greedy_realloc(dst, dst_alloc_size, MAX(ZSTD_DStreamOutSize(), size), 1)))
+        if (!(greedy_realloc(dst, MAX(ZSTD_DStreamOutSize(), size), 1)))
                 return -ENOMEM;
 
         _cleanup_(ZSTD_freeDCtxp) ZSTD_DCtx *dctx = ZSTD_createDCtx();
@@ -293,7 +294,7 @@ int decompress_blob_zstd(
         };
         ZSTD_outBuffer output = {
                 .dst = *dst,
-                .size = *dst_alloc_size,
+                .size = MALLOC_SIZEOF_SAFE(*dst),
         };
 
         size_t k = ZSTD_decompressStream(dctx, &output, &input);
@@ -312,57 +313,63 @@ int decompress_blob_zstd(
 
 int decompress_blob(
                 int compression,
-                const void *src, uint64_t src_size,
-                void **dst, size_t *dst_alloc_size, size_t* dst_size, size_t dst_max) {
+                const void *src,
+                uint64_t src_size,
+                void **dst,
+                size_t* dst_size,
+                size_t dst_max) {
 
         if (compression == OBJECT_COMPRESSED_XZ)
                 return decompress_blob_xz(
                                 src, src_size,
-                                dst, dst_alloc_size, dst_size, dst_max);
+                                dst, dst_size, dst_max);
         else if (compression == OBJECT_COMPRESSED_LZ4)
                 return decompress_blob_lz4(
                                 src, src_size,
-                                dst, dst_alloc_size, dst_size, dst_max);
+                                dst, dst_size, dst_max);
         else if (compression == OBJECT_COMPRESSED_ZSTD)
                 return decompress_blob_zstd(
                                 src, src_size,
-                                dst, dst_alloc_size, dst_size, dst_max);
+                                dst, dst_size, dst_max);
         else
                 return -EPROTONOSUPPORT;
 }
 
-int decompress_startswith_xz(const void *src, uint64_t src_size,
-                             void **buffer, size_t *buffer_size,
-                             const void *prefix, size_t prefix_len,
-                             uint8_t extra) {
+int decompress_startswith_xz(
+                const void *src,
+                uint64_t src_size,
+                void **buffer,
+                const void *prefix,
+                size_t prefix_len,
+                uint8_t extra) {
 
 #if HAVE_XZ
         _cleanup_(lzma_end) lzma_stream s = LZMA_STREAM_INIT;
+        size_t allocated;
         lzma_ret ret;
 
-        /* Checks whether the decompressed blob starts with the
-         * mentioned prefix. The byte extra needs to follow the
-         * prefix */
+        /* Checks whether the decompressed blob starts with the mentioned prefix. The byte extra needs to
+         * follow the prefix */
 
         assert(src);
         assert(src_size > 0);
         assert(buffer);
-        assert(buffer_size);
         assert(prefix);
-        assert(*buffer_size == 0 || *buffer);
 
         ret = lzma_stream_decoder(&s, UINT64_MAX, 0);
         if (ret != LZMA_OK)
                 return -EBADMSG;
 
-        if (!(greedy_realloc(buffer, buffer_size, ALIGN_8(prefix_len + 1), 1)))
+        if (!(greedy_realloc(buffer, ALIGN_8(prefix_len + 1), 1)))
                 return -ENOMEM;
 
+        allocated = MALLOC_SIZEOF_SAFE(*buffer);
+
         s.next_in = src;
         s.avail_in = src_size;
 
         s.next_out = *buffer;
-        s.avail_out = *buffer_size;
+        s.avail_out = allocated;
 
         for (;;) {
                 ret = lzma_code(&s, LZMA_FINISH);
@@ -370,19 +377,20 @@ int decompress_startswith_xz(const void *src, uint64_t src_size,
                 if (!IN_SET(ret, LZMA_OK, LZMA_STREAM_END))
                         return -EBADMSG;
 
-                if (*buffer_size - s.avail_out >= prefix_len + 1)
+                if (allocated - s.avail_out >= prefix_len + 1)
                         return memcmp(*buffer, prefix, prefix_len) == 0 &&
                                 ((const uint8_t*) *buffer)[prefix_len] == extra;
 
                 if (ret == LZMA_STREAM_END)
                         return 0;
 
-                s.avail_out += *buffer_size;
+                s.avail_out += allocated;
 
-                if (!(greedy_realloc(buffer, buffer_size, *buffer_size * 2, 1)))
+                if (!(greedy_realloc(buffer, allocated * 2, 1)))
                         return -ENOMEM;
 
-                s.next_out = *(uint8_t**)buffer + *buffer_size - s.avail_out;
+                allocated = MALLOC_SIZEOF_SAFE(*buffer);
+                s.next_out = *(uint8_t**)buffer + allocated - s.avail_out;
         }
 
 #else
@@ -390,36 +398,43 @@ int decompress_startswith_xz(const void *src, uint64_t src_size,
 #endif
 }
 
-int decompress_startswith_lz4(const void *src, uint64_t src_size,
-                              void **buffer, size_t *buffer_size,
-                              const void *prefix, size_t prefix_len,
-                              uint8_t extra) {
+int decompress_startswith_lz4(
+                const void *src,
+                uint64_t src_size,
+                void **buffer,
+                const void *prefix,
+                size_t prefix_len,
+                uint8_t extra) {
+
 #if HAVE_LZ4
-        /* Checks whether the decompressed blob starts with the
-         * mentioned prefix. The byte extra needs to follow the
-         * prefix */
+        /* Checks whether the decompressed blob starts with the mentioned prefix. The byte extra needs to
+         * follow the prefix */
 
+        size_t allocated;
         int r;
 
         assert(src);
         assert(src_size > 0);
         assert(buffer);
-        assert(buffer_size);
         assert(prefix);
-        assert(*buffer_size == 0 || *buffer);
 
         if (src_size <= 8)
                 return -EBADMSG;
 
-        if (!(greedy_realloc(buffer, buffer_size, ALIGN_8(prefix_len + 1), 1)))
+        if (!(greedy_realloc(buffer, ALIGN_8(prefix_len + 1), 1)))
                 return -ENOMEM;
-
-        r = LZ4_decompress_safe_partial((char*)src + 8, *buffer, src_size - 8,
-                                        prefix_len + 1, *buffer_size);
-        /* One lz4 < 1.8.3, we might get "failure" (r < 0), or "success" where
-         * just a part of the buffer is decompressed. But if we get a smaller
-         * amount of bytes than requested, we don't know whether there isn't enough
-         * data to fill the requested size or whether we just got a partial answer.
+        allocated = MALLOC_SIZEOF_SAFE(*buffer);
+
+        r = LZ4_decompress_safe_partial(
+                        (char*)src + 8,
+                        *buffer,
+                        src_size - 8,
+                        prefix_len + 1,
+                        allocated);
+
+        /* One lz4 < 1.8.3, we might get "failure" (r < 0), or "success" where just a part of the buffer is
+         * decompressed. But if we get a smaller amount of bytes than requested, we don't know whether there
+         * isn't enough data to fill the requested size or whether we just got a partial answer.
          */
         if (r < 0 || (size_t) r < prefix_len + 1) {
                 size_t size;
@@ -437,7 +452,7 @@ int decompress_startswith_lz4(const void *src, uint64_t src_size,
 
                 /* Before version 1.8.3, lz4 always tries to decode full a "sequence",
                  * so in pathological cases might need to decompress the full field. */
-                r = decompress_blob_lz4(src, src_size, buffer, buffer_size, &size, 0);
+                r = decompress_blob_lz4(src, src_size, buffer, &size, 0);
                 if (r < 0)
                         return r;
 
@@ -453,17 +468,17 @@ int decompress_startswith_lz4(const void *src, uint64_t src_size,
 }
 
 int decompress_startswith_zstd(
-                const void *src, uint64_t src_size,
-                void **buffer, size_t *buffer_size,
-                const void *prefix, size_t prefix_len,
+                const void *src,
+                uint64_t src_size,
+                void **buffer,
+                const void *prefix,
+                size_t prefix_len,
                 uint8_t extra) {
 #if HAVE_ZSTD
         assert(src);
         assert(src_size > 0);
         assert(buffer);
-        assert(buffer_size);
         assert(prefix);
-        assert(*buffer_size == 0 || *buffer);
 
         uint64_t size = ZSTD_getFrameContentSize(src, src_size);
         if (IN_SET(size, ZSTD_CONTENTSIZE_ERROR, ZSTD_CONTENTSIZE_UNKNOWN))
@@ -476,7 +491,7 @@ int decompress_startswith_zstd(
         if (!dctx)
                 return -ENOMEM;
 
-        if (!(greedy_realloc(buffer, buffer_size, MAX(ZSTD_DStreamOutSize(), prefix_len + 1), 1)))
+        if (!(greedy_realloc(buffer, MAX(ZSTD_DStreamOutSize(), prefix_len + 1), 1)))
                 return -ENOMEM;
 
         ZSTD_inBuffer input = {
@@ -485,7 +500,7 @@ int decompress_startswith_zstd(
         };
         ZSTD_outBuffer output = {
                 .dst = *buffer,
-                .size = *buffer_size,
+                .size = MALLOC_SIZEOF_SAFE(*buffer),
         };
         size_t k;
 
@@ -505,28 +520,30 @@ int decompress_startswith_zstd(
 
 int decompress_startswith(
                 int compression,
-                const void *src, uint64_t src_size,
-                void **buffer, size_t *buffer_size,
-                const void *prefix, size_t prefix_len,
+                const void *src,
+                uint64_t src_size,
+                void **buffer,
+                const void *prefix,
+                size_t prefix_len,
                 uint8_t extra) {
 
         if (compression == OBJECT_COMPRESSED_XZ)
                 return decompress_startswith_xz(
                                 src, src_size,
-                                buffer, buffer_size,
+                                buffer,
                                 prefix, prefix_len,
                                 extra);
 
         else if (compression == OBJECT_COMPRESSED_LZ4)
                 return decompress_startswith_lz4(
                                 src, src_size,
-                                buffer, buffer_size,
+                                buffer,
                                 prefix, prefix_len,
                                 extra);
         else if (compression == OBJECT_COMPRESSED_ZSTD)
                 return decompress_startswith_zstd(
                                 src, src_size,
-                                buffer, buffer_size,
+                                buffer,
                                 prefix, prefix_len,
                                 extra);
         else