]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Add Support for Dictionaries 1932/head
authorW. Felix Handte <w@felixhandte.com>
Wed, 18 Dec 2019 19:26:35 +0000 (11:26 -0800)
committerW. Felix Handte <w@felixhandte.com>
Wed, 18 Dec 2019 19:54:39 +0000 (11:54 -0800)
contrib/diagnose_corruption/check_flipped_bits.c

index 58f7a45d540aa3c12bd463c4a02953fa61e4cf28..78473ea59f54b4fe98e19bd9880e70cc22d32b79 100644 (file)
@@ -8,6 +8,7 @@
  * You may select, at your option, one of the above-listed licenses.
  */
 
+#define ZSTD_STATIC_LINKING_ONLY
 #include "zstd.h"
 #include "zstd_errors.h"
 
@@ -27,6 +28,13 @@ typedef struct {
   char *output;
   size_t output_size;
 
+  const char *dict_file_name;
+  const char *dict_file_dir_name;
+  int32_t dict_id;
+  char *dict;
+  size_t dict_size;
+  ZSTD_DDict* ddict;
+
   ZSTD_DCtx* dctx;
 
   int success_count;
@@ -36,11 +44,18 @@ typedef struct {
 static void free_stuff(stuff_t* stuff) {
   free(stuff->input);
   free(stuff->output);
+  ZSTD_freeDDict(stuff->ddict);
+  free(stuff->dict);
   ZSTD_freeDCtx(stuff->dctx);
 }
 
 static void usage(void) {
-  fprintf(stderr, "check_flipped_bits input_filename");
+  fprintf(stderr, "check_flipped_bits input_filename [-d dict] [-D dict_dir]\n");
+  fprintf(stderr, "\n");
+  fprintf(stderr, "Arguments:\n");
+  fprintf(stderr, "    -d file: path to a dictionary file to use.\n");
+  fprintf(stderr, "    -D dir : path to a directory, with files containing dictionaries, of the\n"
+                  "             form DICTID.zstd-dict, e.g., 12345.zstd-dict.\n");
   exit(1);
 }
 
@@ -104,10 +119,76 @@ static char* readFile(const char* filename, size_t* size) {
   return buf;
 }
 
+static ZSTD_DDict* readDict(const char* filename, char **buf, size_t* size, int32_t* dict_id) {
+  ZSTD_DDict* ddict;
+  *buf = readFile(filename, size);
+  if (*buf == NULL) {
+    fprintf(stderr, "Opening dictionary file '%s' failed\n", filename);
+    return NULL;
+  }
+
+  ddict = ZSTD_createDDict_advanced(*buf, *size, ZSTD_dlm_byRef, ZSTD_dct_auto, ZSTD_defaultCMem);
+  if (ddict == NULL) {
+    fprintf(stderr, "Failed to create ddict.\n");
+    return NULL;
+  }
+  if (dict_id != NULL) {
+    *dict_id = ZSTD_getDictID_fromDDict(ddict);
+  }
+  return ddict;
+}
+
+static ZSTD_DDict* readDictByID(stuff_t *stuff, int32_t dict_id, char **buf, size_t* size) {
+  if (stuff->dict_file_dir_name == NULL) {
+    return NULL;
+  } else {
+    size_t dir_name_len = strlen(stuff->dict_file_dir_name);
+    int dir_needs_separator = 0;
+    size_t dict_file_name_alloc_size = dir_name_len + 1 /* '/' */ + 10 /* max int32_t len */ + strlen(".zstd-dict") + 1 /* '\0' */;
+    char *dict_file_name = malloc(dict_file_name_alloc_size);
+    ZSTD_DDict* ddict;
+    int32_t read_dict_id;
+    if (dict_file_name == NULL) {
+      fprintf(stderr, "malloc failed.\n");
+      return 0;
+    }
+
+    if (dir_name_len > 0 && stuff->dict_file_dir_name[dir_name_len - 1] != '/') {
+      dir_needs_separator = 1;
+    }
+
+    snprintf(
+      dict_file_name,
+      dict_file_name_alloc_size,
+      "%s%s%u.zstd-dict",
+      stuff->dict_file_dir_name,
+      dir_needs_separator ? "/" : "",
+      dict_id);
+
+    /* fprintf(stderr, "Loading dict %u from '%s'.\n", dict_id, dict_file_name); */
+
+    ddict = readDict(dict_file_name, buf, size, &read_dict_id);
+    if (ddict == NULL) {
+      fprintf(stderr, "Failed to create ddict from '%s'.\n", dict_file_name);
+      free(dict_file_name);
+      return 0;
+    }
+    if (read_dict_id != dict_id) {
+      fprintf(stderr, "Read dictID (%u) does not match expected (%u).\n", read_dict_id, dict_id);
+      free(dict_file_name);
+      ZSTD_freeDDict(ddict);
+      return 0;
+    }
+
+    free(dict_file_name);
+    return ddict;
+  }
+}
+
 static int init_stuff(stuff_t* stuff, int argc, char *argv[]) {
   const char* input_filename;
 
-  if (argc != 2) {
+  if (argc < 2) {
     usage();
   }
 
@@ -133,6 +214,51 @@ static int init_stuff(stuff_t* stuff, int argc, char *argv[]) {
     return 0;
   }
 
+  stuff->dict_file_name = NULL;
+  stuff->dict_file_dir_name = NULL;
+  stuff->dict_id = 0;
+  stuff->dict = NULL;
+  stuff->dict_size = 0;
+  stuff->ddict = NULL;
+
+  if (argc > 2) {
+    if (!strcmp(argv[2], "-d")) {
+      if (argc > 3) {
+        stuff->dict_file_name = argv[3];
+      } else {
+        usage();
+      }
+    } else
+    if (!strcmp(argv[2], "-D")) {
+      if (argc > 3) {
+        stuff->dict_file_dir_name = argv[3];
+      } else {
+        usage();
+      }
+    } else {
+      usage();
+    }
+  }
+
+  if (stuff->dict_file_dir_name) {
+    int32_t dict_id = ZSTD_getDictID_fromFrame(stuff->input, stuff->input_size);
+    if (dict_id != 0) {
+      stuff->ddict = readDictByID(stuff, dict_id, &stuff->dict, &stuff->dict_size);
+      if (stuff->ddict == NULL) {
+        fprintf(stderr, "Failed to create cached ddict.\n");
+        return 0;
+      }
+      stuff->dict_id = dict_id;
+    }
+  } else
+  if (stuff->dict_file_name) {
+    stuff->ddict = readDict(stuff->dict_file_name, &stuff->dict, &stuff->dict_size, &stuff->dict_id);
+    if (stuff->ddict == NULL) {
+      fprintf(stderr, "Failed to create ddict from '%s'.\n", stuff->dict_file_name);
+      return 0;
+    }
+  }
+
   stuff->dctx = ZSTD_createDCtx();
   if (stuff->dctx == NULL) {
     return 0;
@@ -149,9 +275,23 @@ static int test_decompress(stuff_t* stuff) {
   ZSTD_inBuffer in = {stuff->perturbed, stuff->input_size, 0};
   ZSTD_outBuffer out = {stuff->output, stuff->output_size, 0};
   ZSTD_DCtx* dctx = stuff->dctx;
+  int32_t custom_dict_id = ZSTD_getDictID_fromFrame(in.src, in.size);
+  char *custom_dict = NULL;
+  size_t custom_dict_size = 0;
+  ZSTD_DDict* custom_ddict = NULL;
+
+  if (custom_dict_id != 0 && custom_dict_id != stuff->dict_id) {
+    /* fprintf(stderr, "Instead of dict %u, this perturbed blob wants dict %u.\n", stuff->dict_id, custom_dict_id); */
+    custom_ddict = readDictByID(stuff, custom_dict_id, &custom_dict, &custom_dict_size);
+  }
 
   ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
-  ZSTD_DCtx_refDDict(dctx, NULL);
+
+  if (custom_ddict != NULL) {
+    ZSTD_DCtx_refDDict(dctx, custom_ddict);
+  } else {
+    ZSTD_DCtx_refDDict(dctx, stuff->ddict);
+  }
 
   while (in.pos != in.size) {
     out.pos = 0;
@@ -168,11 +308,20 @@ static int test_decompress(stuff_t* stuff) {
       fprintf(
           stderr, "Decompression failed: %s\n", ZSTD_getErrorName(ret));
       */
+      if (custom_ddict != NULL) {
+        ZSTD_freeDDict(custom_ddict);
+        free(custom_dict);
+      }
       return 0;
     }
   }
 
   stuff->success_count++;
+
+  if (custom_ddict != NULL) {
+    ZSTD_freeDDict(custom_ddict);
+    free(custom_dict);
+  }
   return 1;
 }
 
@@ -245,4 +394,4 @@ int main(int argc, char* argv[]) {
   free_stuff(&stuff);
 
   return 0;
-}
\ No newline at end of file
+}