return 1;
}
-/* Decompress a zstd stream from PIN/SIN to POUT/SOUT. Code based on RFC 8878.
+/* Decompress a single zstd frame from *PPIN, ending at PINEND, to *PPOUT/SOUT.
Return 1 on success, 0 on error. */
static int
-elf_zstd_decompress (const unsigned char *pin, size_t sin,
- unsigned char *zdebug_table, unsigned char *pout,
- size_t sout)
+elf_zstd_decompress_frame (const unsigned char **ppin,
+ const unsigned char *pinend,
+ unsigned char *zdebug_table, unsigned char **ppout,
+ size_t sout)
{
- const unsigned char *pinend;
+ const unsigned char *pin;
+ unsigned char *pout;
unsigned char *poutstart;
unsigned char *poutend;
struct elf_zstd_seq_decode literal_decode;
uint64_t content_size;
int last_block;
- pinend = pin + sin;
+ pin = *ppin;
+ pout = *ppout;
poutstart = pout;
- poutend = pout + sout;
literal_decode.table = NULL;
literal_decode.table_bits = -1;
repeated_offset2 = 4;
repeated_offset3 = 8;
- if (unlikely (sin < 4))
+ if (unlikely (pinend - pin < 4))
{
elf_uncompress_failed ();
return 0;
}
if (unlikely (content_size != (size_t) content_size
- || (size_t) content_size != sout))
+ || (size_t) content_size > sout))
{
elf_uncompress_failed ();
return 0;
}
+ poutend = pout + content_size;
+
last_block = 0;
while (!last_block)
{
pin += 4;
}
- if (pin != pinend)
+ *ppin = pin;
+ *ppout = pout;
+
+ return 1;
+}
+
+/* Decompress a zstd stream from PIN/SIN to POUT/SOUT. Code based on RFC 8878.
+ Return 1 on success, 0 on error. */
+
+static int
+elf_zstd_decompress (const unsigned char *pin, size_t sin,
+ unsigned char *zdebug_table, unsigned char *pout,
+ size_t sout)
+{
+ const unsigned char *pinend;
+
+ pinend = pin + sin;
+
+ while (sin > 0)
+ {
+ const unsigned char *pin_frame;
+ unsigned char *pout_frame;
+
+ pin_frame = pin;
+ pout_frame = pout;
+ if (!elf_zstd_decompress_frame (&pin_frame, pinend, zdebug_table,
+ &pout_frame, sout))
+ return 0;
+
+ sin -= pin_frame - pin;
+ pin = pin_frame;
+ sout -= pout_frame - pout;
+ pout = pout_frame;
+ }
+
+ if (sout > 0)
{
elf_uncompress_failed ();
return 0;
size_t orig_bufsize;
size_t i;
char *compressed_buf;
- size_t compressed_bufsize;
size_t compressed_size;
+ size_t chunk_size;
unsigned char *uncompressed_buf;
size_t r;
clockid_t cid;
return;
}
- compressed_bufsize = ZSTD_compressBound (orig_bufsize);
- compressed_buf = malloc (compressed_bufsize);
- if (compressed_buf == NULL)
- {
- perror ("malloc");
- goto fail;
- }
+ /* Split the input into 100K chunks. This is to approximate the fact that lld
+ splits the input into 1M shards. */
- r = ZSTD_compress (compressed_buf, compressed_bufsize,
- orig_buf, orig_bufsize, 3);
- if (ZSTD_isError (r))
+ compressed_size = 0;
+ compressed_buf = NULL;
+ chunk_size = 100 << 10;
+ for (i = 0; i < orig_bufsize; i += chunk_size)
{
- fprintf (stderr, "zstd compress failed: %s\n", ZSTD_getErrorName (r));
- goto fail;
+ size_t chunk_input_size;
+ size_t chunk_compressed_size;
+
+ chunk_input_size = orig_bufsize - i;
+ if (chunk_input_size > chunk_size)
+ chunk_input_size = chunk_size;
+
+ chunk_compressed_size = ZSTD_compressBound (chunk_input_size);
+ compressed_buf = realloc (compressed_buf, compressed_size + chunk_compressed_size);
+ if (compressed_buf == NULL)
+ {
+ perror ("realloc");
+ goto fail;
+ }
+
+ r = ZSTD_compress (compressed_buf + compressed_size,
+ chunk_compressed_size,
+ orig_buf + i, chunk_input_size, 3);
+ if (ZSTD_isError (r))
+ {
+ fprintf (stderr, "zstd compress failed: %s\n", ZSTD_getErrorName (r));
+ goto fail;
+ }
+ compressed_size += r;
}
- compressed_size = r;
uncompressed_buf = malloc (orig_bufsize);
if (uncompressed_buf == NULL)