From: Jennifer Liu Date: Thu, 5 Jul 2018 17:38:45 +0000 (-0700) Subject: Allow splitPoint==1.0 (using all samples for both training and testing) X-Git-Tag: v0.0.29~69^2~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a085d1aae1dbd841baa8ed927465e4d686ccc213;p=thirdparty%2Fzstd.git Allow splitPoint==1.0 (using all samples for both training and testing) --- diff --git a/lib/dictBuilder/cover.c b/lib/dictBuilder/cover.c index 7ac7eb1c3..2c19c0052 100644 --- a/lib/dictBuilder/cover.c +++ b/lib/dictBuilder/cover.c @@ -543,10 +543,10 @@ static int COVER_ctx_init(COVER_ctx_t *ctx, const void *samplesBuffer, const unsigned kFirst = 0; const size_t totalSamplesSize = COVER_sum(samplesSizes, kFirst, nbSamples); /* Split samples into testing and training sets */ - const unsigned nbTrainSamples = (unsigned)((double)nbSamples * splitPoint); - const unsigned nbTestSamples = nbSamples - nbTrainSamples; - const size_t trainingSamplesSize = COVER_sum(samplesSizes, kFirst, nbTrainSamples); - const size_t testSamplesSize = COVER_sum(samplesSizes, nbTrainSamples, nbSamples); + const unsigned nbTrainSamples = splitPoint < 1.0 ? (unsigned)((double)nbSamples * splitPoint) : nbSamples; + const unsigned nbTestSamples = splitPoint < 1.0 ? nbSamples - nbTrainSamples : nbSamples; + const size_t trainingSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, kFirst, nbTrainSamples) : totalSamplesSize; + const size_t testSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, nbTrainSamples, nbSamples) : totalSamplesSize; /* Checks */ if (totalSamplesSize < MAX(d, sizeof(U64)) || totalSamplesSize >= (size_t)COVER_MAX_SAMPLES_SIZE) { @@ -559,12 +559,13 @@ static int COVER_ctx_init(COVER_ctx_t *ctx, const void *samplesBuffer, DISPLAYLEVEL(1, "Total number of training samples is %u and is invalid.", nbTrainSamples); return 0; } - /* Check if there's testing sample when splitPoint is not 1.0 */ - if (nbTestSamples < 1 && splitPoint < 1.0) { + /* Check if there's testing sample */ + if (nbTestSamples < 1) { DISPLAYLEVEL(1, "Total number of testing samples is %u and is invalid.", nbTestSamples); return 0; } - if (nbTrainSamples + nbTestSamples != nbSamples) { + /* Check if nbTrainSamples plus nbTestSamples add up to nbSamples when splitPoint is less than 1*/ + if (nbTrainSamples + nbTestSamples != nbSamples && splitPoint < 1.0) { DISPLAYLEVEL(1, "nbTrainSamples plus nbTestSamples don't add up to nbSamples"); return 0; } @@ -920,7 +921,8 @@ static void COVER_tryParameters(void *opaque) { /* Allocate dst with enough space to compress the maximum sized sample */ { size_t maxSampleSize = 0; - for (i = ctx->nbTrainSamples; i < ctx->nbSamples; ++i) { + i = parameters.splitPoint < 1.0 ? ctx->nbTrainSamples : 0; + for (; i < ctx->nbSamples; ++i) { maxSampleSize = MAX(ctx->samplesSizes[i], maxSampleSize); } dstCapacity = ZSTD_compressBound(maxSampleSize); @@ -973,7 +975,7 @@ ZDICTLIB_API size_t ZDICT_optimizeTrainFromBuffer_cover( /* constants */ const unsigned nbThreads = parameters->nbThreads; const double splitPoint = - parameters->splitPoint <= 0.0 ? DEFAULT_SPLITPOINT : parameters->splitPoint; + (parameters->splitPoint <= 0.0 || parameters->splitPoint > 1.0) ? DEFAULT_SPLITPOINT : parameters->splitPoint; const unsigned kMinD = parameters->d == 0 ? 6 : parameters->d; const unsigned kMaxD = parameters->d == 0 ? 8 : parameters->d; const unsigned kMinK = parameters->k == 0 ? 50 : parameters->k; @@ -991,7 +993,7 @@ ZDICTLIB_API size_t ZDICT_optimizeTrainFromBuffer_cover( POOL_ctx *pool = NULL; /* Checks */ - if (splitPoint <= 0 || splitPoint >= 1) { + if (splitPoint <= 0 || splitPoint > 1) { LOCALDISPLAYLEVEL(displayLevel, 1, "Incorrect parameters\n"); return ERROR(GENERIC); } diff --git a/tests/playTests.sh b/tests/playTests.sh index 985b12d2d..3e1531375 100755 --- a/tests/playTests.sh +++ b/tests/playTests.sh @@ -425,6 +425,8 @@ rm tmp* $ECHO "- Compare size of dictionary from 90% training samples with 80% training samples" $ZSTD --train-cover=split=90 -r *.c ../programs/*.c $ZSTD --train-cover=split=80 -r *.c ../programs/*.c +$ECHO "- Create dictionary using all samples for both training and testing" +$ZSTD --train-cover=split=100 -r *.c ../programs/*.c $ECHO "\n===> legacy dictionary builder "