diff --git a/core/src/slice/sort/shared/smallsort.rs b/core/src/slice/sort/shared/smallsort.rs index 09f898309bd65..f6dcf42ba6037 100644 --- a/core/src/slice/sort/shared/smallsort.rs +++ b/core/src/slice/sort/shared/smallsort.rs @@ -387,7 +387,7 @@ unsafe fn swap_if_less(v_base: *mut T, a_pos: usize, b_pos: usize, is_less where F: FnMut(&T, &T) -> bool, { - // SAFETY: the caller must guarantee that `a` and `b` each added to `v_base` yield valid + // SAFETY: the caller must guarantee that `a_pos` and `b_pos` each added to `v_base` yield valid // pointers into `v_base`, and are properly aligned, and part of the same allocation. unsafe { let v_a = v_base.add(a_pos); @@ -404,16 +404,16 @@ where // The equivalent code with a branch would be: // // if should_swap { - // ptr::swap(left, right, 1); + // ptr::swap(v_a, v_b, 1); // } // The goal is to generate cmov instructions here. - let left_swap = if should_swap { v_b } else { v_a }; - let right_swap = if should_swap { v_a } else { v_b }; + let v_a_swap = should_swap.select_unpredictable(v_b, v_a); + let v_b_swap = should_swap.select_unpredictable(v_a, v_b); - let right_swap_tmp = ManuallyDrop::new(ptr::read(right_swap)); - ptr::copy(left_swap, v_a, 1); - ptr::copy_nonoverlapping(&*right_swap_tmp, v_b, 1); + let v_b_swap_tmp = ManuallyDrop::new(ptr::read(v_b_swap)); + ptr::copy(v_a_swap, v_a, 1); + ptr::copy_nonoverlapping(&*v_b_swap_tmp, v_b, 1); } } @@ -640,26 +640,21 @@ pub unsafe fn sort4_stable bool>( // 1, 1 | c b a d let c3 = is_less(&*c, &*a); let c4 = is_less(&*d, &*b); - let min = select(c3, c, a); - let max = select(c4, b, d); - let unknown_left = select(c3, a, select(c4, c, b)); - let unknown_right = select(c4, d, select(c3, b, c)); + let min = c3.select_unpredictable(c, a); + let max = c4.select_unpredictable(b, d); + let unknown_left = c3.select_unpredictable(a, c4.select_unpredictable(c, b)); + let unknown_right = c4.select_unpredictable(d, c3.select_unpredictable(b, c)); // Sort the last two unknown elements. let c5 = is_less(&*unknown_right, &*unknown_left); - let lo = select(c5, unknown_right, unknown_left); - let hi = select(c5, unknown_left, unknown_right); + let lo = c5.select_unpredictable(unknown_right, unknown_left); + let hi = c5.select_unpredictable(unknown_left, unknown_right); ptr::copy_nonoverlapping(min, dst, 1); ptr::copy_nonoverlapping(lo, dst.add(1), 1); ptr::copy_nonoverlapping(hi, dst.add(2), 1); ptr::copy_nonoverlapping(max, dst.add(3), 1); } - - #[inline(always)] - fn select(cond: bool, if_true: *const T, if_false: *const T) -> *const T { - if cond { if_true } else { if_false } - } } /// SAFETY: The caller MUST guarantee that `v_base` is valid for 8 reads and