/**
* Returns the sum of the sample sizes.
*/
-static size_t COVER_sum(const size_t *samplesSizes, unsigned firstSample, unsigned lastSample) {
+static size_t COVER_sum(const size_t *samplesSizes, unsigned nbSamples) {
size_t sum = 0;
unsigned i;
- for (i = firstSample; i < lastSample; ++i) {
+ for (i = 0; i < nbSamples; ++i) {
sum += samplesSizes[i];
}
return sum;
const size_t *samplesSizes, unsigned nbSamples,
unsigned d, double splitPoint) {
const BYTE *const samples = (const BYTE *)samplesBuffer;
- const unsigned kFirst = 0;
- const size_t totalSamplesSize = COVER_sum(samplesSizes, kFirst, nbSamples);
+ const size_t totalSamplesSize = COVER_sum(samplesSizes, nbSamples);
/* Split samples into testing and training sets */
const unsigned nbTrainSamples = splitPoint < 1.0 ? (unsigned)((double)nbSamples * splitPoint) : nbSamples;
const unsigned nbTestSamples = splitPoint < 1.0 ? nbSamples - nbTrainSamples : nbSamples;
- const size_t trainingSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, kFirst, nbTrainSamples) : totalSamplesSize;
- const size_t testSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, nbTrainSamples, nbSamples) : totalSamplesSize;
+ const size_t trainingSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes, nbTrainSamples) : totalSamplesSize;
+ const size_t testSamplesSize = splitPoint < 1.0 ? COVER_sum(samplesSizes + nbTrainSamples, nbTestSamples) : totalSamplesSize;
/* Checks */
if (totalSamplesSize < MAX(d, sizeof(U64)) ||
totalSamplesSize >= (size_t)COVER_MAX_SAMPLES_SIZE) {
DISPLAYLEVEL(1, "Total number of testing samples is %u and is invalid.", nbTestSamples);
return 0;
}
- /* Check if nbTrainSamples plus nbTestSamples add up to nbSamples when splitPoint is less than 1*/
- if (nbTrainSamples + nbTestSamples != nbSamples && splitPoint < 1.0) {
- DISPLAYLEVEL(1, "nbTrainSamples plus nbTestSamples don't add up to nbSamples");
- return 0;
- }
/* Zero the context */
memset(ctx, 0, sizeof(*ctx));
DISPLAYLEVEL(2, "Training on %u samples of total size %u\n", nbTrainSamples,
/* constants */
const unsigned nbThreads = parameters->nbThreads;
const double splitPoint =
- (parameters->splitPoint <= 0.0 || parameters->splitPoint > 1.0) ? DEFAULT_SPLITPOINT : parameters->splitPoint;
+ parameters->splitPoint <= 0.0 ? DEFAULT_SPLITPOINT : parameters->splitPoint;
const unsigned kMinD = parameters->d == 0 ? 6 : parameters->d;
const unsigned kMaxD = parameters->d == 0 ? 8 : parameters->d;
const unsigned kMinK = parameters->k == 0 ? 50 : parameters->k;