]> git.ipfire.org Git - thirdparty/gcc.git/blame - libstdc++-v3/include/parallel/multiway_mergesort.h
Daily bump.
[thirdparty/gcc.git] / libstdc++-v3 / include / parallel / multiway_mergesort.h
CommitLineData
c2ba9709
JS
1// -*- C++ -*-
2
fac9044d 3// Copyright (C) 2007, 2008 Free Software Foundation, Inc.
c2ba9709
JS
4//
5// This file is part of the GNU ISO C++ Library. This library is free
6// software; you can redistribute it and/or modify it under the terms
7// of the GNU General Public License as published by the Free Software
8// Foundation; either version 2, or (at your option) any later
9// version.
10
11// This library is distributed in the hope that it will be useful, but
12// WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14// General Public License for more details.
15
16// You should have received a copy of the GNU General Public License
17// along with this library; see the file COPYING. If not, write to
18// the Free Software Foundation, 59 Temple Place - Suite 330, Boston,
19// MA 02111-1307, USA.
20
21// As a special exception, you may use this file as part of a free
22// software library without restriction. Specifically, if other files
23// instantiate templates or use macros or inline functions from this
24// file, or you compile this file and link it with other files to
25// produce an executable, this file does not by itself cause the
26// resulting executable to be covered by the GNU General Public
27// License. This exception does not however invalidate any other
28// reasons why the executable file might be covered by the GNU General
29// Public License.
30
31/** @file parallel/multiway_mergesort.h
32 * @brief Parallel multiway merge sort.
33 * This file is a GNU parallel extension to the Standard C++ Library.
34 */
35
36// Written by Johannes Singler.
37
38#ifndef _GLIBCXX_PARALLEL_MERGESORT_H
39#define _GLIBCXX_PARALLEL_MERGESORT_H 1
40
41#include <vector>
42
43#include <parallel/basic_iterator.h>
44#include <bits/stl_algo.h>
45#include <parallel/parallel.h>
46#include <parallel/multiway_merge.h>
c2ba9709
JS
47
48namespace __gnu_parallel
49{
50
e683ee2a
JS
51/** @brief Subsequence description. */
52template<typename _DifferenceTp>
c2ba9709
JS
53 struct Piece
54 {
55 typedef _DifferenceTp difference_type;
56
57 /** @brief Begin of subsequence. */
58 difference_type begin;
59
60 /** @brief End of subsequence. */
61 difference_type end;
62 };
63
e683ee2a
JS
64/** @brief Data accessed by all threads.
65 *
66 * PMWMS = parallel multiway mergesort */
67template<typename RandomAccessIterator>
c2ba9709
JS
68 struct PMWMSSortingData
69 {
70 typedef std::iterator_traits<RandomAccessIterator> traits_type;
71 typedef typename traits_type::value_type value_type;
72 typedef typename traits_type::difference_type difference_type;
73
e683ee2a
JS
74 /** @brief Number of threads involved. */
75 thread_index_t num_threads;
76
c2ba9709
JS
77 /** @brief Input begin. */
78 RandomAccessIterator source;
79
80 /** @brief Start indices, per thread. */
81 difference_type* starts;
82
83 /** @brief Temporary arrays for each thread.
84 *
85 * Indirection Allows using the temporary storage in different
86 * ways, without code duplication.
87 * @see _GLIBCXX_MULTIWAY_MERGESORT_COPY_LAST */
88 value_type** temporaries;
89
90#if _GLIBCXX_MULTIWAY_MERGESORT_COPY_LAST
91 /** @brief Storage in which to sort. */
92 RandomAccessIterator* sorting_places;
93
94 /** @brief Storage into which to merge. */
95 value_type** merging_places;
96#else
97 /** @brief Storage in which to sort. */
98 value_type** sorting_places;
99
100 /** @brief Storage into which to merge. */
101 RandomAccessIterator* merging_places;
102#endif
103 /** @brief Samples. */
104 value_type* samples;
105
106 /** @brief Offsets to add to the found positions. */
107 difference_type* offsets;
108
109 /** @brief Pieces of data to merge @c [thread][sequence] */
110 std::vector<Piece<difference_type> >* pieces;
c2ba9709 111
c2ba9709
JS
112 /** @brief Stable sorting desired. */
113 bool stable;
e683ee2a
JS
114};
115
116/**
117 * @brief Select samples from a sequence.
118 * @param sd Pointer to algorithm data. Result will be placed in
119 * @c sd->samples.
120 * @param num_samples Number of samples to select.
121 */
122template<typename RandomAccessIterator, typename _DifferenceTp>
5817ff8e 123 void
e683ee2a
JS
124 determine_samples(PMWMSSortingData<RandomAccessIterator>* sd,
125 _DifferenceTp& num_samples)
c2ba9709 126 {
1661473b
JS
127 typedef std::iterator_traits<RandomAccessIterator> traits_type;
128 typedef typename traits_type::value_type value_type;
c2ba9709
JS
129 typedef _DifferenceTp difference_type;
130
e683ee2a 131 thread_index_t iam = omp_get_thread_num();
c2ba9709 132
ee1b5fc5 133 num_samples = _Settings::get().sort_mwms_oversampling * sd->num_threads - 1;
c2ba9709 134
e683ee2a 135 difference_type* es = new difference_type[num_samples + 2];
c891154f 136
e683ee2a
JS
137 equally_split(sd->starts[iam + 1] - sd->starts[iam],
138 num_samples + 1, es);
c2ba9709 139
5817ff8e
PC
140 for (difference_type i = 0; i < num_samples; ++i)
141 ::new(&(sd->samples[iam * num_samples + i]))
142 value_type(sd->source[sd->starts[iam] + es[i + 1]]);
e683ee2a
JS
143
144 delete[] es;
c2ba9709
JS
145 }
146
e683ee2a
JS
147/** @brief PMWMS code executed by each thread.
148 * @param sd Pointer to algorithm data.
149 * @param comp Comparator.
150 */
151template<typename RandomAccessIterator, typename Comparator>
5817ff8e 152 void
e683ee2a
JS
153 parallel_sort_mwms_pu(PMWMSSortingData<RandomAccessIterator>* sd,
154 Comparator& comp)
c2ba9709
JS
155 {
156 typedef std::iterator_traits<RandomAccessIterator> traits_type;
157 typedef typename traits_type::value_type value_type;
158 typedef typename traits_type::difference_type difference_type;
159
e683ee2a 160 thread_index_t iam = omp_get_thread_num();
c2ba9709
JS
161
162 // Length of this thread's chunk, before merging.
163 difference_type length_local = sd->starts[iam + 1] - sd->starts[iam];
164
165#if _GLIBCXX_MULTIWAY_MERGESORT_COPY_LAST
166 typedef RandomAccessIterator SortingPlacesIterator;
167
168 // Sort in input storage.
169 sd->sorting_places[iam] = sd->source + sd->starts[iam];
170#else
171 typedef value_type* SortingPlacesIterator;
172
173 // Sort in temporary storage, leave space for sentinel.
e683ee2a
JS
174 sd->sorting_places[iam] = sd->temporaries[iam] =
175 static_cast<value_type*>(
176 ::operator new(sizeof(value_type) * (length_local + 1)));
c2ba9709
JS
177
178 // Copy there.
e683ee2a
JS
179 std::uninitialized_copy(sd->source + sd->starts[iam],
180 sd->source + sd->starts[iam] + length_local,
181 sd->sorting_places[iam]);
c2ba9709
JS
182#endif
183
184 // Sort locally.
e683ee2a
JS
185 if (sd->stable)
186 __gnu_sequential::stable_sort(sd->sorting_places[iam],
187 sd->sorting_places[iam] + length_local,
188 comp);
c2ba9709 189 else
e683ee2a
JS
190 __gnu_sequential::sort(sd->sorting_places[iam],
191 sd->sorting_places[iam] + length_local,
192 comp);
c2ba9709
JS
193
194 // Invariant: locally sorted subsequence in sd->sorting_places[iam],
195 // sd->sorting_places[iam] + length_local.
ee1b5fc5
BK
196 const _Settings& __s = _Settings::get();
197 if (__s.sort_splitting == SAMPLING)
c2ba9709 198 {
e683ee2a
JS
199 difference_type num_samples;
200 determine_samples(sd, num_samples);
201
202# pragma omp barrier
203
204# pragma omp single
205 __gnu_sequential::sort(sd->samples,
206 sd->samples + (num_samples * sd->num_threads),
207 comp);
208
209# pragma omp barrier
210
5817ff8e 211 for (int s = 0; s < sd->num_threads; ++s)
e683ee2a
JS
212 {
213 // For each sequence.
214 if (num_samples * iam > 0)
215 sd->pieces[iam][s].begin =
216 std::lower_bound(sd->sorting_places[s],
1661473b
JS
217 sd->sorting_places[s]
218 + (sd->starts[s + 1] - sd->starts[s]),
e683ee2a
JS
219 sd->samples[num_samples * iam],
220 comp)
221 - sd->sorting_places[s];
222 else
223 // Absolute beginning.
224 sd->pieces[iam][s].begin = 0;
225
226 if ((num_samples * (iam + 1)) < (num_samples * sd->num_threads))
227 sd->pieces[iam][s].end =
228 std::lower_bound(sd->sorting_places[s],
1661473b
JS
229 sd->sorting_places[s]
230 + (sd->starts[s + 1] - sd->starts[s]),
231 sd->samples[num_samples * (iam + 1)],
232 comp)
e683ee2a
JS
233 - sd->sorting_places[s];
234 else
235 // Absolute end.
236 sd->pieces[iam][s].end = sd->starts[s + 1] - sd->starts[s];
237 }
c2ba9709 238 }
ee1b5fc5 239 else if (__s.sort_splitting == EXACT)
c2ba9709 240 {
e683ee2a
JS
241# pragma omp barrier
242
243 std::vector<std::pair<SortingPlacesIterator, SortingPlacesIterator> >
244 seqs(sd->num_threads);
5817ff8e 245 for (int s = 0; s < sd->num_threads; ++s)
e683ee2a 246 seqs[s] = std::make_pair(sd->sorting_places[s],
1661473b
JS
247 sd->sorting_places[s]
248 + (sd->starts[s + 1] - sd->starts[s]));
e683ee2a
JS
249
250 std::vector<SortingPlacesIterator> offsets(sd->num_threads);
251
252 // if not last thread
253 if (iam < sd->num_threads - 1)
254 multiseq_partition(seqs.begin(), seqs.end(),
255 sd->starts[iam + 1], offsets.begin(), comp);
256
5817ff8e 257 for (int seq = 0; seq < sd->num_threads; ++seq)
e683ee2a
JS
258 {
259 // for each sequence
260 if (iam < (sd->num_threads - 1))
261 sd->pieces[iam][seq].end = offsets[seq] - seqs[seq].first;
262 else
263 // very end of this sequence
5817ff8e
PC
264 sd->pieces[iam][seq].end = (sd->starts[seq + 1]
265 - sd->starts[seq]);
e683ee2a
JS
266 }
267
268# pragma omp barrier
269
5817ff8e 270 for (int seq = 0; seq < sd->num_threads; ++seq)
e683ee2a
JS
271 {
272 // For each sequence.
273 if (iam > 0)
274 sd->pieces[iam][seq].begin = sd->pieces[iam - 1][seq].end;
275 else
276 // Absolute beginning.
277 sd->pieces[iam][seq].begin = 0;
278 }
c2ba9709
JS
279 }
280
c2ba9709
JS
281 // Offset from target begin, length after merging.
282 difference_type offset = 0, length_am = 0;
5817ff8e 283 for (int s = 0; s < sd->num_threads; ++s)
c2ba9709 284 {
e683ee2a
JS
285 length_am += sd->pieces[iam][s].end - sd->pieces[iam][s].begin;
286 offset += sd->pieces[iam][s].begin;
c2ba9709
JS
287 }
288
289#if _GLIBCXX_MULTIWAY_MERGESORT_COPY_LAST
290 // Merge to temporary storage, uninitialized creation not possible
291 // since there is no multiway_merge calling the placement new
292 // instead of the assignment operator.
1661473b 293 // XXX incorrect (de)construction
e683ee2a 294 sd->merging_places[iam] = sd->temporaries[iam] =
5817ff8e
PC
295 static_cast<value_type*>(::operator new(sizeof(value_type)
296 * length_am));
c2ba9709
JS
297#else
298 // Merge directly to target.
299 sd->merging_places[iam] = sd->source + offset;
300#endif
e683ee2a
JS
301 std::vector<std::pair<SortingPlacesIterator, SortingPlacesIterator> >
302 seqs(sd->num_threads);
c2ba9709 303
5817ff8e 304 for (int s = 0; s < sd->num_threads; ++s)
c2ba9709 305 {
1661473b 306 seqs[s] =
5817ff8e
PC
307 std::make_pair(sd->sorting_places[s] + sd->pieces[iam][s].begin,
308 sd->sorting_places[s] + sd->pieces[iam][s].end);
c2ba9709
JS
309 }
310
1661473b
JS
311 multiway_merge(seqs.begin(), seqs.end(), sd->merging_places[iam], comp,
312 length_am, sd->stable, false, sequential_tag());
c2ba9709 313
e683ee2a 314# pragma omp barrier
c2ba9709
JS
315
316#if _GLIBCXX_MULTIWAY_MERGESORT_COPY_LAST
317 // Write back.
e683ee2a
JS
318 std::copy(sd->merging_places[iam],
319 sd->merging_places[iam] + length_am,
320 sd->source + offset);
c2ba9709
JS
321#endif
322
fac9044d 323 ::operator delete(sd->temporaries[iam]);
c2ba9709
JS
324 }
325
e683ee2a
JS
326/** @brief PMWMS main call.
327 * @param begin Begin iterator of sequence.
328 * @param end End iterator of sequence.
329 * @param comp Comparator.
330 * @param n Length of sequence.
331 * @param num_threads Number of threads to use.
332 * @param stable Stable sorting.
333 */
334template<typename RandomAccessIterator, typename Comparator>
5817ff8e 335 void
e683ee2a 336 parallel_sort_mwms(RandomAccessIterator begin, RandomAccessIterator end,
5817ff8e
PC
337 Comparator comp, typename
338 std::iterator_traits<RandomAccessIterator>::
339 difference_type n, int num_threads, bool stable)
c2ba9709
JS
340 {
341 _GLIBCXX_CALL(n)
e683ee2a 342
c2ba9709
JS
343 typedef std::iterator_traits<RandomAccessIterator> traits_type;
344 typedef typename traits_type::value_type value_type;
345 typedef typename traits_type::difference_type difference_type;
346
347 if (n <= 1)
348 return;
349
e683ee2a 350 // at least one element per thread
c2ba9709
JS
351 if (num_threads > n)
352 num_threads = static_cast<thread_index_t>(n);
353
e683ee2a 354 // shared variables
c2ba9709 355 PMWMSSortingData<RandomAccessIterator> sd;
e683ee2a 356 difference_type* starts;
ee1b5fc5 357 const _Settings& __s = _Settings::get();
c2ba9709 358
e683ee2a
JS
359# pragma omp parallel num_threads(num_threads)
360 {
361 num_threads = omp_get_num_threads(); //no more threads than requested
362
363# pragma omp single
364 {
365 sd.num_threads = num_threads;
366 sd.source = begin;
367 sd.temporaries = new value_type*[num_threads];
c2ba9709
JS
368
369#if _GLIBCXX_MULTIWAY_MERGESORT_COPY_LAST
e683ee2a
JS
370 sd.sorting_places = new RandomAccessIterator[num_threads];
371 sd.merging_places = new value_type*[num_threads];
c2ba9709 372#else
e683ee2a
JS
373 sd.sorting_places = new value_type*[num_threads];
374 sd.merging_places = new RandomAccessIterator[num_threads];
c2ba9709
JS
375#endif
376
ee1b5fc5 377 if (__s.sort_splitting == SAMPLING)
e683ee2a
JS
378 {
379 unsigned int size =
ee1b5fc5 380 (__s.sort_mwms_oversampling * num_threads - 1)
1661473b 381 * num_threads;
e683ee2a 382 sd.samples = static_cast<value_type*>(
5817ff8e 383 ::operator new(size * sizeof(value_type)));
e683ee2a
JS
384 }
385 else
386 sd.samples = NULL;
387
388 sd.offsets = new difference_type[num_threads - 1];
389 sd.pieces = new std::vector<Piece<difference_type> >[num_threads];
5817ff8e 390 for (int s = 0; s < num_threads; ++s)
e683ee2a
JS
391 sd.pieces[s].resize(num_threads);
392 starts = sd.starts = new difference_type[num_threads + 1];
393 sd.stable = stable;
394
395 difference_type chunk_length = n / num_threads;
396 difference_type split = n % num_threads;
397 difference_type pos = 0;
5817ff8e 398 for (int i = 0; i < num_threads; ++i)
e683ee2a
JS
399 {
400 starts[i] = pos;
401 pos += (i < split) ? (chunk_length + 1) : chunk_length;
402 }
403 starts[num_threads] = pos;
404 }
405
406 // Now sort in parallel.
407 parallel_sort_mwms_pu(&sd, comp);
408 } //parallel
c2ba9709 409
c2ba9709
JS
410 delete[] starts;
411 delete[] sd.temporaries;
412 delete[] sd.sorting_places;
413 delete[] sd.merging_places;
414
ee1b5fc5 415 if (__s.sort_splitting == SAMPLING)
fac9044d 416 ::operator delete(sd.samples);
c2ba9709
JS
417
418 delete[] sd.offsets;
419 delete[] sd.pieces;
c2ba9709 420 }
e683ee2a 421} //namespace __gnu_parallel
c2ba9709
JS
422
423#endif