* You may select, at your option, one of the above-listed licenses.
*/
+#define ZSTD_STATIC_LINKING_ONLY
#include "zstd.h"
#include "zstd_errors.h"
char *output;
size_t output_size;
+ const char *dict_file_name;
+ const char *dict_file_dir_name;
+ int32_t dict_id;
+ char *dict;
+ size_t dict_size;
+ ZSTD_DDict* ddict;
+
ZSTD_DCtx* dctx;
int success_count;
static void free_stuff(stuff_t* stuff) {
free(stuff->input);
free(stuff->output);
+ ZSTD_freeDDict(stuff->ddict);
+ free(stuff->dict);
ZSTD_freeDCtx(stuff->dctx);
}
static void usage(void) {
- fprintf(stderr, "check_flipped_bits input_filename");
+ fprintf(stderr, "check_flipped_bits input_filename [-d dict] [-D dict_dir]\n");
+ fprintf(stderr, "\n");
+ fprintf(stderr, "Arguments:\n");
+ fprintf(stderr, " -d file: path to a dictionary file to use.\n");
+ fprintf(stderr, " -D dir : path to a directory, with files containing dictionaries, of the\n"
+ " form DICTID.zstd-dict, e.g., 12345.zstd-dict.\n");
exit(1);
}
return buf;
}
+static ZSTD_DDict* readDict(const char* filename, char **buf, size_t* size, int32_t* dict_id) {
+ ZSTD_DDict* ddict;
+ *buf = readFile(filename, size);
+ if (*buf == NULL) {
+ fprintf(stderr, "Opening dictionary file '%s' failed\n", filename);
+ return NULL;
+ }
+
+ ddict = ZSTD_createDDict_advanced(*buf, *size, ZSTD_dlm_byRef, ZSTD_dct_auto, ZSTD_defaultCMem);
+ if (ddict == NULL) {
+ fprintf(stderr, "Failed to create ddict.\n");
+ return NULL;
+ }
+ if (dict_id != NULL) {
+ *dict_id = ZSTD_getDictID_fromDDict(ddict);
+ }
+ return ddict;
+}
+
+static ZSTD_DDict* readDictByID(stuff_t *stuff, int32_t dict_id, char **buf, size_t* size) {
+ if (stuff->dict_file_dir_name == NULL) {
+ return NULL;
+ } else {
+ size_t dir_name_len = strlen(stuff->dict_file_dir_name);
+ int dir_needs_separator = 0;
+ size_t dict_file_name_alloc_size = dir_name_len + 1 /* '/' */ + 10 /* max int32_t len */ + strlen(".zstd-dict") + 1 /* '\0' */;
+ char *dict_file_name = malloc(dict_file_name_alloc_size);
+ ZSTD_DDict* ddict;
+ int32_t read_dict_id;
+ if (dict_file_name == NULL) {
+ fprintf(stderr, "malloc failed.\n");
+ return 0;
+ }
+
+ if (dir_name_len > 0 && stuff->dict_file_dir_name[dir_name_len - 1] != '/') {
+ dir_needs_separator = 1;
+ }
+
+ snprintf(
+ dict_file_name,
+ dict_file_name_alloc_size,
+ "%s%s%u.zstd-dict",
+ stuff->dict_file_dir_name,
+ dir_needs_separator ? "/" : "",
+ dict_id);
+
+ /* fprintf(stderr, "Loading dict %u from '%s'.\n", dict_id, dict_file_name); */
+
+ ddict = readDict(dict_file_name, buf, size, &read_dict_id);
+ if (ddict == NULL) {
+ fprintf(stderr, "Failed to create ddict from '%s'.\n", dict_file_name);
+ free(dict_file_name);
+ return 0;
+ }
+ if (read_dict_id != dict_id) {
+ fprintf(stderr, "Read dictID (%u) does not match expected (%u).\n", read_dict_id, dict_id);
+ free(dict_file_name);
+ ZSTD_freeDDict(ddict);
+ return 0;
+ }
+
+ free(dict_file_name);
+ return ddict;
+ }
+}
+
static int init_stuff(stuff_t* stuff, int argc, char *argv[]) {
const char* input_filename;
- if (argc != 2) {
+ if (argc < 2) {
usage();
}
return 0;
}
+ stuff->dict_file_name = NULL;
+ stuff->dict_file_dir_name = NULL;
+ stuff->dict_id = 0;
+ stuff->dict = NULL;
+ stuff->dict_size = 0;
+ stuff->ddict = NULL;
+
+ if (argc > 2) {
+ if (!strcmp(argv[2], "-d")) {
+ if (argc > 3) {
+ stuff->dict_file_name = argv[3];
+ } else {
+ usage();
+ }
+ } else
+ if (!strcmp(argv[2], "-D")) {
+ if (argc > 3) {
+ stuff->dict_file_dir_name = argv[3];
+ } else {
+ usage();
+ }
+ } else {
+ usage();
+ }
+ }
+
+ if (stuff->dict_file_dir_name) {
+ int32_t dict_id = ZSTD_getDictID_fromFrame(stuff->input, stuff->input_size);
+ if (dict_id != 0) {
+ stuff->ddict = readDictByID(stuff, dict_id, &stuff->dict, &stuff->dict_size);
+ if (stuff->ddict == NULL) {
+ fprintf(stderr, "Failed to create cached ddict.\n");
+ return 0;
+ }
+ stuff->dict_id = dict_id;
+ }
+ } else
+ if (stuff->dict_file_name) {
+ stuff->ddict = readDict(stuff->dict_file_name, &stuff->dict, &stuff->dict_size, &stuff->dict_id);
+ if (stuff->ddict == NULL) {
+ fprintf(stderr, "Failed to create ddict from '%s'.\n", stuff->dict_file_name);
+ return 0;
+ }
+ }
+
stuff->dctx = ZSTD_createDCtx();
if (stuff->dctx == NULL) {
return 0;
ZSTD_inBuffer in = {stuff->perturbed, stuff->input_size, 0};
ZSTD_outBuffer out = {stuff->output, stuff->output_size, 0};
ZSTD_DCtx* dctx = stuff->dctx;
+ int32_t custom_dict_id = ZSTD_getDictID_fromFrame(in.src, in.size);
+ char *custom_dict = NULL;
+ size_t custom_dict_size = 0;
+ ZSTD_DDict* custom_ddict = NULL;
+
+ if (custom_dict_id != 0 && custom_dict_id != stuff->dict_id) {
+ /* fprintf(stderr, "Instead of dict %u, this perturbed blob wants dict %u.\n", stuff->dict_id, custom_dict_id); */
+ custom_ddict = readDictByID(stuff, custom_dict_id, &custom_dict, &custom_dict_size);
+ }
ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
- ZSTD_DCtx_refDDict(dctx, NULL);
+
+ if (custom_ddict != NULL) {
+ ZSTD_DCtx_refDDict(dctx, custom_ddict);
+ } else {
+ ZSTD_DCtx_refDDict(dctx, stuff->ddict);
+ }
while (in.pos != in.size) {
out.pos = 0;
fprintf(
stderr, "Decompression failed: %s\n", ZSTD_getErrorName(ret));
*/
+ if (custom_ddict != NULL) {
+ ZSTD_freeDDict(custom_ddict);
+ free(custom_dict);
+ }
return 0;
}
}
stuff->success_count++;
+
+ if (custom_ddict != NULL) {
+ ZSTD_freeDDict(custom_ddict);
+ free(custom_dict);
+ }
return 1;
}
free_stuff(&stuff);
return 0;
-}
\ No newline at end of file
+}