]> git.ipfire.org Git - thirdparty/postgresql.git/commitdiff
Skip common prefixes during radix sort
authorJohn Naylor <john.naylor@postgresql.org>
Wed, 1 Apr 2026 07:18:57 +0000 (14:18 +0700)
committerJohn Naylor <john.naylor@postgresql.org>
Wed, 1 Apr 2026 07:18:57 +0000 (14:18 +0700)
During the counting step, keep track of the bits that are the same
for the entire input.  If we counted only a single distinct byte,
the next recursion will start at the next byte position that has
more than one distinct byte in the input. This allows us to skip over
multiple passes where the byte is the same for the entire input.

This provides a significant speedup for integers that have some upper
bytes with all-zeros or all-ones, which is common.

Reviewed-by: Chengpeng Yan <chengpeng_yan@outlook.com>
Reviewed-by: ChangAo Chen <cca5507@qq.com>
Discussion: https://postgr.es/m/CANWCAZYpGMDSSwAa18fOxJGXaPzVdyPsWpOkfCX32DWh3Qznzw@mail.gmail.com

src/backend/utils/sort/tuplesort.c

index 1fc440ea6ca1ce7848c5df7d3433e121a5eb92dd..72c2c2995d8da98253f98e96508b2186ab924de2 100644 (file)
 #include "commands/tablespace.h"
 #include "miscadmin.h"
 #include "pg_trace.h"
+#include "port/pg_bitutils.h"
 #include "storage/shmem.h"
 #include "utils/guc.h"
 #include "utils/memutils.h"
@@ -2659,17 +2660,25 @@ radix_sort_recursive(SortTuple *begin, size_t n_elems, int level, Tuplesortstate
        int                     num_partitions = 0;
        int                     num_remaining;
        SortSupport ssup = &state->base.sortKeys[0];
+       Datum           ref_datum;
+       Datum           common_upper_bits = 0;
        size_t          start_offset = 0;
        SortTuple  *partition_begin = begin;
+       int                     next_level;
 
        /* count number of occurrences of each byte */
+       ref_datum = normalize_datum(begin[0].datum1, ssup);
        for (SortTuple *st = begin; st < begin + n_elems; st++)
        {
+               Datum           this_datum;
                uint8           this_partition;
 
+               this_datum = normalize_datum(st->datum1, ssup);
+               /* accumulate bits different from the reference datum */
+               common_upper_bits |= ref_datum ^ this_datum;
+
                /* extract the byte for this level from the normalized datum */
-               this_partition = current_byte(normalize_datum(st->datum1, ssup),
-                                                                         level);
+               this_partition = current_byte(this_datum, level);
 
                /* save it for the permutation step */
                st->curbyte = this_partition;
@@ -2747,6 +2756,33 @@ radix_sort_recursive(SortTuple *begin, size_t n_elems, int level, Tuplesortstate
        }
 
        /* recurse */
+
+       if (num_partitions == 1)
+       {
+               /*
+                * There is only one distinct byte at the current level. It can happen
+                * that some subsequent bytes are also the same for all input values,
+                * such as the upper bytes of small integers. To skip unproductive
+                * passes for that case, we compute the level where the input has more
+                * than one distinct byte, so that the next recursion can start there.
+                */
+               if (common_upper_bits == 0)
+                       next_level = sizeof(Datum);
+               else
+               {
+                       int                     diffpos;
+
+                       /*
+                        * The upper bits of common_upper_bits are zero where all datums
+                        * have the same bits.
+                        */
+                       diffpos = pg_leftmost_one_pos64(DatumGetUInt64(common_upper_bits));
+                       next_level = sizeof(Datum) - 1 - (diffpos / BITS_PER_BYTE);
+               }
+       }
+       else
+               next_level = level + 1;
+
        for (uint8 *rp = remaining_partitions;
                 rp < remaining_partitions + num_partitions;
                 rp++)
@@ -2757,7 +2793,7 @@ radix_sort_recursive(SortTuple *begin, size_t n_elems, int level, Tuplesortstate
 
                if (num_elements > 1)
                {
-                       if (level < sizeof(Datum) - 1)
+                       if (next_level < sizeof(Datum))
                        {
                                if (num_elements < QSORT_THRESHOLD)
                                {
@@ -2770,7 +2806,7 @@ radix_sort_recursive(SortTuple *begin, size_t n_elems, int level, Tuplesortstate
                                {
                                        radix_sort_recursive(partition_begin,
                                                                                 num_elements,
-                                                                                level + 1,
+                                                                                next_level,
                                                                                 state);
                                }
                        }