diff --git a/CHANGELOG.md b/CHANGELOG.md index 338efc08..8ccfaed6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `Uint::square_redc`. ([#402]) - Support for diesel @ 2.2 ([#404]) - Support for sqlx @ 0.8 ([#400]) - Support for fastrlp @ 0.4 ([#401]) @@ -21,10 +22,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for sqlx @ 0.7. This is a breaking change, outside of regular semver policy, as 0.7 contains a security vulnerability ([#400]) +### Fixed + +- `Uint::mul_redc` is now alloc free ([#402]) + [#399]: https://github.com/recmo/uint/pull/399 [#400]: https://github.com/recmo/uint/pull/400 [#401]: https://github.com/recmo/uint/pull/401 [#404]: https://github.com/recmo/uint/pull/404 +[#402]: https://github.com/recmo/uint/pull/402 ## [1.12.3] - 2024-06-03 diff --git a/Cargo.toml b/Cargo.toml index e5a2473b..54a9dd7b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,7 +66,7 @@ parity-scale-codec = { version = "3", optional = true, features = [ "max-encoded-len", ], default-features = false } primitive-types = { version = "0.12", optional = true, default-features = false } -proptest = { version = "1.3", optional = true, default-features = false } +proptest = { version = "=1.5", optional = true, default-features = false } pyo3 = { version = "0.19", optional = true, default-features = false } quickcheck = { version = "1", optional = true, default-features = false } rand = { version = "0.8", optional = true, default-features = false } @@ -105,7 +105,7 @@ bincode = "1.3" hex = "0.4" hex-literal = "0.4" postgres = "0.19" -proptest = "1.2" +proptest = "=1.5" serde_json = "1.0" [features] diff --git a/proptest-regressions/algorithms/mul_redc.txt b/proptest-regressions/algorithms/mul_redc.txt new file mode 100644 index 00000000..11b3c11d --- /dev/null +++ b/proptest-regressions/algorithms/mul_redc.txt @@ -0,0 +1,11 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc f8106a52136ed4aac61eec137e12c5da14188344055de1af2c70670a9bdcb685 # shrinks to a = 1, b = 1, m = 6096796062212595973 +cc d8b943c322534ac6073b169239a6a57d92ed874bd1af77abcaca5fd6c17d0922 # shrinks to a = 1, m = 6467196249019906631 +cc 7bce74ed04eba0d78a0753b45256b4fa898002bfbfdd2a45e2fa8692bddda85b # shrinks to a = 17273988827536164680, m = 4783851910396016589 +cc 31a52325174b546a906b7327ea489c5df219ac159f819e2cd0240aaab7fe1da6 # shrinks to a = 14240046082810188870, b = 16896972505368501529, m = 6144969318566343923 +cc 14972b6b9d20d6efe0cc3595dd60936ae21f0985b25cf47b5bf667cdf0b6f9b3 # shrinks to mut a = 194085243466426527248460240309849653567, m = 302498102704436076702507509051420483905 diff --git a/proptest-regressions/modular.txt b/proptest-regressions/modular.txt index 3c9ef0d9..60f1d75c 100644 --- a/proptest-regressions/modular.txt +++ b/proptest-regressions/modular.txt @@ -9,3 +9,4 @@ cc d3df2bf31e0850f89f640c6a35c5ddb7fd15bc57c04eb60df5c04aea86d4b27a # shrinks to cc 530d6a1671f6f937904c349f7ae6504c7cd2e05c31bf194e3aa38d7161a5fc2a # shrinks to a = 0x00_U2, b = 0x00_U2, m = 0x03_U2 cc d7d611337732de2c417788803637c596949792f5bcb942956ea0ec3e8b889d82 # shrinks to a = 0x00_U2, b = 0x00_U2, c = 0x00_U2, m = 0x03_U2 cc f3498e21378eea45e82848df9f85d19faa37874e686e4874bb1da885f3a2ac38 # shrinks to a = 0x00_U2, b = 0x00_U2, c = 0x00_U2, m = 0x03_U2 +cc e40e555ad3f7103369086e7a27e810f9223e6e67d575d7b26a5a43fc8b103afc # shrinks to a = 2251333493155034715, b = 12864035474233633436, m = 17624464859391105743 diff --git a/src/add.rs b/src/add.rs index 77269b5c..9c508b73 100644 --- a/src/add.rs +++ b/src/add.rs @@ -1,4 +1,7 @@ -use crate::Uint; +use crate::{ + algorithms::{borrowing_sub, carrying_add}, + Uint, +}; use core::{ iter::Sum, ops::{Add, AddAssign, Neg, Sub, SubAssign}, @@ -56,24 +59,16 @@ impl Uint { #[inline] #[must_use] pub const fn overflowing_add(mut self, rhs: Self) -> (Self, bool) { - // TODO: Replace with `u64::carrying_add` once stable. - #[inline] - const fn u64_carrying_add(lhs: u64, rhs: u64, carry: bool) -> (u64, bool) { - let (a, b) = lhs.overflowing_add(rhs); - let (c, d) = a.overflowing_add(carry as u64); - (c, b | d) - } - if BITS == 0 { return (Self::ZERO, false); } let mut carry = false; let mut i = 0; while i < LIMBS { - (self.limbs[i], carry) = u64_carrying_add(self.limbs[i], rhs.limbs[i], carry); + (self.limbs[i], carry) = carrying_add(self.limbs[i], rhs.limbs[i], carry); i += 1; } - let overflow = carry || self.limbs[LIMBS - 1] > Self::MASK; + let overflow = carry | (self.limbs[LIMBS - 1] > Self::MASK); self.limbs[LIMBS - 1] &= Self::MASK; (self, overflow) } @@ -98,24 +93,16 @@ impl Uint { #[inline] #[must_use] pub const fn overflowing_sub(mut self, rhs: Self) -> (Self, bool) { - // TODO: Replace with `u64::borrowing_sub` once stable. - #[inline] - const fn u64_borrowing_sub(lhs: u64, rhs: u64, borrow: bool) -> (u64, bool) { - let (a, b) = lhs.overflowing_sub(rhs); - let (c, d) = a.overflowing_sub(borrow as u64); - (c, b | d) - } - if BITS == 0 { return (Self::ZERO, false); } let mut borrow = false; let mut i = 0; while i < LIMBS { - (self.limbs[i], borrow) = u64_borrowing_sub(self.limbs[i], rhs.limbs[i], borrow); + (self.limbs[i], borrow) = borrowing_sub(self.limbs[i], rhs.limbs[i], borrow); i += 1; } - let overflow = borrow || self.limbs[LIMBS - 1] > Self::MASK; + let overflow = borrow | (self.limbs[LIMBS - 1] > Self::MASK); self.limbs[LIMBS - 1] &= Self::MASK; (self, overflow) } diff --git a/src/algorithms/gcd/mod.rs b/src/algorithms/gcd/mod.rs index 0898bfef..b6f328c4 100644 --- a/src/algorithms/gcd/mod.rs +++ b/src/algorithms/gcd/mod.rs @@ -1,5 +1,7 @@ #![allow(clippy::module_name_repetitions)] +// TODO: https://github.com/bitcoin-core/secp256k1/blob/master/doc/safegcd_implementation.md + // TODO: Make these algorithms work on limb slices. mod matrix; diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index f056e405..487daf71 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -11,7 +11,6 @@ mod add; pub mod div; mod gcd; mod mul; -#[cfg(feature = "alloc")] // TODO: Make mul_redc alloc-free mod mul_redc; mod ops; mod shift; @@ -21,11 +20,10 @@ pub use self::{ div::div, gcd::{gcd, gcd_extended, inv_mod, LehmerMatrix}, mul::{add_nx1, addmul, addmul_n, addmul_nx1, mul_nx1, submul_nx1}, + mul_redc::{mul_redc, square_redc}, ops::{adc, sbb}, shift::{shift_left_small, shift_right_small}, }; -#[cfg(feature = "alloc")] -pub use mul_redc::mul_redc; trait DoubleWord: Sized + Copy { fn join(high: T, low: T) -> Self; @@ -116,3 +114,21 @@ pub fn cmp(left: &[u64], right: &[u64]) -> Ordering { left.len().cmp(&right.len()) } + +// Helper while [Rust#85532](https://github.com/rust-lang/rust/issues/85532) stabilizes. +#[inline] +#[must_use] +pub const fn carrying_add(lhs: u64, rhs: u64, carry: bool) -> (u64, bool) { + let (result, carry_1) = lhs.overflowing_add(rhs); + let (result, carry_2) = result.overflowing_add(carry as u64); + (result, carry_1 | carry_2) +} + +// Helper while [Rust#85532](https://github.com/rust-lang/rust/issues/85532) stabilizes. +#[inline] +#[must_use] +pub const fn borrowing_sub(lhs: u64, rhs: u64, borrow: bool) -> (u64, bool) { + let (result, borrow_1) = lhs.overflowing_sub(rhs); + let (result, borrow_2) = result.overflowing_sub(borrow as u64); + (result, borrow_1 | borrow_2) +} diff --git a/src/algorithms/mul_redc.rs b/src/algorithms/mul_redc.rs index 0cbbcdc0..3781a7f0 100644 --- a/src/algorithms/mul_redc.rs +++ b/src/algorithms/mul_redc.rs @@ -1,70 +1,260 @@ -use super::addmul; -use core::iter::zip; +// TODO: https://baincapitalcrypto.com/optimizing-montgomery-multiplication-in-webassembly/ -/// See Handbook of Applied Cryptography, Algorithm 14.32, p. 601. -#[allow(clippy::cognitive_complexity)] // REFACTOR: Improve +use super::{borrowing_sub, carrying_add, cmp}; +use core::{cmp::Ordering, iter::zip}; + +/// Computes a * b * 2^(-BITS) mod modulus +/// +/// Requires that `inv` is the inverse of `-modulus[0]` modulo `2^64`. +/// Requires that `a` and `b` are less than `modulus`. #[inline] -pub fn mul_redc(a: &[u64], b: &[u64], result: &mut [u64], m: &[u64], inv: u64) { - debug_assert!(!m.is_empty()); - debug_assert_eq!(a.len(), m.len()); - debug_assert_eq!(b.len(), m.len()); - debug_assert_eq!(result.len(), m.len()); - debug_assert_eq!(inv.wrapping_mul(m[0]), u64::MAX); - - // Compute temp full product. - // OPT: Do combined multiplication and reduction. - let mut temp = vec![0; 2 * m.len() + 1]; - addmul(&mut temp, a, b); - - // Reduce temp. - for i in 0..m.len() { - let u = temp[i].wrapping_mul(inv); - - // REFACTOR: Create add_mul1 routine. - let mut carry = 0; - #[allow(clippy::cast_possible_truncation)] // Intentional - for j in 0..m.len() { - carry += u128::from(temp[i + j]) + u128::from(m[j]) * u128::from(u); - temp[i + j] = carry as u64; - carry >>= 64; - } - #[allow(clippy::cast_possible_truncation)] // Intentional - for j in m.len()..(temp.len() - i) { - carry += u128::from(temp[i + j]); - temp[i + j] = carry as u64; - carry >>= 64; - } - debug_assert!(carry == 0); - } - debug_assert!(temp[temp.len() - 1] <= 1); // Basically a carry flag. - - // Copy result. - result.copy_from_slice(&temp[m.len()..2 * m.len()]); - - // Subtract one more m if result >= m - let mut reduce = true; - // REFACTOR: Create cmp routine - if temp[temp.len() - 1] == 0 { - for (r, m) in zip(result.iter().rev(), m.iter().rev()) { - if r < m { - reduce = false; - break; +#[must_use] +pub fn mul_redc(a: [u64; N], b: [u64; N], modulus: [u64; N], inv: u64) -> [u64; N] { + debug_assert_eq!(inv.wrapping_mul(modulus[0]), u64::MAX); + debug_assert_eq!(cmp(&a, &modulus), Ordering::Less); + debug_assert_eq!(cmp(&b, &modulus), Ordering::Less); + + // Coarsely Integrated Operand Scanning (CIOS) + // See + // See + // See + let mut result = [0; N]; + let mut carry = false; + for b in b { + let mut m = 0; + let mut carry_1 = 0; + let mut carry_2 = 0; + for i in 0..N { + // Add limb product + let (value, next_carry) = carrying_mul_add(a[i], b, result[i], carry_1); + carry_1 = next_carry; + + if i == 0 { + // Compute reduction factor + m = value.wrapping_mul(inv); } - if r > m { - break; + + // Add m * modulus to acc to clear next_result[0] + let (value, next_carry) = carrying_mul_add(modulus[i], m, value, carry_2); + carry_2 = next_carry; + + // Shift result + if i > 0 { + result[i - 1] = value; + } else { + debug_assert_eq!(value, 0); } } + + // Add carries + let (value, next_carry) = carrying_add(carry_1, carry_2, carry); + result[N - 1] = value; + if modulus[N - 1] >= 0x7fff_ffff_ffff_ffff { + carry = next_carry; + } else { + debug_assert!(!next_carry); + } } - if reduce { - // REFACTOR: Create sub routine - let mut carry = 0; - #[allow(clippy::cast_possible_truncation)] // Intentional - #[allow(clippy::cast_sign_loss)] // Intentional - for (r, m) in zip(result.iter_mut(), m.iter().copied()) { - carry += i128::from(*r) - i128::from(m); - *r = carry as u64; - carry >>= 64; // Sign extending shift + + // Compute reduced product. + reduce1_carry(result, modulus, carry) +} + +/// Computes a^2 * 2^(-BITS) mod modulus +/// +/// Requires that `inv` is the inverse of `-modulus[0]` modulo `2^64`. +/// Requires that `a` is less than `modulus`. +#[inline] +#[must_use] +#[allow(clippy::cast_possible_truncation)] +pub fn square_redc(a: [u64; N], modulus: [u64; N], inv: u64) -> [u64; N] { + debug_assert_eq!(inv.wrapping_mul(modulus[0]), u64::MAX); + debug_assert_eq!(cmp(&a, &modulus), Ordering::Less); + + let mut result = [0; N]; + let mut carry_outer = 0; + for i in 0..N { + // Add limb product + let (value, mut carry_lo) = carrying_mul_add(a[i], a[i], result[i], 0); + let mut carry_hi = false; + result[i] = value; + for j in (i + 1)..N { + let (value, next_carry_lo, next_carry_hi) = + carrying_double_mul_add(a[i], a[j], result[j], carry_lo, carry_hi); + result[j] = value; + carry_lo = next_carry_lo; + carry_hi = next_carry_hi; + } + + // Add m times modulus to result and shift one limb + let m = result[0].wrapping_mul(inv); + let (value, mut carry) = carrying_mul_add(m, modulus[0], result[0], 0); + debug_assert_eq!(value, 0); + for j in 1..N { + let (value, next_carry) = carrying_mul_add(modulus[j], m, result[j], carry); + result[j - 1] = value; + carry = next_carry; + } + + // Add carries + if modulus[N - 1] >= 0x3fff_ffff_ffff_ffff { + let wide = (carry_outer as u128) + .wrapping_add(carry_lo as u128) + .wrapping_add((carry_hi as u128) << 64) + .wrapping_add(carry as u128); + result[N - 1] = wide as u64; + + // Note carry_outer can be {0, 1, 2}. + carry_outer = (wide >> 64) as u64; + debug_assert!(carry_outer <= 2); + } else { + // `carry_outer` and `carry_hi` are always zero. + debug_assert!(!carry_hi); + debug_assert_eq!(carry_outer, 0); + let (value, carry) = carry_lo.overflowing_add(carry); + debug_assert!(!carry); + result[N - 1] = value; } - debug_assert!(carry == 0 || temp[temp.len() - 1] == 1); + } + + // Compute reduced product. + debug_assert!(carry_outer <= 1); + reduce1_carry(result, modulus, carry_outer > 0) +} + +#[inline] +#[must_use] +#[allow(clippy::needless_bitwise_bool)] +fn reduce1_carry(value: [u64; N], modulus: [u64; N], carry: bool) -> [u64; N] { + let (reduced, borrow) = sub(value, modulus); + // TODO: Ideally this turns into a cmov, which makes the whole mul_redc constant + // time. + if carry | !borrow { + reduced + } else { + value + } +} + +#[inline] +#[must_use] +fn sub(lhs: [u64; N], rhs: [u64; N]) -> ([u64; N], bool) { + let mut result = [0; N]; + let mut borrow = false; + for (result, (lhs, rhs)) in zip(&mut result, zip(lhs, rhs)) { + let (value, next_borrow) = borrowing_sub(lhs, rhs, borrow); + *result = value; + borrow = next_borrow; + } + (result, borrow) +} + +/// Compute `lhs * rhs + add + carry`. +/// The output can not overflow for any input values. +#[inline] +#[must_use] +#[allow(clippy::cast_possible_truncation)] +const fn carrying_mul_add(lhs: u64, rhs: u64, add: u64, carry: u64) -> (u64, u64) { + let wide = (lhs as u128) + .wrapping_mul(rhs as u128) + .wrapping_add(add as u128) + .wrapping_add(carry as u128); + (wide as u64, (wide >> 64) as u64) +} + +/// Compute `2 * lhs * rhs + add + carry_lo + 2^64 * carry_hi`. +/// The output can not overflow for any input values. +#[inline] +#[must_use] +#[allow(clippy::cast_possible_truncation)] +const fn carrying_double_mul_add( + lhs: u64, + rhs: u64, + add: u64, + carry_lo: u64, + carry_hi: bool, +) -> (u64, u64, bool) { + let wide = (lhs as u128).wrapping_mul(rhs as u128); + let (wide, carry_1) = wide.overflowing_add(wide); + let carries = (add as u128) + .wrapping_add(carry_lo as u128) + .wrapping_add((carry_hi as u128) << 64); + let (wide, carry_2) = wide.overflowing_add(carries); + (wide as u64, (wide >> 64) as u64, carry_1 | carry_2) +} + +#[cfg(test)] +mod test { + use core::ops::Neg; + + use super::{ + super::{addmul, div}, + *, + }; + use crate::{aliases::U64, const_for, nlimbs, Uint}; + use proptest::{prop_assert_eq, proptest}; + + fn modmul(a: [u64; N], b: [u64; N], modulus: [u64; N]) -> [u64; N] { + // Compute a * b + let mut product = vec![0; 2 * N]; + addmul(&mut product, &a, &b); + + // Compute product mod modulus + let mut reduced = modulus; + div(&mut product, &mut reduced); + reduced + } + + fn mul_base(a: [u64; N], modulus: [u64; N]) -> [u64; N] { + // Compute a * 2^(N * 64) + let mut product = vec![0; 2 * N]; + product[N..].copy_from_slice(&a); + + // Compute product mod modulus + let mut reduced = modulus; + div(&mut product, &mut reduced); + reduced + } + + #[test] + fn test_mul_redc() { + const_for!(BITS in NON_ZERO if (BITS >= 16) { + const LIMBS: usize = nlimbs(BITS); + type U = Uint; + proptest!(|(mut a: U, mut b: U, mut m: U)| { + m |= U::from(1_u64); // Make sure m is odd. + a %= m; // Make sure a is less than m. + b %= m; // Make sure b is less than m. + let a = *a.as_limbs(); + let b = *b.as_limbs(); + let m = *m.as_limbs(); + let inv = U64::from(m[0]).inv_ring().unwrap().neg().as_limbs()[0]; + + let result = mul_base(mul_redc(a, b, m, inv), m); + let expected = modmul(a, b, m); + + prop_assert_eq!(result, expected); + }); + }); + } + + #[test] + fn test_square_redc() { + const_for!(BITS in NON_ZERO if (BITS >= 16) { + const LIMBS: usize = nlimbs(BITS); + type U = Uint; + proptest!(|(mut a: U, mut m: U)| { + m |= U::from(1_u64); // Make sure m is odd. + a %= m; // Make sure a is less than m. + let a = *a.as_limbs(); + let m = *m.as_limbs(); + let inv = U64::from(m[0]).inv_ring().unwrap().neg().as_limbs()[0]; + + let result = mul_base(square_redc(a, m, inv), m); + let expected = modmul(a, a, m); + + prop_assert_eq!(result, expected); + }); + }); } } diff --git a/src/modular.rs b/src/modular.rs index 30335a26..95e095fa 100644 --- a/src/modular.rs +++ b/src/modular.rs @@ -115,6 +115,8 @@ impl Uint { /// Montgomery multiplication. /// + /// Requires `self` and `other` to be less than `modulus`. + /// /// Computes /// /// $$ @@ -150,23 +152,30 @@ impl Uint { /// /// # Panics /// - /// Panics if `inv` is not correct. + /// Panics if `inv` is not correct in debug mode. #[inline] #[must_use] - #[cfg(feature = "alloc")] // TODO: Make mul_redc alloc-free pub fn mul_redc(self, other: Self, modulus: Self, inv: u64) -> Self { if BITS == 0 { return Self::ZERO; } - assert_eq!(inv.wrapping_mul(modulus.limbs[0]), u64::MAX); - let mut result = Self::ZERO; - algorithms::mul_redc( - self.as_limbs(), - other.as_limbs(), - &mut result.limbs, - modulus.as_limbs(), - inv, - ); + let result = algorithms::mul_redc(self.limbs, other.limbs, modulus.limbs, inv); + let result = Self::from_limbs(result); + debug_assert!(result < modulus); + result + } + + /// Montgomery squaring. + /// + /// See [Self::mul_redc]. + #[inline] + #[must_use] + pub fn square_redc(self, modulus: Self, inv: u64) -> Self { + if BITS == 0 { + return Self::ZERO; + } + let result = algorithms::square_redc(self.limbs, modulus.limbs, inv); + let result = Self::from_limbs(result); debug_assert!(result < modulus); result } @@ -301,4 +310,26 @@ mod tests { }); }); } + + #[test] + fn test_square_redc() { + const_for!(BITS in NON_ZERO if (BITS >= 16) { + const LIMBS: usize = nlimbs(BITS); + type U = Uint; + proptest!(|(a: U, m: U)| { + prop_assume!(m >= U::from(2)); + if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() { + let inv = (-inv).as_limbs()[0]; + + let r = U::from(2).pow_mod(U::from(64 * LIMBS), m); + let ar = a.mul_mod(r, m); + // TODO: Test for larger (>= m) values of a, b. + + let expected = a.mul_mod(a, m).mul_mod(r, m); + + assert_eq!(ar.square_redc(m, inv), expected); + } + }); + }); + } }