Skip to content

Commit 742d3f0

Browse files
committed
Auto merge of #102872 - mikebenfield:better-get-discr, r=nagisa
rustc_codegen_ssa: Better code generation for niche discriminants. In some cases we can avoid arithmetic before checking whether a niche is a tag. Also rename some identifiers around niches. This is relevant to #101872
2 parents b7b7f27 + 51918dc commit 742d3f0

File tree

3 files changed

+371
-49
lines changed

3 files changed

+371
-49
lines changed

compiler/rustc_codegen_ssa/src/mir/place.rs

+145-49
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
209209
bx: &mut Bx,
210210
cast_to: Ty<'tcx>,
211211
) -> V {
212-
let cast_to = bx.cx().immediate_backend_type(bx.cx().layout_of(cast_to));
212+
let cast_to_layout = bx.cx().layout_of(cast_to);
213+
let cast_to_size = cast_to_layout.layout.size();
214+
let cast_to = bx.cx().immediate_backend_type(cast_to_layout);
213215
if self.layout.abi.is_uninhabited() {
214216
return bx.cx().const_undef(cast_to);
215217
}
@@ -229,7 +231,8 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
229231

230232
// Read the tag/niche-encoded discriminant from memory.
231233
let tag = self.project_field(bx, tag_field);
232-
let tag = bx.load_operand(tag);
234+
let tag_op = bx.load_operand(tag);
235+
let tag_imm = tag_op.immediate();
233236

234237
// Decode the discriminant (specifically if it's niche-encoded).
235238
match *tag_encoding {
@@ -242,68 +245,161 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
242245
Int(_, signed) => !tag_scalar.is_bool() && signed,
243246
_ => false,
244247
};
245-
bx.intcast(tag.immediate(), cast_to, signed)
248+
bx.intcast(tag_imm, cast_to, signed)
246249
}
247250
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
248-
// Rebase from niche values to discriminants, and check
249-
// whether the result is in range for the niche variants.
250-
let niche_llty = bx.cx().immediate_backend_type(tag.layout);
251-
let tag = tag.immediate();
252-
253-
// We first compute the "relative discriminant" (wrt `niche_variants`),
254-
// that is, if `n = niche_variants.end() - niche_variants.start()`,
255-
// we remap `niche_start..=niche_start + n` (which may wrap around)
256-
// to (non-wrap-around) `0..=n`, to be able to check whether the
257-
// discriminant corresponds to a niche variant with one comparison.
258-
// We also can't go directly to the (variant index) discriminant
259-
// and check that it is in the range `niche_variants`, because
260-
// that might not fit in the same type, on top of needing an extra
261-
// comparison (see also the comment on `let niche_discr`).
262-
let relative_discr = if niche_start == 0 {
263-
// Avoid subtracting `0`, which wouldn't work for pointers.
264-
// FIXME(eddyb) check the actual primitive type here.
265-
tag
251+
// Cast to an integer so we don't have to treat a pointer as a
252+
// special case.
253+
let (tag, tag_llty) = if tag_scalar.primitive().is_ptr() {
254+
let t = bx.type_isize();
255+
let tag = bx.ptrtoint(tag_imm, t);
256+
(tag, t)
266257
} else {
267-
bx.sub(tag, bx.cx().const_uint_big(niche_llty, niche_start))
258+
(tag_imm, bx.cx().immediate_backend_type(tag_op.layout))
268259
};
260+
261+
let tag_size = tag_scalar.size(bx.cx());
262+
let max_unsigned = tag_size.unsigned_int_max();
263+
let max_signed = tag_size.signed_int_max() as u128;
264+
let min_signed = max_signed + 1;
269265
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
270-
let is_niche = if relative_max == 0 {
271-
// Avoid calling `const_uint`, which wouldn't work for pointers.
272-
// Also use canonical == 0 instead of non-canonical u<= 0.
273-
// FIXME(eddyb) check the actual primitive type here.
274-
bx.icmp(IntPredicate::IntEQ, relative_discr, bx.cx().const_null(niche_llty))
266+
let niche_end = niche_start.wrapping_add(relative_max as u128) & max_unsigned;
267+
let range = tag_scalar.valid_range(bx.cx());
268+
269+
let sle = |lhs: u128, rhs: u128| -> bool {
270+
// Signed and unsigned comparisons give the same results,
271+
// except that in signed comparisons an integer with the
272+
// sign bit set is less than one with the sign bit clear.
273+
// Toggle the sign bit to do a signed comparison.
274+
(lhs ^ min_signed) <= (rhs ^ min_signed)
275+
};
276+
277+
// We have a subrange `niche_start..=niche_end` inside `range`.
278+
// If the value of the tag is inside this subrange, it's a
279+
// "niche value", an increment of the discriminant. Otherwise it
280+
// indicates the untagged variant.
281+
// A general algorithm to extract the discriminant from the tag
282+
// is:
283+
// relative_tag = tag - niche_start
284+
// is_niche = relative_tag <= (ule) relative_max
285+
// discr = if is_niche {
286+
// cast(relative_tag) + niche_variants.start()
287+
// } else {
288+
// untagged_variant
289+
// }
290+
// However, we will likely be able to emit simpler code.
291+
292+
// Find the least and greatest values in `range`, considered
293+
// both as signed and unsigned.
294+
let (low_unsigned, high_unsigned) = if range.start <= range.end {
295+
(range.start, range.end)
296+
} else {
297+
(0, max_unsigned)
298+
};
299+
let (low_signed, high_signed) = if sle(range.start, range.end) {
300+
(range.start, range.end)
275301
} else {
276-
let relative_max = bx.cx().const_uint(niche_llty, relative_max as u64);
277-
bx.icmp(IntPredicate::IntULE, relative_discr, relative_max)
302+
(min_signed, max_signed)
303+
};
304+
305+
let niches_ule = niche_start <= niche_end;
306+
let niches_sle = sle(niche_start, niche_end);
307+
let cast_smaller = cast_to_size <= tag_size;
308+
309+
// In the algorithm above, we can change
310+
// cast(relative_tag) + niche_variants.start()
311+
// into
312+
// cast(tag) + (niche_variants.start() - niche_start)
313+
// if either the casted type is no larger than the original
314+
// type, or if the niche values are contiguous (in either the
315+
// signed or unsigned sense).
316+
let can_incr_after_cast = cast_smaller || niches_ule || niches_sle;
317+
318+
let data_for_boundary_niche = || -> Option<(IntPredicate, u128)> {
319+
if !can_incr_after_cast {
320+
None
321+
} else if niche_start == low_unsigned {
322+
Some((IntPredicate::IntULE, niche_end))
323+
} else if niche_end == high_unsigned {
324+
Some((IntPredicate::IntUGE, niche_start))
325+
} else if niche_start == low_signed {
326+
Some((IntPredicate::IntSLE, niche_end))
327+
} else if niche_end == high_signed {
328+
Some((IntPredicate::IntSGE, niche_start))
329+
} else {
330+
None
331+
}
278332
};
279333

280-
// NOTE(eddyb) this addition needs to be performed on the final
281-
// type, in case the niche itself can't represent all variant
282-
// indices (e.g. `u8` niche with more than `256` variants,
283-
// but enough uninhabited variants so that the remaining variants
284-
// fit in the niche).
285-
// In other words, `niche_variants.end - niche_variants.start`
286-
// is representable in the niche, but `niche_variants.end`
287-
// might not be, in extreme cases.
288-
let niche_discr = {
289-
let relative_discr = if relative_max == 0 {
290-
// HACK(eddyb) since we have only one niche, we know which
291-
// one it is, and we can avoid having a dynamic value here.
292-
bx.cx().const_uint(cast_to, 0)
334+
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
335+
// Best case scenario: only one tagged variant. This will
336+
// likely become just a comparison and a jump.
337+
// The algorithm is:
338+
// is_niche = tag == niche_start
339+
// discr = if is_niche {
340+
// niche_start
341+
// } else {
342+
// untagged_variant
343+
// }
344+
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
345+
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
346+
let tagged_discr =
347+
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
348+
(is_niche, tagged_discr, 0)
349+
} else if let Some((predicate, constant)) = data_for_boundary_niche() {
350+
// The niche values are either the lowest or the highest in
351+
// `range`. We can avoid the first subtraction in the
352+
// algorithm.
353+
// The algorithm is now this:
354+
// is_niche = tag <= niche_end
355+
// discr = if is_niche {
356+
// cast(tag) + (niche_variants.start() - niche_start)
357+
// } else {
358+
// untagged_variant
359+
// }
360+
// (the first line may instead be tag >= niche_start,
361+
// and may be a signed or unsigned comparison)
362+
let is_niche =
363+
bx.icmp(predicate, tag, bx.cx().const_uint_big(tag_llty, constant));
364+
let cast_tag = if cast_smaller {
365+
bx.intcast(tag, cast_to, false)
366+
} else if niches_ule {
367+
bx.zext(tag, cast_to)
293368
} else {
294-
bx.intcast(relative_discr, cast_to, false)
369+
bx.sext(tag, cast_to)
295370
};
296-
bx.add(
371+
372+
let delta = (niche_variants.start().as_u32() as u128).wrapping_sub(niche_start);
373+
(is_niche, cast_tag, delta)
374+
} else {
375+
// The special cases don't apply, so we'll have to go with
376+
// the general algorithm.
377+
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
378+
let cast_tag = bx.intcast(relative_discr, cast_to, false);
379+
let is_niche = bx.icmp(
380+
IntPredicate::IntULE,
297381
relative_discr,
298-
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64),
299-
)
382+
bx.cx().const_uint(tag_llty, relative_max as u64),
383+
);
384+
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
300385
};
301386

302-
bx.select(
387+
let tagged_discr = if delta == 0 {
388+
tagged_discr
389+
} else {
390+
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
391+
};
392+
393+
let discr = bx.select(
303394
is_niche,
304-
niche_discr,
395+
tagged_discr,
305396
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
306-
)
397+
);
398+
399+
// In principle we could insert assumes on the possible range of `discr`, but
400+
// currently in LLVM this seems to be a pessimization.
401+
402+
discr
307403
}
308404
}
309405
}

src/test/codegen/enum-match.rs

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// compile-flags: -Copt-level=1
2+
// only-x86_64
3+
4+
#![crate_type = "lib"]
5+
6+
// Check each of the 3 cases for `codegen_get_discr`.
7+
8+
// Case 0: One tagged variant.
9+
pub enum Enum0 {
10+
A(bool),
11+
B,
12+
}
13+
14+
// CHECK: define i8 @match0{{.*}}
15+
// CHECK-NEXT: start:
16+
// CHECK-NEXT: %1 = icmp eq i8 %0, 2
17+
// CHECK-NEXT: %2 = and i8 %0, 1
18+
// CHECK-NEXT: %.0 = select i1 %1, i8 13, i8 %2
19+
#[no_mangle]
20+
pub fn match0(e: Enum0) -> u8 {
21+
use Enum0::*;
22+
match e {
23+
A(b) => b as u8,
24+
B => 13,
25+
}
26+
}
27+
28+
// Case 1: Niche values are on a boundary for `range`.
29+
pub enum Enum1 {
30+
A(bool),
31+
B,
32+
C,
33+
}
34+
35+
// CHECK: define i8 @match1{{.*}}
36+
// CHECK-NEXT: start:
37+
// CHECK-NEXT: %1 = icmp ugt i8 %0, 1
38+
// CHECK-NEXT: %2 = zext i8 %0 to i64
39+
// CHECK-NEXT: %3 = add nsw i64 %2, -1
40+
// CHECK-NEXT: %_2 = select i1 %1, i64 %3, i64 0
41+
// CHECK-NEXT: switch i64 %_2, label {{.*}} [
42+
#[no_mangle]
43+
pub fn match1(e: Enum1) -> u8 {
44+
use Enum1::*;
45+
match e {
46+
A(b) => b as u8,
47+
B => 13,
48+
C => 100,
49+
}
50+
}
51+
52+
// Case 2: Special cases don't apply.
53+
pub enum X {
54+
_2=2, _3, _4, _5, _6, _7, _8, _9, _10, _11,
55+
_12, _13, _14, _15, _16, _17, _18, _19, _20,
56+
_21, _22, _23, _24, _25, _26, _27, _28, _29,
57+
_30, _31, _32, _33, _34, _35, _36, _37, _38,
58+
_39, _40, _41, _42, _43, _44, _45, _46, _47,
59+
_48, _49, _50, _51, _52, _53, _54, _55, _56,
60+
_57, _58, _59, _60, _61, _62, _63, _64, _65,
61+
_66, _67, _68, _69, _70, _71, _72, _73, _74,
62+
_75, _76, _77, _78, _79, _80, _81, _82, _83,
63+
_84, _85, _86, _87, _88, _89, _90, _91, _92,
64+
_93, _94, _95, _96, _97, _98, _99, _100, _101,
65+
_102, _103, _104, _105, _106, _107, _108, _109,
66+
_110, _111, _112, _113, _114, _115, _116, _117,
67+
_118, _119, _120, _121, _122, _123, _124, _125,
68+
_126, _127, _128, _129, _130, _131, _132, _133,
69+
_134, _135, _136, _137, _138, _139, _140, _141,
70+
_142, _143, _144, _145, _146, _147, _148, _149,
71+
_150, _151, _152, _153, _154, _155, _156, _157,
72+
_158, _159, _160, _161, _162, _163, _164, _165,
73+
_166, _167, _168, _169, _170, _171, _172, _173,
74+
_174, _175, _176, _177, _178, _179, _180, _181,
75+
_182, _183, _184, _185, _186, _187, _188, _189,
76+
_190, _191, _192, _193, _194, _195, _196, _197,
77+
_198, _199, _200, _201, _202, _203, _204, _205,
78+
_206, _207, _208, _209, _210, _211, _212, _213,
79+
_214, _215, _216, _217, _218, _219, _220, _221,
80+
_222, _223, _224, _225, _226, _227, _228, _229,
81+
_230, _231, _232, _233, _234, _235, _236, _237,
82+
_238, _239, _240, _241, _242, _243, _244, _245,
83+
_246, _247, _248, _249, _250, _251, _252, _253,
84+
}
85+
86+
pub enum Enum2 {
87+
A(X),
88+
B,
89+
C,
90+
D,
91+
E,
92+
}
93+
94+
// CHECK: define i8 @match2{{.*}}
95+
// CHECK-NEXT: start:
96+
// CHECK-NEXT: %1 = add i8 %0, 2
97+
// CHECK-NEXT: %2 = zext i8 %1 to i64
98+
// CHECK-NEXT: %3 = icmp ult i8 %1, 4
99+
// CHECK-NEXT: %4 = add nuw nsw i64 %2, 1
100+
// CHECK-NEXT: %_2 = select i1 %3, i64 %4, i64 0
101+
// CHECK-NEXT: switch i64 %_2, label {{.*}} [
102+
#[no_mangle]
103+
pub fn match2(e: Enum2) -> u8 {
104+
use Enum2::*;
105+
match e {
106+
A(b) => b as u8,
107+
B => 13,
108+
C => 100,
109+
D => 200,
110+
E => 250,
111+
}
112+
}

0 commit comments

Comments
 (0)