You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While inspecting the code generated by num_enum, I found that the code for TryFromPrimitive is suboptimal. In certain situations, the generated code leads to generated assembly that is both larger (binary size) and slower than simple handwritten code using transmute. This is very unfortunate since one of the main benefits of num_enum is that it allows low-level enum "serialization" without the use of unsafe.
Description
The generated code uses a large match statement to match input values. This is a good general solution, but can lead to inefficient code gen for certain enums.
In this issue, I will focus on enums where the discriminant value of all variants are statically known (= the derive macro knows the numerical value for each variant) and each variant has exactly one discriminant (= no alternatives). I will also ignore default and catch_all, but the idea I talk about can easily be adopted to support them.
I annotated the assembly a little, but it basically works in 2 steps:
Check whether the input primitive is > 191. If it is, then return Err.
Otherwise, look up whether to jump to the Ok or Err branch in a jump table.
I abbreviated the jump table here, but it has 192 entries (each 4 bytes). So to retrieve 1 bit of information, we need to read a jump table entry and then jmp to the retrieved location.
This is quite inefficient. Both the memory read and jmp are comparatively slow, and the jump table is 768 bytes in binary size.
More efficient code
Compare this to the following implementation I've written by hand:
fntry_from_primitive(number:u8) -> Result<DxgiFormatEnum,TryFromPrimitiveError>{let is_valid = matches!(number,0..=115 | 130..=132 | 191);if is_valid {// SAFETY: `number` is a valid discriminant valueOk(unsafe{ std::mem::transmute(number)})}else{Err(TryFromPrimitiveError(number))}}
try_from_primitive:cmp dil,116 ; compare input to 116jae .LBB0_2 ; if the `input >=116`, go to LBB0_2xoreax,eax ; Otherwise, return Okmovedx,ediret.LBB0_2:leaecx,[rdi+126] ; compute `input - 130`moval,1cmpcl,61 ; compare `input - 130` to 61 (191 - 130 = 61)ja .LBB0_4 ; if `input - 130 > 61`, return Err movabs rax,2305843009213693944 ; this value is used as a bitsetshrrax,cl.LBB0_4:movedx,ediret
This is the output for O3 with other optimization modes producing virtually identical assembly. I want to point out that what the compiler produced is quite good here. It handled 0..=115 in its own branch and used a fast bitset lookup for 130..=132 | 191. Note that there is no jump table or anything. The assembly is just a few bytes and can easily be inlined.
The underlying issue
The underlying issue is two-fold:
LLVM has a hard time optimizing the large match statement generated by num_enum.
LLVM is inconsistent when it optimizes the match statement generated by num_enum.
I want to expand on the second point. LLVM actually produces very efficient assembly for the match statement generated by num_enum IF the enum is sufficiently simple.
Even small changes in discriminant values can have huge effects on the generated assembly.
Here is a Compiler Explorer link where you quickly see the generated assembly for different variants. Just input the discriminant values into the test! macro.
Suggested fix
Since the compiler seems to be bad at optimizing the large match statement generated by num_enum, I would suggest splitting TryFromPrimitive::try_from_primitive into 2 parts:
The first part will determine whether the given input is a valid discriminant value.
The second part uses the result from the first part and transmutes the input or returns an error (or the default/catch_all variant).
The compiler generally seems to be better at optimizing this form, leading to assembly that is as good or better than of the current code gen. You can test this out with this Compiler Explorer link.
Note however that matches! is not a perfect solution. It also generates jump tables sometimes, it just does it less often. To fix this, I would further suggest manually choosing a good strategy for determining is_valid instead of leaving it entirely up to the compiler. I would suggest the following:
Let V be the set of all discriminant values. The determine is_valid: bool as follows:
If V is an interval, generate is_valid = (V_min..=V_max).contains(&input).
If V_max - V_min <= 64, then use a 64-bit bitset.
Otherwise, sort V and output a matches! with ranges instead of simple literal values. E.g. matches!(1..=4, 6..=8, 100) instead of matches!(1, 2, 3, 4, 6, 7, 8, 100). LLVM tends to produce better assembly like this.
This approach should always produce assembly that is as good or better than simply using matches! with all variants.
Note: If you choose to not implement the bitset method and want to rely on LLVM instead, I would advise aginst using matches! with ranges for small enums. LLVM tends to optimize those with bitset, unless ranges are used in macthes!. (Again, LLVM is really inconsistent in how it optimizes this.)
The text was updated successfully, but these errors were encountered:
Motivation
While inspecting the code generated by
num_enum
, I found that the code forTryFromPrimitive
is suboptimal. In certain situations, the generated code leads to generated assembly that is both larger (binary size) and slower than simple handwritten code usingtransmute
. This is very unfortunate since one of the main benefits ofnum_enum
is that it allows low-level enum "serialization" without the use ofunsafe
.Description
The generated code uses a large
match
statement to match input values. This is a good general solution, but can lead to inefficient code gen for certain enums.In this issue, I will focus on enums where the discriminant value of all variants are statically known (= the derive macro knows the numerical value for each variant) and each variant has exactly one discriminant (= no
alternatives
). I will also ignoredefault
andcatch_all
, but the idea I talk about can easily be adopted to support them.Inefficient code gen
Consider the following enum of the
DXGI_FORMAT
enumeration:The generated implementation for
TryFromPrimitive::try_from_primitive
currently looks like this (simplified):And this generates the following assembly on Rust 1.85.0 for O1, O2, O3, Os, and Oz: (Compiler Explorer link)
I annotated the assembly a little, but it basically works in 2 steps:
> 191
. If it is, then returnErr
.Ok
orErr
branch in a jump table.I abbreviated the jump table here, but it has 192 entries (each 4 bytes). So to retrieve 1 bit of information, we need to read a jump table entry and then
jmp
to the retrieved location.This is quite inefficient. Both the memory read and
jmp
are comparatively slow, and the jump table is 768 bytes in binary size.More efficient code
Compare this to the following implementation I've written by hand:
This is the output for O3 with other optimization modes producing virtually identical assembly. I want to point out that what the compiler produced is quite good here. It handled
0..=115
in its own branch and used a fast bitset lookup for130..=132 | 191
. Note that there is no jump table or anything. The assembly is just a few bytes and can easily be inlined.The underlying issue
The underlying issue is two-fold:
match
statement generated bynum_enum
.match
statement generated bynum_enum
.I want to expand on the second point. LLVM actually produces very efficient assembly for the
match
statement generated bynum_enum
IF the enum is sufficiently simple.Example:
However, even relatively simple enums can cause large jump tables to be generated.
Example:
Even small changes in discriminant values can have huge effects on the generated assembly.
Here is a Compiler Explorer link where you quickly see the generated assembly for different variants. Just input the discriminant values into the
test!
macro.Suggested fix
Since the compiler seems to be bad at optimizing the large
match
statement generated bynum_enum
, I would suggest splittingTryFromPrimitive::try_from_primitive
into 2 parts:default
/catch_all
variant).In code, it would have the following form:
The compiler generally seems to be better at optimizing this form, leading to assembly that is as good or better than of the current code gen. You can test this out with this Compiler Explorer link.
Note however that
matches!
is not a perfect solution. It also generates jump tables sometimes, it just does it less often. To fix this, I would further suggest manually choosing a good strategy for determiningis_valid
instead of leaving it entirely up to the compiler. I would suggest the following:Let
V
be the set of all discriminant values. The determineis_valid: bool
as follows:V
is an interval, generateis_valid = (V_min..=V_max).contains(&input)
.V_max - V_min <= 64
, then use a 64-bit bitset.V
and output amatches!
with ranges instead of simple literal values. E.g.matches!(1..=4, 6..=8, 100)
instead ofmatches!(1, 2, 3, 4, 6, 7, 8, 100)
. LLVM tends to produce better assembly like this.This approach should always produce assembly that is as good or better than simply using
matches!
with all variants.Note: If you choose to not implement the bitset method and want to rely on LLVM instead, I would advise aginst using
matches!
with ranges for small enums. LLVM tends to optimize those with bitset, unless ranges are used inmacthes!
. (Again, LLVM is really inconsistent in how it optimizes this.)The text was updated successfully, but these errors were encountered: