Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve generated code for TryFromPrimitive #160

Open
RunDevelopment opened this issue Feb 27, 2025 · 0 comments
Open

Improve generated code for TryFromPrimitive #160

RunDevelopment opened this issue Feb 27, 2025 · 0 comments

Comments

@RunDevelopment
Copy link

RunDevelopment commented Feb 27, 2025

Motivation

While inspecting the code generated by num_enum, I found that the code for TryFromPrimitive is suboptimal. In certain situations, the generated code leads to generated assembly that is both larger (binary size) and slower than simple handwritten code using transmute. This is very unfortunate since one of the main benefits of num_enum is that it allows low-level enum "serialization" without the use of unsafe.

Description

The generated code uses a large match statement to match input values. This is a good general solution, but can lead to inefficient code gen for certain enums.

In this issue, I will focus on enums where the discriminant value of all variants are statically known (= the derive macro knows the numerical value for each variant) and each variant has exactly one discriminant (= no alternatives). I will also ignore default and catch_all, but the idea I talk about can easily be adopted to support them.

Inefficient code gen

Consider the following enum of the DXGI_FORMAT enumeration:

#[derive(TryFromPrimitive)]
#[repr(u8)]
enum DxgiFormatEnum {
    UNKNOWN = 0,
    R32G32B32A32_TYPELESS = 1,
    // ... 113 more variants
    B4G4R4A4_UNORM = 115,

    P208 = 130,
    V208 = 131,
    V408 = 132,

    A4B4G4R4_UNORM = 191,
}

The generated implementation for TryFromPrimitive::try_from_primitive currently looks like this (simplified):

fn try_from_primitive(number: u8) -> Result<DxgiFormatEnum, TryFromPrimitiveError> {
    #![allow(non_upper_case_globals)]
    const UNKNOWN__num_enum_0__: u8 = 0;
    const R32G32B32A32_TYPELESS__num_enum_0__: u8 = 1;
    // ... and so on
    const B4G4R4A4_UNORM__num_enum_0__: u8 = 115;
    const P208__num_enum_0__: u8 = 130;
    const V208__num_enum_0__: u8 = 131;
    const V408__num_enum_0__: u8 = 132;
    const A4B4G4R4_UNORM__num_enum_0__: u8 = 191;
    #[deny(unreachable_patterns)]
    match number {
        UNKNOWN__num_enum_0__ => Ok(DxgiFormatEnum::UNKNOWN),
        R32G32B32A32_TYPELESS__num_enum_0__ => Ok(DxgiFormatEnum::R32G32B32A32_TYPELESS),
        // ... and so on
        B4G4R4A4_UNORM__num_enum_0__ => Ok(DxgiFormatEnum::B4G4R4A4_UNORM),
        P208__num_enum_0__ => Ok(DxgiFormatEnum::P208),
        V208__num_enum_0__ => Ok(DxgiFormatEnum::V208),
        V408__num_enum_0__ => Ok(DxgiFormatEnum::V408),
        A4B4G4R4_UNORM__num_enum_0__ => Ok(DxgiFormatEnum::A4B4G4R4_UNORM),
        #[allow(unreachable_patterns)]
        _ => Err(TryFromPrimitiveError(number)),
    }
}

And this generates the following assembly on Rust 1.85.0 for O1, O2, O3, Os, and Oz: (Compiler Explorer link)

try_from_primitive:
        mov     al, 1
        cmp     dil, -65 ; compare the input primitive with 191
        ja      .LBB0_3  ; if the input is >191, go to the Err branch
        movzx   ecx, dil
        lea     rdx, [rip + .LJTI0_0]
        movsxd  rcx, dword ptr [rdx + 4*rcx]
        add     rcx, rdx
        jmp     rcx
.LBB0_2: ; Ok branch
        xor     eax, eax
.LBB0_3: ; Err branch
        mov     edx, edi
        ret
.LJTI0_0:
        .long   .LBB0_2-.LJTI0_0 ; 0 = UNKNOWN
        .long   .LBB0_2-.LJTI0_0 ; 1 = R32G32B32A32_TYPELESS
        ; ...
        .long   .LBB0_2-.LJTI0_0 ; 115 = B4G4R4A4_UNORM__num_enum_0
        .long   .LBB0_3-.LJTI0_0 ; 116 Err
        ; ...
        .long   .LBB0_3-.LJTI0_0 ; 129 Err
        .long   .LBB0_2-.LJTI0_0 ; 130 = P208
        .long   .LBB0_2-.LJTI0_0 ; 131 = V208
        .long   .LBB0_2-.LJTI0_0 ; 132 = V408
        .long   .LBB0_3-.LJTI0_0 ; 133 Err
        ; ...
        .long   .LBB0_3-.LJTI0_0 ; 190 Err
        .long   .LBB0_2-.LJTI0_0 ; 191 = A4B4G4R4_UNORM

I annotated the assembly a little, but it basically works in 2 steps:

  1. Check whether the input primitive is > 191. If it is, then return Err.
  2. Otherwise, look up whether to jump to the Ok or Err branch in a jump table.

I abbreviated the jump table here, but it has 192 entries (each 4 bytes). So to retrieve 1 bit of information, we need to read a jump table entry and then jmp to the retrieved location.

This is quite inefficient. Both the memory read and jmp are comparatively slow, and the jump table is 768 bytes in binary size.

More efficient code

Compare this to the following implementation I've written by hand:

fn try_from_primitive(number: u8) -> Result<DxgiFormatEnum, TryFromPrimitiveError> {
    let is_valid = matches!(number, 0..=115 | 130..=132 | 191);

    if is_valid {
        // SAFETY: `number` is a valid discriminant value
        Ok(unsafe { std::mem::transmute(number) })
    } else {
        Err(TryFromPrimitiveError(number))
    }
}
try_from_primitive:
        cmp     dil, 116 ; compare input to 116
        jae     .LBB0_2  ; if the `input >=116`, go to LBB0_2
        xor     eax, eax ; Otherwise, return Ok
        mov     edx, edi
        ret
.LBB0_2:
        lea     ecx, [rdi + 126] ; compute `input - 130`
        mov     al, 1
        cmp     cl, 61  ; compare `input - 130` to 61 (191 - 130 = 61)
        ja      .LBB0_4 ; if `input - 130 > 61`, return Err
        movabs  rax, 2305843009213693944 ; this value is used as a bitset
        shr     rax, cl
.LBB0_4:
        mov     edx, edi
        ret

This is the output for O3 with other optimization modes producing virtually identical assembly. I want to point out that what the compiler produced is quite good here. It handled 0..=115 in its own branch and used a fast bitset lookup for 130..=132 | 191. Note that there is no jump table or anything. The assembly is just a few bytes and can easily be inlined.

The underlying issue

The underlying issue is two-fold:

  1. LLVM has a hard time optimizing the large match statement generated by num_enum.
  2. LLVM is inconsistent when it optimizes the match statement generated by num_enum.

I want to expand on the second point. LLVM actually produces very efficient assembly for the match statement generated by num_enum IF the enum is sufficiently simple.

Example:
#[derive(IntoPrimitive, TryFromPrimitive)]
#[repr(u8)]
enum Simple3 {
    A = 0,
    B = 1,
    C = 2,
}
try_from_primitive:
        cmp     dil, 3
        setae   al
        mov     edx, edi
        ret

However, even relatively simple enums can cause large jump tables to be generated.

Example:
#[derive(IntoPrimitive, TryFromPrimitive)]
#[repr(u8)]
enum SimpleJumpTable {
    A = 0,
    B = 1,
    C = 2,
    D = 3,
    E = 4,
    F = 5,
    Large = 65,
}
try_from_primitive:
        mov     al, 1
        cmp     dil, 64
        ja      .LBB0_3
        movzx   ecx, dil
        lea     rdx, [rip + .LJTI0_0]
        movsxd  rcx, dword ptr [rdx + 4*rcx]
        add     rcx, rdx
        jmp     rcx
.LBB0_2:
        xor     eax, eax
.LBB0_3:
        mov     edx, edi
        ret
.LJTI0_0:
        .long   .LBB0_2-.LJTI0_0
        .long   .LBB0_2-.LJTI0_0
        ; 63 more entries

Even small changes in discriminant values can have huge effects on the generated assembly.

Here is a Compiler Explorer link where you quickly see the generated assembly for different variants. Just input the discriminant values into the test! macro.

Suggested fix

Since the compiler seems to be bad at optimizing the large match statement generated by num_enum, I would suggest splitting TryFromPrimitive::try_from_primitive into 2 parts:

  1. The first part will determine whether the given input is a valid discriminant value.
  2. The second part uses the result from the first part and transmutes the input or returns an error (or the default/catch_all variant).

In code, it would have the following form:

fn try_from_primitive(number: Self::Primitive) -> Result<Self, TryFromPrimitiveError> {
    const Variant0: u8 = ...;
    const Variant1: u8 = ...;
    const Variant2: u8 = ...;
    const VariantN: u8 = ...;

    let is_valid = matches!(number, Variant0 | Variant1 | Variant2 | ... | VariantN);

    if is_valid {
        // SAFETY: `number` is a valid discriminant value
        Ok(unsafe { std::mem::transmute(number) })
    } else {
        Err(TryFromPrimitiveError(number))
    }
}

The compiler generally seems to be better at optimizing this form, leading to assembly that is as good or better than of the current code gen. You can test this out with this Compiler Explorer link.

Note however that matches! is not a perfect solution. It also generates jump tables sometimes, it just does it less often. To fix this, I would further suggest manually choosing a good strategy for determining is_valid instead of leaving it entirely up to the compiler. I would suggest the following:

Let V be the set of all discriminant values. The determine is_valid: bool as follows:

  1. If V is an interval, generate is_valid = (V_min..=V_max).contains(&input).
  2. If V_max - V_min <= 64, then use a 64-bit bitset.
  3. Otherwise, sort V and output a matches! with ranges instead of simple literal values. E.g. matches!(1..=4, 6..=8, 100) instead of matches!(1, 2, 3, 4, 6, 7, 8, 100). LLVM tends to produce better assembly like this.

This approach should always produce assembly that is as good or better than simply using matches! with all variants.

Note: If you choose to not implement the bitset method and want to rely on LLVM instead, I would advise aginst using matches! with ranges for small enums. LLVM tends to optimize those with bitset, unless ranges are used in macthes!. (Again, LLVM is really inconsistent in how it optimizes this.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant