const unsigned kFirst = 0;
const size_t totalSamplesSize = COVER_sum(samplesSizes, kFirst, nbSamples);
/* Split samples into testing and training sets */
- const unsigned nbTrainSamples = (unsigned)((double)nbSamples * splitPoint);
- const unsigned nbTestSamples = nbSamples - nbTrainSamples;
- const size_t trainingSamplesSize = COVER_sum(samplesSizes, kFirst, nbTrainSamples);
- const size_t testSamplesSize = COVER_sum(samplesSizes, nbTrainSamples, nbSamples);
+ const unsigned nbTrainSamples = splitPoint < 1.0 ? (unsigned)((double)nbSamples * splitPoint) : nbSamples;
+ const unsigned nbTestSamples = splitPoint < 1.0 ? nbSamples - nbTrainSamples : nbSamples;
+ const size_t trainingSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, kFirst, nbTrainSamples) : totalSamplesSize;
+ const size_t testSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, nbTrainSamples, nbSamples) : totalSamplesSize;
/* Checks */
if (totalSamplesSize < MAX(d, sizeof(U64)) ||
totalSamplesSize >= (size_t)COVER_MAX_SAMPLES_SIZE) {
DISPLAYLEVEL(1, "Total number of training samples is %u and is invalid.", nbTrainSamples);
return 0;
}
- /* Check if there's testing sample when splitPoint is not 1.0 */
- if (nbTestSamples < 1 && splitPoint < 1.0) {
+ /* Check if there's testing sample */
+ if (nbTestSamples < 1) {
DISPLAYLEVEL(1, "Total number of testing samples is %u and is invalid.", nbTestSamples);
return 0;
}
- if (nbTrainSamples + nbTestSamples != nbSamples) {
+ /* Check if nbTrainSamples plus nbTestSamples add up to nbSamples when splitPoint is less than 1*/
+ if (nbTrainSamples + nbTestSamples != nbSamples && splitPoint < 1.0) {
DISPLAYLEVEL(1, "nbTrainSamples plus nbTestSamples don't add up to nbSamples");
return 0;
}
/* Allocate dst with enough space to compress the maximum sized sample */
{
size_t maxSampleSize = 0;
- for (i = ctx->nbTrainSamples; i < ctx->nbSamples; ++i) {
+ i = parameters.splitPoint < 1.0 ? ctx->nbTrainSamples : 0;
+ for (; i < ctx->nbSamples; ++i) {
maxSampleSize = MAX(ctx->samplesSizes[i], maxSampleSize);
}
dstCapacity = ZSTD_compressBound(maxSampleSize);
/* constants */
const unsigned nbThreads = parameters->nbThreads;
const double splitPoint =
- parameters->splitPoint <= 0.0 ? DEFAULT_SPLITPOINT : parameters->splitPoint;
+ (parameters->splitPoint <= 0.0 || parameters->splitPoint > 1.0) ? DEFAULT_SPLITPOINT : parameters->splitPoint;
const unsigned kMinD = parameters->d == 0 ? 6 : parameters->d;
const unsigned kMaxD = parameters->d == 0 ? 8 : parameters->d;
const unsigned kMinK = parameters->k == 0 ? 50 : parameters->k;
POOL_ctx *pool = NULL;
/* Checks */
- if (splitPoint <= 0 || splitPoint >= 1) {
+ if (splitPoint <= 0 || splitPoint > 1) {
LOCALDISPLAYLEVEL(displayLevel, 1, "Incorrect parameters\n");
return ERROR(GENERIC);
}