]> git.ipfire.org Git - thirdparty/gcc.git/blame - libstdc++-v3/include/parallel/partial_sum.h
MAINTAINERS: Update my email address.
[thirdparty/gcc.git] / libstdc++-v3 / include / parallel / partial_sum.h
CommitLineData
c2ba9709
JS
1// -*- C++ -*-
2
3// Copyright (C) 2007 Free Software Foundation, Inc.
4//
5// This file is part of the GNU ISO C++ Library. This library is free
6// software; you can redistribute it and/or modify it under the terms
7// of the GNU General Public License as published by the Free Software
8// Foundation; either version 2, or (at your option) any later
9// version.
10
11// This library is distributed in the hope that it will be useful, but
12// WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14// General Public License for more details.
15
16// You should have received a copy of the GNU General Public License
17// along with this library; see the file COPYING. If not, write to
18// the Free Software Foundation, 59 Temple Place - Suite 330, Boston,
19// MA 02111-1307, USA.
20
21// As a special exception, you may use this file as part of a free
22// software library without restriction. Specifically, if other files
23// instantiate templates or use macros or inline functions from this
24// file, or you compile this file and link it with other files to
25// produce an executable, this file does not by itself cause the
26// resulting executable to be covered by the GNU General Public
27// License. This exception does not however invalidate any other
28// reasons why the executable file might be covered by the GNU General
29// Public License.
30
31/** @file parallel/partial_sum.h
32 * @brief Parallel implementation of std::partial_sum(), i. e. prefix
33 * sums.
34 * This file is a GNU parallel extension to the Standard C++ Library.
35 */
36
37// Written by Johannes Singler.
38
39#ifndef _GLIBCXX_PARALLEL_PARTIAL_SUM_H
40#define _GLIBCXX_PARALLEL_PARTIAL_SUM_H 1
41
c2ba9709
JS
42#include <omp.h>
43#include <bits/stl_algobase.h>
44#include <parallel/parallel.h>
45#include <parallel/numericfwd.h>
46
47namespace __gnu_parallel
48{
49 // Problem: there is no 0-element given.
50
e683ee2a
JS
51/** @brief Base case prefix sum routine.
52 * @param begin Begin iterator of input sequence.
53 * @param end End iterator of input sequence.
54 * @param result Begin iterator of output sequence.
55 * @param bin_op Associative binary function.
56 * @param value Start value. Must be passed since the neutral
57 * element is unknown in general.
58 * @return End iterator of output sequence. */
59template<
60 typename InputIterator,
61 typename OutputIterator,
62 typename BinaryOperation>
c2ba9709 63 inline OutputIterator
e683ee2a
JS
64 parallel_partial_sum_basecase(
65 InputIterator begin, InputIterator end,
66 OutputIterator result, BinaryOperation bin_op,
67 typename std::iterator_traits<InputIterator>::value_type value)
c2ba9709
JS
68 {
69 if (begin == end)
70 return result;
71
72 while (begin != end)
73 {
e683ee2a
JS
74 value = bin_op(value, *begin);
75 *result = value;
1661473b
JS
76 ++result;
77 ++begin;
c2ba9709
JS
78 }
79 return result;
80 }
81
e683ee2a
JS
82/** @brief Parallel partial sum implementation, two-phase approach,
83 no recursion.
84 * @param begin Begin iterator of input sequence.
85 * @param end End iterator of input sequence.
86 * @param result Begin iterator of output sequence.
87 * @param bin_op Associative binary function.
88 * @param n Length of sequence.
89 * @param num_threads Number of threads to use.
90 * @return End iterator of output sequence.
91 */
92template<
93 typename InputIterator,
94 typename OutputIterator,
95 typename BinaryOperation>
c2ba9709 96 OutputIterator
e683ee2a
JS
97 parallel_partial_sum_linear(
98 InputIterator begin, InputIterator end,
99 OutputIterator result, BinaryOperation bin_op,
100 typename std::iterator_traits<InputIterator>::difference_type n)
c2ba9709
JS
101 {
102 typedef std::iterator_traits<InputIterator> traits_type;
103 typedef typename traits_type::value_type value_type;
104 typedef typename traits_type::difference_type difference_type;
105
1661473b
JS
106 if (begin == end)
107 return result;
108
e683ee2a
JS
109 thread_index_t num_threads =
110 std::min<difference_type>(get_max_threads(), n - 1);
111
c2ba9709
JS
112 if (num_threads < 2)
113 {
e683ee2a
JS
114 *result = *begin;
115 return parallel_partial_sum_basecase(
116 begin + 1, end, result + 1, bin_op, *begin);
c2ba9709
JS
117 }
118
e683ee2a
JS
119 difference_type* borders;
120 value_type* sums;
c2ba9709 121
e683ee2a 122# pragma omp parallel num_threads(num_threads)
c2ba9709 123 {
e683ee2a
JS
124# pragma omp single
125 {
126 num_threads = omp_get_num_threads();
127
128 borders = new difference_type[num_threads + 2];
129
130 if (Settings::partial_sum_dilatation == 1.0f)
131 equally_split(n, num_threads + 1, borders);
132 else
133 {
134 difference_type chunk_length =
135 ((double)n /
136 ((double)num_threads + Settings::partial_sum_dilatation)),
137 borderstart = n - num_threads * chunk_length;
138 borders[0] = 0;
1661473b 139 for (int i = 1; i < (num_threads + 1); ++i)
e683ee2a
JS
140 {
141 borders[i] = borderstart;
142 borderstart += chunk_length;
143 }
144 borders[num_threads + 1] = n;
145 }
146
147 sums = static_cast<value_type*>(
148 ::operator new(sizeof(value_type) * num_threads));
149 OutputIterator target_end;
150 } //single
151
1661473b 152 thread_index_t iam = omp_get_thread_num();
e683ee2a
JS
153 if (iam == 0)
154 {
155 *result = *begin;
156 parallel_partial_sum_basecase(begin + 1, begin + borders[1],
157 result + 1, bin_op, *begin);
1661473b 158 new(&(sums[iam])) value_type(*(result + borders[1] - 1));
e683ee2a
JS
159 }
160 else
161 {
1661473b
JS
162 new(&(sums[iam])) value_type(
163 std::accumulate(begin + borders[iam] + 1,
164 begin + borders[iam + 1],
165 *(begin + borders[iam]),
166 bin_op, __gnu_parallel::sequential_tag()));
e683ee2a
JS
167 }
168
169# pragma omp barrier
170
171# pragma omp single
172 parallel_partial_sum_basecase(
173 sums + 1, sums + num_threads, sums + 1, bin_op, sums[0]);
174
175# pragma omp barrier
176
177 // Still same team.
178 parallel_partial_sum_basecase(begin + borders[iam + 1],
179 begin + borders[iam + 2],
180 result + borders[iam + 1], bin_op,
181 sums[iam]);
182 } //parallel
183
184 delete[] sums;
185 delete[] borders;
c2ba9709
JS
186
187 return result + n;
188 }
189
e683ee2a
JS
190/** @brief Parallel partial sum front-end.
191 * @param begin Begin iterator of input sequence.
192 * @param end End iterator of input sequence.
193 * @param result Begin iterator of output sequence.
194 * @param bin_op Associative binary function.
195 * @return End iterator of output sequence. */
196template<
197 typename InputIterator,
198 typename OutputIterator,
199 typename BinaryOperation>
c2ba9709
JS
200 OutputIterator
201 parallel_partial_sum(InputIterator begin, InputIterator end,
e683ee2a 202 OutputIterator result, BinaryOperation bin_op)
c2ba9709 203 {
e683ee2a 204 _GLIBCXX_CALL(begin - end)
c2ba9709
JS
205
206 typedef std::iterator_traits<InputIterator> traits_type;
207 typedef typename traits_type::value_type value_type;
208 typedef typename traits_type::difference_type difference_type;
209
210 difference_type n = end - begin;
211
c2ba9709
JS
212 switch (Settings::partial_sum_algorithm)
213 {
214 case Settings::LINEAR:
e683ee2a
JS
215 // Need an initial offset.
216 return parallel_partial_sum_linear(begin, end, result, bin_op, n);
c2ba9709 217 default:
e683ee2a
JS
218 // Partial_sum algorithm not implemented.
219 _GLIBCXX_PARALLEL_ASSERT(0);
220 return result + n;
c2ba9709
JS
221 }
222 }
223}
224
225#endif