Skip to content

Commit

Permalink
Use more explicit and reliable ptr select in sort impls
Browse files Browse the repository at this point in the history
Using if ... with the intent to avoid branches can be surprising to readers and
carries the risk of turning into jumps/branches generated by some future
compiler version, breaking crucial optimizations.

This commit replaces their usage with the explicit and IR annotated
`bool::select_unpredictable`.
  • Loading branch information
Voultapher authored and gitbot committed Mar 4, 2025
1 parent 4f9a538 commit 8bb4f5a
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions core/src/slice/sort/shared/smallsort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ unsafe fn swap_if_less<T, F>(v_base: *mut T, a_pos: usize, b_pos: usize, is_less
where
F: FnMut(&T, &T) -> bool,
{
// SAFETY: the caller must guarantee that `a` and `b` each added to `v_base` yield valid
// SAFETY: the caller must guarantee that `a_pos` and `b_pos` each added to `v_base` yield valid
// pointers into `v_base`, and are properly aligned, and part of the same allocation.
unsafe {
let v_a = v_base.add(a_pos);
Expand All @@ -404,16 +404,16 @@ where
// The equivalent code with a branch would be:
//
// if should_swap {
// ptr::swap(left, right, 1);
// ptr::swap(v_a, v_b, 1);
// }

// The goal is to generate cmov instructions here.
let left_swap = if should_swap { v_b } else { v_a };
let right_swap = if should_swap { v_a } else { v_b };
let v_a_swap = should_swap.select_unpredictable(v_b, v_a);
let v_b_swap = should_swap.select_unpredictable(v_a, v_b);

let right_swap_tmp = ManuallyDrop::new(ptr::read(right_swap));
ptr::copy(left_swap, v_a, 1);
ptr::copy_nonoverlapping(&*right_swap_tmp, v_b, 1);
let v_b_swap_tmp = ManuallyDrop::new(ptr::read(v_b_swap));
ptr::copy(v_a_swap, v_a, 1);
ptr::copy_nonoverlapping(&*v_b_swap_tmp, v_b, 1);
}
}

Expand Down Expand Up @@ -640,26 +640,21 @@ pub unsafe fn sort4_stable<T, F: FnMut(&T, &T) -> bool>(
// 1, 1 | c b a d
let c3 = is_less(&*c, &*a);
let c4 = is_less(&*d, &*b);
let min = select(c3, c, a);
let max = select(c4, b, d);
let unknown_left = select(c3, a, select(c4, c, b));
let unknown_right = select(c4, d, select(c3, b, c));
let min = c3.select_unpredictable(c, a);
let max = c4.select_unpredictable(b, d);
let unknown_left = c3.select_unpredictable(a, c4.select_unpredictable(c, b));
let unknown_right = c4.select_unpredictable(d, c3.select_unpredictable(b, c));

// Sort the last two unknown elements.
let c5 = is_less(&*unknown_right, &*unknown_left);
let lo = select(c5, unknown_right, unknown_left);
let hi = select(c5, unknown_left, unknown_right);
let lo = c5.select_unpredictable(unknown_right, unknown_left);
let hi = c5.select_unpredictable(unknown_left, unknown_right);

ptr::copy_nonoverlapping(min, dst, 1);
ptr::copy_nonoverlapping(lo, dst.add(1), 1);
ptr::copy_nonoverlapping(hi, dst.add(2), 1);
ptr::copy_nonoverlapping(max, dst.add(3), 1);
}

#[inline(always)]
fn select<T>(cond: bool, if_true: *const T, if_false: *const T) -> *const T {
if cond { if_true } else { if_false }
}
}

/// SAFETY: The caller MUST guarantee that `v_base` is valid for 8 reads and
Expand Down

0 comments on commit 8bb4f5a

Please sign in to comment.