]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Allow splitPoint==1.0 (using all samples for both training and testing)
authorJennifer Liu <jenniferliu620@fb.com>
Thu, 5 Jul 2018 17:38:45 +0000 (10:38 -0700)
committerJennifer Liu <jenniferliu620@fb.com>
Thu, 5 Jul 2018 17:38:45 +0000 (10:38 -0700)
lib/dictBuilder/cover.c
tests/playTests.sh

index 7ac7eb1c3b908332d5398090d67d16d7cd2840f9..2c19c0052f3ad178056d9e1a1764f050abe6513c 100644 (file)
@@ -543,10 +543,10 @@ static int COVER_ctx_init(COVER_ctx_t *ctx, const void *samplesBuffer,
   const unsigned kFirst = 0;
   const size_t totalSamplesSize = COVER_sum(samplesSizes, kFirst, nbSamples);
   /* Split samples into testing and training sets */
-  const unsigned nbTrainSamples = (unsigned)((double)nbSamples * splitPoint);
-  const unsigned nbTestSamples = nbSamples - nbTrainSamples;
-  const size_t trainingSamplesSize = COVER_sum(samplesSizes, kFirst, nbTrainSamples);
-  const size_t testSamplesSize = COVER_sum(samplesSizes, nbTrainSamples, nbSamples);
+  const unsigned nbTrainSamples = splitPoint < 1.0 ? (unsigned)((double)nbSamples * splitPoint) : nbSamples;
+  const unsigned nbTestSamples = splitPoint < 1.0 ? nbSamples - nbTrainSamples : nbSamples;
+  const size_t trainingSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, kFirst, nbTrainSamples) : totalSamplesSize;
+  const size_t testSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, nbTrainSamples, nbSamples) : totalSamplesSize;
   /* Checks */
   if (totalSamplesSize < MAX(d, sizeof(U64)) ||
       totalSamplesSize >= (size_t)COVER_MAX_SAMPLES_SIZE) {
@@ -559,12 +559,13 @@ static int COVER_ctx_init(COVER_ctx_t *ctx, const void *samplesBuffer,
     DISPLAYLEVEL(1, "Total number of training samples is %u and is invalid.", nbTrainSamples);
     return 0;
   }
-  /* Check if there's testing sample when splitPoint is not 1.0 */
-  if (nbTestSamples < 1 && splitPoint < 1.0) {
+  /* Check if there's testing sample */
+  if (nbTestSamples < 1) {
     DISPLAYLEVEL(1, "Total number of testing samples is %u and is invalid.", nbTestSamples);
     return 0;
   }
-  if (nbTrainSamples + nbTestSamples != nbSamples) {
+  /* Check if nbTrainSamples plus nbTestSamples add up to nbSamples when splitPoint is less than 1*/
+  if (nbTrainSamples + nbTestSamples != nbSamples && splitPoint < 1.0) {
     DISPLAYLEVEL(1, "nbTrainSamples plus nbTestSamples don't add up to nbSamples");
     return 0;
   }
@@ -920,7 +921,8 @@ static void COVER_tryParameters(void *opaque) {
     /* Allocate dst with enough space to compress the maximum sized sample */
     {
       size_t maxSampleSize = 0;
-      for (i = ctx->nbTrainSamples; i < ctx->nbSamples; ++i) {
+      i = parameters.splitPoint < 1.0 ? ctx->nbTrainSamples : 0;
+      for (; i < ctx->nbSamples; ++i) {
         maxSampleSize = MAX(ctx->samplesSizes[i], maxSampleSize);
       }
       dstCapacity = ZSTD_compressBound(maxSampleSize);
@@ -973,7 +975,7 @@ ZDICTLIB_API size_t ZDICT_optimizeTrainFromBuffer_cover(
   /* constants */
   const unsigned nbThreads = parameters->nbThreads;
   const double splitPoint =
-      parameters->splitPoint <= 0.0 ? DEFAULT_SPLITPOINT : parameters->splitPoint;
+      (parameters->splitPoint <= 0.0 || parameters->splitPoint > 1.0) ? DEFAULT_SPLITPOINT : parameters->splitPoint;
   const unsigned kMinD = parameters->d == 0 ? 6 : parameters->d;
   const unsigned kMaxD = parameters->d == 0 ? 8 : parameters->d;
   const unsigned kMinK = parameters->k == 0 ? 50 : parameters->k;
@@ -991,7 +993,7 @@ ZDICTLIB_API size_t ZDICT_optimizeTrainFromBuffer_cover(
   POOL_ctx *pool = NULL;
 
   /* Checks */
-  if (splitPoint <= 0 || splitPoint >= 1) {
+  if (splitPoint <= 0 || splitPoint > 1) {
     LOCALDISPLAYLEVEL(displayLevel, 1, "Incorrect parameters\n");
     return ERROR(GENERIC);
   }
index 985b12d2dc23924875ac1e8465f1631c33d62736..3e15313755ac2c398f3dbc18a4d47f0a12e9f282 100755 (executable)
@@ -425,6 +425,8 @@ rm tmp*
 $ECHO "- Compare size of dictionary from 90% training samples with 80% training samples"
 $ZSTD --train-cover=split=90 -r *.c ../programs/*.c
 $ZSTD --train-cover=split=80 -r *.c ../programs/*.c
+$ECHO "- Create dictionary using all samples for both training and testing"
+$ZSTD --train-cover=split=100 -r *.c ../programs/*.c
 
 $ECHO "\n===>  legacy dictionary builder "