/* TODO(bkjg): add tests for this function */
int distribution_sub(distribution_t *d1, distribution_t *d2) {
- if (d1->total_sum < d2->total_sum) {
- d1->total_sum = d2->total_sum - d1->total_sum;
- for (size_t i = 0; i < tree_size(d1->num_buckets); ++i) {
- if (d1->tree[i].maximum != d2->tree[i].maximum ||
- d1->tree[i].bucket_counter > d2->tree[i].bucket_counter) {
- pthread_mutex_unlock(&d2->mutex);
- pthread_mutex_unlock(&d1->mutex);
- return EINVAL;
- }
-
- d1->tree[i].bucket_counter =
- d2->tree[i].bucket_counter - d1->tree[i].bucket_counter;
- }
- } else {
- d1->total_sum -= d2->total_sum;
- for (size_t i = 0; i < tree_size(d1->num_buckets); ++i) {
- if (d1->tree[i].maximum != d2->tree[i].maximum ||
- d1->tree[i].bucket_counter < d2->tree[i].bucket_counter) {
- pthread_mutex_unlock(&d2->mutex);
- pthread_mutex_unlock(&d1->mutex);
- return EINVAL;
- }
-
- d1->tree[i].bucket_counter -= d2->tree[i].bucket_counter;
- }
+ int cmp_status = distribution_cmp(d1, d2);
+ if (cmp_status != 1 && cmp_status != 0) { // i.e. d1 < d2 or can't compare
+ if (cmp_status == -1)
+ cmp_status = ERANGE;
+ return cmp_status;
+ }
+
+ pthread_mutex_lock(&d1->mutex);
+ pthread_mutex_lock(&d2->mutex);
+
+ d1->total_sum -= d2->total_sum;
+ d1->total_square_sum -= d2->total_square_sum;
+ for (size_t i = 0; i < tree_size(d1->num_buckets); i++) {
+ d1->tree[i].bucket_counter -= d2.tree[i].bucket_counter;
}
pthread_mutex_unlock(&d2->mutex);