Skip to content

Commit

Permalink
Significantly optimize cost calculation model
Browse files Browse the repository at this point in the history
Comparing profile data for the original Zopfli implementation in C vs.
this Rust port in a realistic test file highlighted that Rust code spent
a significantly larger portion of time computing the translated
`GetCostStat` function, `zopfli::squeeze::get_cost_stat`. This is one of
the hottest functions, responsible for more than 5% of the total samples
collected by the sampling profiler I used.

When inspecting the generated x64 assembly code for this function in
Rust vs. in C, I noticed that the Rust function took ~200 lines of
assembly, while in C the same function is translated to ~50 lines. Upon
closer inspection, I noticed Rust's codegen for this function was
particularly suboptimal for two main reasons:

- Safe Rust performs array index bound checks at runtime unless the
  compiler can assert that the index value range is always in bounds.
- Functions such as `get_dist_extra_bits`, which are implemented via
  `match` expressions with lots of patterns, got translated into
  many inefficient test and jump instructions, while the related C code
  leveraged intrinsics to compute the same result with better
  readability and much less and much faster instructions.

This change improves performance in both fronts: first, it tactically
pads lookup tables and inserts assertions so that the overall amount of
bound checks is reduced, leading to increased optimization opportunities
for the compiler; second, `match` expressions are replaced by their more
readable, efficient, and correct alternatives.

With this change, the Rust function now takes ~90 lines of assembly, a
55% improvement, and a total file compression performance uplift of ~10%
for several test files is achieved.

Related to #12.
  • Loading branch information
AlexTMjugador committed Feb 23, 2025
1 parent f2b32cc commit 718fb62
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 77 deletions.
10 changes: 5 additions & 5 deletions src/deflate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ fn calculate_block_symbol_size_small(
LitLen::LengthDist(litlens_i, dists_i) => {
debug_assert!(litlens_i < 259);
let ll_symbol = get_length_symbol(litlens_i as usize);
let d_symbol = get_dist_symbol(dists_i);
let d_symbol = get_dist_symbol(dists_i) as usize;
result += ll_lengths[ll_symbol];
result += d_lengths[d_symbol];
result += get_length_symbol_extra_bits(ll_symbol);
Expand Down Expand Up @@ -1022,20 +1022,20 @@ fn add_lz77_data<W: Write>(
}
LitLen::LengthDist(len, dist) => {
let litlen = len as usize;
assert!((3..=288).contains(&litlen)); // Eases inlining and gets rid of index bound checks below
let lls = get_length_symbol(litlen);
let ds = get_dist_symbol(dist);
debug_assert!((3..=288).contains(&litlen));
let ds = get_dist_symbol(dist) as usize;
debug_assert!(ll_lengths[lls] > 0);
debug_assert!(d_lengths[ds] > 0);
bitwise_writer.add_huffman_bits(ll_symbols[lls], ll_lengths[lls])?;
bitwise_writer.add_bits(
get_length_extra_bits_value(litlen),
get_length_extra_bits(litlen) as u32,
get_length_extra_bits(litlen),
)?;
bitwise_writer.add_huffman_bits(d_symbols[ds], d_lengths[ds])?;
bitwise_writer.add_bits(
u32::from(get_dist_extra_bits_value(dist)),
get_dist_extra_bits(dist) as u32,
get_dist_extra_bits(dist),
)?;
testlength += litlen;
}
Expand Down
4 changes: 2 additions & 2 deletions src/lz77.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ impl Lz77Store {
LitLen::LengthDist(length, dist) => {
let len_sym = get_length_symbol(length as usize);
self.ll_symbol.push(len_sym as u16);
self.d_symbol.push(get_dist_symbol(dist) as u16);
self.d_symbol.push(get_dist_symbol(dist));
self.ll_counts[llstart + len_sym] += 1;
self.d_counts[dstart + get_dist_symbol(dist)] += 1;
self.d_counts[dstart + get_dist_symbol(dist) as usize] += 1;
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions src/squeeze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,20 @@ fn get_cost_fixed(litlen: usize, dist: u16) -> f64 {
let lbits = get_length_extra_bits(litlen);
let lsym = get_length_symbol(litlen);
// Every dist symbol has length 5.
7 + usize::from(lsym > 279) + 5 + dbits + lbits
7 + u32::from(lsym > 279) + 5 + dbits + lbits
};
result as f64
}

/// Cost model based on symbol statistics.
fn get_cost_stat(litlen: usize, dist: u16, stats: &SymbolStats) -> f64 {
assert!(litlen < ZOPFLI_NUM_LL); // Eases inlining and gets rid of index bound checks below
if dist == 0 {
stats.ll_symbols[litlen]
} else {
let lsym = get_length_symbol(litlen);
let lbits = get_length_extra_bits(litlen) as f64;
let dsym = get_dist_symbol(dist);
let dsym = get_dist_symbol(dist) as usize;
let dbits = get_dist_extra_bits(dist) as f64;
lbits + dbits + stats.ll_symbols[lsym] + stats.d_symbols[dsym]
}
Expand Down Expand Up @@ -168,7 +169,7 @@ impl SymbolStats {
LitLen::Literal(lit) => self.litlens[lit as usize] += 1,
LitLen::LengthDist(len, dist) => {
self.litlens[get_length_symbol(len as usize)] += 1;
self.dists[get_dist_symbol(dist)] += 1;
self.dists[get_dist_symbol(dist) as usize] += 1;
}
}
}
Expand Down
94 changes: 27 additions & 67 deletions src/symbols.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
const LENGTH_SYMBOL_TABLE: [usize; 259] = [
const LENGTH_SYMBOL_TABLE: [usize; 288 /* Originally, 259 entries */] = [
0, 0, 0, 257, 258, 259, 260, 261, 262, 263, 264, 265, 265, 266, 266, 267, 267, 268, 268, 269,
269, 269, 269, 270, 270, 270, 270, 271, 271, 271, 271, 272, 272, 272, 272, 273, 273, 273, 273,
273, 273, 273, 273, 274, 274, 274, 274, 274, 274, 274, 274, 275, 275, 275, 275, 275, 275, 275,
Expand All @@ -13,6 +13,8 @@ const LENGTH_SYMBOL_TABLE: [usize; 259] = [
283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 284, 284,
284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284,
284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 285,
// Padding to 288 entries for reduced bounds checking
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
];

/// Gets the symbol for the given length, cfr. the DEFLATE spec.
Expand All @@ -22,78 +24,33 @@ pub const fn get_length_symbol(length: usize) -> usize {
}

/// Gets the amount of extra bits for the given dist, cfr. the DEFLATE spec.
pub const fn get_dist_extra_bits(dist: u16) -> usize {
(match dist {
0..=4 => 0,
5..=8 => 1,
9..=16 => 2,
17..=32 => 3,
33..=64 => 4,
65..=128 => 5,
129..=256 => 6,
257..=512 => 7,
513..=1024 => 8,
1025..=2048 => 9,
2049..=4096 => 10,
4097..=8192 => 11,
8193..=16384 => 12,
_ => 13,
}) as usize
pub const fn get_dist_extra_bits(dist: u16) -> u32 {
if dist < 5 {
return 0;
}
(dist - 1).ilog2() - 1
}

/// Gets value of the extra bits for the given dist, cfr. the DEFLATE spec.
pub const fn get_dist_extra_bits_value(dist: u16) -> u16 {
match dist {
0..=4 => 0,
5..=8 => (dist - 5) & 1,
9..=16 => (dist - 9) & 3,
17..=32 => (dist - 17) & 7,
33..=64 => (dist - 33) & 15,
65..=128 => (dist - 65) & 31,
129..=256 => (dist - 129) & 63,
257..=512 => (dist - 257) & 127,
513..=1024 => (dist - 513) & 255,
1025..=2048 => (dist - 1025) & 511,
2049..=4096 => (dist - 2049) & 1023,
4097..=8192 => (dist - 4097) & 2047,
8193..=16384 => (dist - 8193) & 4095,
_ => (dist - 16385) & 8191,
if dist < 5 {
return 0;
}
let l = (dist - 1).ilog2();
(dist - (1 + (1 << l))) & ((1 << (l - 1)) - 1)
}

pub const fn get_dist_symbol(dist: u16) -> usize {
(match dist {
0..=4 => dist - 1,
5..=6 => 4,
7..=8 => 5,
9..=12 => 6,
13..=16 => 7,
17..=24 => 8,
25..=32 => 9,
33..=48 => 10,
49..=64 => 11,
65..=96 => 12,
97..=128 => 13,
129..=192 => 14,
193..=256 => 15,
257..=384 => 16,
385..=512 => 17,
513..=768 => 18,
769..=1024 => 19,
1025..=1536 => 20,
1537..=2048 => 21,
2049..=3072 => 22,
3073..=4096 => 23,
4097..=6144 => 24,
6145..=8192 => 25,
8193..=12288 => 26,
12289..=16384 => 27,
16385..=24576 => 28,
_ => 29,
}) as usize
pub const fn get_dist_symbol(dist: u16) -> u16 {
if dist < 5 {
// dist should never equal zero, and wrapping generates more efficient code
return dist.wrapping_sub(1);
}
let l = (dist - 1).ilog2();
let r = ((dist - 1) >> (l - 1)) & 1;
l as u16 * 2 + r
}

const LENGTH_EXTRA_BITS: [usize; 259] = [
const LENGTH_EXTRA_BITS: [u32; 288 /* Originally, 259 entries */] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
Expand All @@ -102,15 +59,16 @@ const LENGTH_EXTRA_BITS: [usize; 259] = [
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 5, 0,
5, 5, 0, // Padding to 288 entries for reduced bounds checking
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
];

/// Gets the amount of extra bits for the given length, cfr. the DEFLATE spec.
pub const fn get_length_extra_bits(l: usize) -> usize {
pub const fn get_length_extra_bits(l: usize) -> u32 {
LENGTH_EXTRA_BITS[l]
}

const LENGTH_EXTRA_BITS_VALUE: [u32; 259] = [
const LENGTH_EXTRA_BITS_VALUE: [u32; 288 /* Originally, 259 entries */] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0,
1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4,
5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
Expand All @@ -121,6 +79,8 @@ const LENGTH_EXTRA_BITS_VALUE: [u32; 259] = [
4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
23, 24, 25, 26, 27, 28, 29, 30, 0,
// Padding to 288 entries for reduced bounds checking
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
];

/// Gets value of the extra bits for the given length, cfr. the DEFLATE spec.
Expand Down

0 comments on commit 718fb62

Please sign in to comment.