]> git.ipfire.org Git - thirdparty/gcc.git/blame - libstdc++-v3/include/parallel/balanced_quicksort.h
re PR middle-end/33691 (Type checking error with bitwise and/or)
[thirdparty/gcc.git] / libstdc++-v3 / include / parallel / balanced_quicksort.h
CommitLineData
c2ba9709
JS
1// -*- C++ -*-
2
3// Copyright (C) 2007 Free Software Foundation, Inc.
4//
5// This file is part of the GNU ISO C++ Library. This library is free
6// software; you can redistribute it and/or modify it under the terms
7// of the GNU General Public License as published by the Free Software
8// Foundation; either version 2, or (at your option) any later
9// version.
10
11// This library is distributed in the hope that it will be useful, but
12// WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14// General Public License for more details.
15
16// You should have received a copy of the GNU General Public License
17// along with this library; see the file COPYING. If not, write to
18// the Free Software Foundation, 59 Temple Place - Suite 330, Boston,
19// MA 02111-1307, USA.
20
21// As a special exception, you may use this file as part of a free
22// software library without restriction. Specifically, if other files
23// instantiate templates or use macros or inline functions from this
24// file, or you compile this file and link it with other files to
25// produce an executable, this file does not by itself cause the
26// resulting executable to be covered by the GNU General Public
27// License. This exception does not however invalidate any other
28// reasons why the executable file might be covered by the GNU General
29// Public License.
30
31/** @file parallel/balanced_quicksort.h
32 * @brief Implementation of a dynamically load-balanced parallel quicksort.
33 *
34 * It works in-place and needs only logarithmic extra memory.
35 * This file is a GNU parallel extension to the Standard C++ Library.
36 */
37
38// Written by Johannes Singler.
39
40#ifndef _GLIBCXX_PARALLEL_BAL_QUICKSORT_H
41#define _GLIBCXX_PARALLEL_BAL_QUICKSORT_H 1
42
43#include <parallel/basic_iterator.h>
44#include <bits/stl_algo.h>
45
46#include <parallel/settings.h>
47#include <parallel/partition.h>
48#include <parallel/random_number.h>
49#include <parallel/queue.h>
50#include <functional>
51
52#if _GLIBCXX_ASSERTIONS
53#include <parallel/checkers.h>
54#endif
55
56namespace __gnu_parallel
57{
58 /** @brief Information local to one thread in the parallel quicksort run. */
59 template<typename RandomAccessIterator>
60 struct QSBThreadLocal
61 {
62 typedef std::iterator_traits<RandomAccessIterator> traits_type;
63 typedef typename traits_type::difference_type difference_type;
64
65 /** @brief Continuous part of the sequence, described by an
66 iterator pair. */
67 typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
68
69 /** @brief Initial piece to work on. */
70 Piece initial;
71
72 /** @brief Work-stealing queue. */
73 RestrictedBoundedConcurrentQueue<Piece> leftover_parts;
74
75 /** @brief Number of threads involved in this algorithm. */
76 thread_index_t num_threads;
77
78 /** @brief Pointer to a counter of elements left over to sort. */
79 volatile difference_type* elements_leftover;
80
81 /** @brief The complete sequence to sort. */
82 Piece global;
83
84 /** @brief Constructor.
85 * @param queue_size Size of the work-stealing queue. */
86 QSBThreadLocal(int queue_size) : leftover_parts(queue_size) { }
87 };
88
89 /** @brief Initialize the thread local storage.
90 * @param tls Array of thread-local storages.
91 * @param queue_size Size of the work-stealing queue. */
92 template<typename RandomAccessIterator>
93 inline void
94 qsb_initialize(QSBThreadLocal<RandomAccessIterator>** tls, int queue_size)
95 {
96 int iam = omp_get_thread_num();
97 tls[iam] = new QSBThreadLocal<RandomAccessIterator>(queue_size);
98 }
99
100
101 /** @brief Balanced quicksort divide step.
102 * @param begin Begin iterator of subsequence.
103 * @param end End iterator of subsequence.
104 * @param comp Comparator.
105 * @param num_threads Number of threads that are allowed to work on
106 * this part.
107 * @pre @c (end-begin)>=1 */
108 template<typename RandomAccessIterator, typename Comparator>
109 inline typename std::iterator_traits<RandomAccessIterator>::difference_type
110 qsb_divide(RandomAccessIterator begin, RandomAccessIterator end,
111 Comparator comp, int num_threads)
112 {
113 _GLIBCXX_PARALLEL_ASSERT(num_threads > 0);
114
115 typedef std::iterator_traits<RandomAccessIterator> traits_type;
116 typedef typename traits_type::value_type value_type;
117 typedef typename traits_type::difference_type difference_type;
118
119 RandomAccessIterator pivot_pos = median_of_three_iterators(begin, begin + (end - begin) / 2, end - 1, comp);
120
121#if defined(_GLIBCXX_ASSERTIONS)
122 // Must be in between somewhere.
123 difference_type n = end - begin;
124
125 _GLIBCXX_PARALLEL_ASSERT((!comp(*pivot_pos, *begin) && !comp(*(begin + n / 2), *pivot_pos))
126 || (!comp(*pivot_pos, *begin) && !comp(*end, *pivot_pos))
127 || (!comp(*pivot_pos, *(begin + n / 2)) && !comp(*begin, *pivot_pos))
128 || (!comp(*pivot_pos, *(begin + n / 2)) && !comp(*end, *pivot_pos))
129 || (!comp(*pivot_pos, *end) && !comp(*begin, *pivot_pos))
130 || (!comp(*pivot_pos, *end) && !comp(*(begin + n / 2), *pivot_pos)));
131#endif
132
133 // Swap pivot value to end.
134 if (pivot_pos != (end - 1))
135 std::swap(*pivot_pos, *(end - 1));
136 pivot_pos = end - 1;
137
138 __gnu_parallel::binder2nd<Comparator, value_type, value_type, bool> pred(comp, *pivot_pos);
139
140 // Divide, returning end - begin - 1 in the worst case.
141 difference_type split_pos = parallel_partition(begin, end - 1, pred, num_threads);
142
143 // Swap back pivot to middle.
144 std::swap(*(begin + split_pos), *pivot_pos);
145 pivot_pos = begin + split_pos;
146
147#if _GLIBCXX_ASSERTIONS
148 RandomAccessIterator r;
149 for (r = begin; r != pivot_pos; r++)
150 _GLIBCXX_PARALLEL_ASSERT(comp(*r, *pivot_pos));
151 for (; r != end; r++)
152 _GLIBCXX_PARALLEL_ASSERT(!comp(*r, *pivot_pos));
153#endif
154
155 return split_pos;
156 }
157
158 /** @brief Quicksort conquer step.
159 * @param tls Array of thread-local storages.
160 * @param begin Begin iterator of subsequence.
161 * @param end End iterator of subsequence.
162 * @param comp Comparator.
163 * @param iam Number of the thread processing this function.
164 * @param num_threads Number of threads that are allowed to work on this part. */
165 template<typename RandomAccessIterator, typename Comparator>
166 inline void
167 qsb_conquer(QSBThreadLocal<RandomAccessIterator>** tls,
168 RandomAccessIterator begin, RandomAccessIterator end,
169 Comparator comp, thread_index_t iam, thread_index_t num_threads)
170 {
171 typedef std::iterator_traits<RandomAccessIterator> traits_type;
172 typedef typename traits_type::value_type value_type;
173 typedef typename traits_type::difference_type difference_type;
174
175 difference_type n = end - begin;
176
177 if (num_threads <= 1 || n < 2)
178 {
179 tls[iam]->initial.first = begin;
180 tls[iam]->initial.second = end;
181
182 qsb_local_sort_with_helping(tls, comp, iam);
183
184 return;
185 }
186
187 // Divide step.
188 difference_type split_pos = qsb_divide(begin, end, comp, num_threads);
189
190#if _GLIBCXX_ASSERTIONS
191 _GLIBCXX_PARALLEL_ASSERT(0 <= split_pos && split_pos < (end - begin));
192#endif
193
194 thread_index_t num_threads_leftside = std::max<thread_index_t>(1, std::min<thread_index_t>(num_threads - 1, split_pos * num_threads / n));
195
196#pragma omp atomic
197 *tls[iam]->elements_leftover -= (difference_type)1;
198
199 // Conquer step.
200#pragma omp parallel sections num_threads(2)
201 {
202#pragma omp section
203 qsb_conquer(tls, begin, begin + split_pos, comp, iam, num_threads_leftside);
204 // The pivot_pos is left in place, to ensure termination.
205#pragma omp section
206 qsb_conquer(tls, begin + split_pos + 1, end, comp,
207 iam + num_threads_leftside, num_threads - num_threads_leftside);
208 }
209 }
210
211 /**
212 * @brief Quicksort step doing load-balanced local sort.
213 * @param tls Array of thread-local storages.
214 * @param comp Comparator.
215 * @param iam Number of the thread processing this function.
216 */
217 template<typename RandomAccessIterator, typename Comparator>
218 inline void
219 qsb_local_sort_with_helping(QSBThreadLocal<RandomAccessIterator>** tls,
220 Comparator& comp, int iam)
221 {
222 typedef std::iterator_traits<RandomAccessIterator> traits_type;
223 typedef typename traits_type::value_type value_type;
224 typedef typename traits_type::difference_type difference_type;
225 typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
226
227 QSBThreadLocal<RandomAccessIterator>& tl = *tls[iam];
228
229 difference_type base_case_n = Settings::sort_qsb_base_case_maximal_n;
230 if (base_case_n < 2)
231 base_case_n = 2;
232 thread_index_t num_threads = tl.num_threads;
233
234 // Every thread has its own random number generator.
235 random_number rng(iam + 1);
236
237 Piece current = tl.initial;
238
239 difference_type elements_done = 0;
240#if _GLIBCXX_ASSERTIONS
241 difference_type total_elements_done = 0;
242#endif
243
244 for (;;)
245 {
246 // Invariant: current must be a valid (maybe empty) range.
247 RandomAccessIterator begin = current.first, end = current.second;
248 difference_type n = end - begin;
249
250 if (n > base_case_n)
251 {
252 // Divide.
253 RandomAccessIterator pivot_pos = begin + rng(n);
254
255 // Swap pivot_pos value to end.
256 if (pivot_pos != (end - 1))
257 std::swap(*pivot_pos, *(end - 1));
258 pivot_pos = end - 1;
259
260 __gnu_parallel::binder2nd<Comparator, value_type, value_type, bool> pred(comp, *pivot_pos);
261
262 // Divide, leave pivot unchanged in last place.
263 RandomAccessIterator split_pos1, split_pos2;
264 split_pos1 = __gnu_sequential::partition(begin, end - 1, pred);
265
266 // Left side: < pivot_pos; right side: >= pivot_pos.
267#if _GLIBCXX_ASSERTIONS
268 _GLIBCXX_PARALLEL_ASSERT(begin <= split_pos1 && split_pos1 < end);
269#endif
270 // Swap pivot back to middle.
271 if (split_pos1 != pivot_pos)
272 std::swap(*split_pos1, *pivot_pos);
273 pivot_pos = split_pos1;
274
275 // In case all elements are equal, split_pos1 == 0.
276 if ((split_pos1 + 1 - begin) < (n >> 7)
277 || (end - split_pos1) < (n >> 7))
278 {
279 // Very unequal split, one part smaller than one 128th
a3e6b31a 280 // elements not strictly larger than the pivot.
c2ba9709
JS
281 __gnu_parallel::unary_negate<__gnu_parallel::binder1st<Comparator, value_type, value_type, bool>, value_type> pred(__gnu_parallel::binder1st<Comparator, value_type, value_type, bool>(comp, *pivot_pos));
282
283 // Find other end of pivot-equal range.
284 split_pos2 = __gnu_sequential::partition(split_pos1 + 1, end, pred);
285 }
286 else
287 {
288 // Only skip the pivot.
289 split_pos2 = split_pos1 + 1;
290 }
291
292 // Elements equal to pivot are done.
293 elements_done += (split_pos2 - split_pos1);
294#if _GLIBCXX_ASSERTIONS
295 total_elements_done += (split_pos2 - split_pos1);
296#endif
297 // Always push larger part onto stack.
298 if (((split_pos1 + 1) - begin) < (end - (split_pos2)))
299 {
300 // Right side larger.
301 if ((split_pos2) != end)
302 tl.leftover_parts.push_front(std::make_pair(split_pos2, end));
303
304 //current.first = begin; //already set anyway
305 current.second = split_pos1;
306 continue;
307 }
308 else
309 {
310 // Left side larger.
311 if (begin != split_pos1)
312 tl.leftover_parts.push_front(std::make_pair(begin, split_pos1));
313
314 current.first = split_pos2;
315 //current.second = end; //already set anyway
316 continue;
317 }
318 }
319 else
320 {
321 __gnu_sequential::sort(begin, end, comp);
322 elements_done += n;
323#if _GLIBCXX_ASSERTIONS
324 total_elements_done += n;
325#endif
326
327 // Prefer own stack, small pieces.
328 if (tl.leftover_parts.pop_front(current))
329 continue;
330
331#pragma omp atomic
332 *tl.elements_leftover -= elements_done;
333 elements_done = 0;
334
335#if _GLIBCXX_ASSERTIONS
336 double search_start = omp_get_wtime();
337#endif
338
339 // Look for new work.
340 bool success = false;
341 while (*tl.elements_leftover > 0 && !success
342#if _GLIBCXX_ASSERTIONS
343 // Possible dead-lock.
344 && (omp_get_wtime() < (search_start + 1.0))
345#endif
346 )
347 {
348 thread_index_t victim;
349 victim = rng(num_threads);
350
351 // Large pieces.
352 success = (victim != iam) && tls[victim]->leftover_parts.pop_back(current);
353 if (!success)
354 yield();
355#if !defined(__ICC) && !defined(__ECC)
356#pragma omp flush
357#endif
358 }
359
360#if _GLIBCXX_ASSERTIONS
361 if (omp_get_wtime() >= (search_start + 1.0))
362 {
363 sleep(1);
364 _GLIBCXX_PARALLEL_ASSERT(omp_get_wtime() < (search_start + 1.0));
365 }
366#endif
367 if (!success)
368 {
369#if _GLIBCXX_ASSERTIONS
370 _GLIBCXX_PARALLEL_ASSERT(*tl.elements_leftover == 0);
371#endif
372 return;
373 }
374 }
375 }
376 }
377
378 /** @brief Top-level quicksort routine.
379 * @param begin Begin iterator of sequence.
380 * @param end End iterator of sequence.
381 * @param comp Comparator.
382 * @param n Length of the sequence to sort.
383 * @param num_threads Number of threads that are allowed to work on
384 * this part.
385 */
386 template<typename RandomAccessIterator, typename Comparator>
387 inline void
388 parallel_sort_qsb(RandomAccessIterator begin, RandomAccessIterator end,
389 Comparator comp,
390 typename std::iterator_traits<RandomAccessIterator>::difference_type n, int num_threads)
391 {
392 _GLIBCXX_CALL(end - begin)
393
394 typedef std::iterator_traits<RandomAccessIterator> traits_type;
395 typedef typename traits_type::value_type value_type;
396 typedef typename traits_type::difference_type difference_type;
397 typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
398
399 typedef QSBThreadLocal<RandomAccessIterator> tls_type;
400
401 if (n <= 1)
402 return;
403
404 // At least one element per processor.
405 if (num_threads > n)
406 num_threads = static_cast<thread_index_t>(n);
407
408 tls_type** tls = new tls_type*[num_threads];
409
410#pragma omp parallel num_threads(num_threads)
411 // Initialize variables per processor.
412 qsb_initialize(tls, num_threads * (thread_index_t)(log2(n) + 1));
413
414 // There can never be more than ceil(log2(n)) ranges on the stack, because
415 // 1. Only one processor pushes onto the stack
416 // 2. The largest range has at most length n
417 // 3. Each range is larger than half of the range remaining
418 volatile difference_type elements_leftover = n;
419 for (int i = 0; i < num_threads; i++)
420 {
421 tls[i]->elements_leftover = &elements_leftover;
422 tls[i]->num_threads = num_threads;
423 tls[i]->global = std::make_pair(begin, end);
424
425 // Just in case nothing is left to assign.
426 tls[i]->initial = std::make_pair(end, end);
427 }
428
429 // Initial splitting, recursively.
430 int old_nested = omp_get_nested();
431 omp_set_nested(true);
432
433 // Main recursion call.
434 qsb_conquer(tls, begin, begin + n, comp, 0, num_threads);
435
436 omp_set_nested(old_nested);
437
438#if _GLIBCXX_ASSERTIONS
439 // All stack must be empty.
440 Piece dummy;
441 for (int i = 1; i < num_threads; i++)
442 _GLIBCXX_PARALLEL_ASSERT(!tls[i]->leftover_parts.pop_back(dummy));
443#endif
444
445 for (int i = 0; i < num_threads; i++)
446 delete tls[i];
447 delete[] tls;
448 }
449} // namespace __gnu_parallel
450
451#endif