Skip to content

Commit 3a0545c

Browse files
authored
ml-kem: fix potential kyberslash attack (#18)
Updates compress method to remove division of secret values by public value (q) as described in the kyberslash attack
1 parent a299f50 commit 3a0545c

File tree

8 files changed

+141
-31
lines changed

8 files changed

+141
-31
lines changed

.github/workflows/workspace.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
timeout-minutes: 45
2222
steps:
2323
- uses: actions/checkout@v4
24-
- uses: dtolnay/rust-toolchain@1.74.0
24+
- uses: dtolnay/rust-toolchain@beta # TODO: use `1.79` after 2024-06-13
2525
with:
2626
components: clippy
2727
- run: cargo clippy -- -D warnings

Cargo.lock

Lines changed: 31 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ml-kem/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ sha3 = { version = "0.10.8", default-features = false }
2828
criterion = "0.5.1"
2929
hex = "0.4.3"
3030
hex-literal = "0.4.1"
31+
num-rational = { version = "0.4.2", default-features = false, features = ["num-bigint"] }
3132
rand = "0.8.5"
3233
crypto-common = { version = "0.1.6", features = ["rand_core"] }
3334

ml-kem/src/algebra.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ pub struct FieldElement(pub Integer);
1717
impl FieldElement {
1818
pub const Q: Integer = 3329;
1919
pub const Q32: u32 = Self::Q as u32;
20-
const Q64: u64 = Self::Q as u64;
20+
pub const Q64: u64 = Self::Q as u64;
2121
const BARRETT_SHIFT: usize = 24;
22+
#[allow(clippy::integer_division_remainder_used)]
2223
const BARRETT_MULTIPLIER: u64 = (1 << Self::BARRETT_SHIFT) / Self::Q64;
2324

2425
// A fast modular reduction for small numbers `x < 2*q`
@@ -263,6 +264,7 @@ impl NttPolynomial {
263264
#[allow(clippy::cast_possible_truncation)]
264265
const ZETA_POW_BITREV: [FieldElement; 128] = {
265266
const ZETA: u64 = 17;
267+
#[allow(clippy::integer_division_remainder_used)]
266268
const fn bitrev7(x: usize) -> usize {
267269
((x >> 6) % 2)
268270
| (((x >> 5) % 2) << 1)
@@ -277,6 +279,7 @@ const ZETA_POW_BITREV: [FieldElement; 128] = {
277279
let mut pow = [FieldElement(0); 128];
278280
let mut i = 0;
279281
let mut curr = 1u64;
282+
#[allow(clippy::integer_division_remainder_used)]
280283
while i < 128 {
281284
pow[i] = FieldElement(curr as u16);
282285
i += 1;
@@ -300,6 +303,7 @@ const GAMMA: [FieldElement; 128] = {
300303
let mut i = 0;
301304
while i < 128 {
302305
let zpr = ZETA_POW_BITREV[i].0 as u64;
306+
#[allow(clippy::integer_division_remainder_used)]
303307
let g = (zpr * zpr * ZETA) % FieldElement::Q64;
304308
gamma[i] = FieldElement(g as u16);
305309
i += 1;

ml-kem/src/compress.rs

Lines changed: 98 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ use crate::util::Truncate;
66
pub trait CompressionFactor: EncodingSize {
77
const POW2_HALF: u32;
88
const MASK: Integer;
9+
const DIV_SHIFT: usize;
10+
const DIV_MUL: u64;
911
}
1012

1113
impl<T> CompressionFactor for T
@@ -14,6 +16,9 @@ where
1416
{
1517
const POW2_HALF: u32 = 1 << (T::USIZE - 1);
1618
const MASK: Integer = ((1 as Integer) << T::USIZE) - 1;
19+
const DIV_SHIFT: usize = 34;
20+
#[allow(clippy::integer_division_remainder_used)]
21+
const DIV_MUL: u64 = (1 << T::DIV_SHIFT) / FieldElement::Q64;
1722
}
1823

1924
// Traits for objects that allow compression / decompression
@@ -25,25 +30,26 @@ pub trait Compress {
2530
impl Compress for FieldElement {
2631
// Equation 4.5: Compress_d(x) = round((2^d / q) x)
2732
//
28-
// Here and in decompression, we leverage the following fact:
33+
// Here and in decompression, we leverage the following facts:
2934
//
3035
// round(a / b) = floor((a + b/2) / b)
36+
// a / q ~= (a * x) >> s where x >> s ~= 1/q
3137
fn compress<D: CompressionFactor>(&mut self) -> &Self {
32-
const Q_HALF: u32 = (FieldElement::Q32 - 1) / 2;
33-
let x = u32::from(self.0);
34-
let y = ((x << D::USIZE) + Q_HALF) / FieldElement::Q32;
38+
const Q_HALF: u64 = (FieldElement::Q64 + 1) >> 1;
39+
let x = u64::from(self.0);
40+
let y = ((((x << D::USIZE) + Q_HALF) * D::DIV_MUL) >> D::DIV_SHIFT).truncate();
3541
self.0 = y.truncate() & D::MASK;
3642
self
3743
}
38-
// Equation 4.6: Decomporess_d(x) = round((q / 2^d) x)
44+
45+
// Equation 4.6: Decompress_d(x) = round((q / 2^d) x)
3946
fn decompress<D: CompressionFactor>(&mut self) -> &Self {
4047
let x = u32::from(self.0);
4148
let y = ((x * FieldElement::Q32) + D::POW2_HALF) >> D::USIZE;
4249
self.0 = y.truncate();
4350
self
4451
}
4552
}
46-
4753
impl Compress for Polynomial {
4854
fn compress<D: CompressionFactor>(&mut self) -> &Self {
4955
for x in &mut self.0 {
@@ -84,37 +90,100 @@ impl<K: ArraySize> Compress for PolynomialVector<K> {
8490
pub(crate) mod test {
8591
use super::*;
8692
use hybrid_array::typenum::{U1, U10, U11, U12, U4, U5, U6};
93+
use num_rational::Ratio;
94+
95+
fn rational_compress<D: CompressionFactor>(input: u16) -> u16 {
96+
let fraction = Ratio::new(u32::from(input) * (1 << D::USIZE), FieldElement::Q32);
97+
(fraction.round().to_integer() as u16) & D::MASK
98+
}
99+
100+
fn rational_decompress<D: CompressionFactor>(input: u16) -> u16 {
101+
let fraction = Ratio::new(u32::from(input) * FieldElement::Q32, 1 << D::USIZE);
102+
fraction.round().to_integer() as u16
103+
}
87104

88-
// Verify that the integer compression routine produces the same results as rounding with
89-
// floats.
90-
fn compression_known_answer_test<D: CompressionFactor>() {
91-
let fq: f64 = FieldElement::Q as f64;
92-
let f2d: f64 = 2.0_f64.powi(D::I32);
105+
// Verify against inequality 4.7
106+
#[allow(clippy::integer_division_remainder_used)]
107+
fn compression_decompression_inequality<D: CompressionFactor>() {
108+
const QI32: i32 = FieldElement::Q as i32;
109+
let error_threshold = Ratio::new(FieldElement::Q, 1 << D::USIZE).to_integer() as i32;
93110

94111
for x in 0..FieldElement::Q {
95-
let fx = x as f64;
96-
let mut x = FieldElement(x);
112+
let mut y = FieldElement(x);
113+
y.compress::<D>();
114+
y.decompress::<D>();
115+
116+
let mut error = i32::from(y.0) - i32::from(x) + QI32;
117+
if error > (QI32 - 1) / 2 {
118+
error -= QI32;
119+
}
120+
121+
assert!(
122+
error.abs() <= error_threshold,
123+
"Inequality failed for x = {x}: error = {}, error_threshold = {error_threshold}, D = {:?}",
124+
error.abs(),
125+
D::USIZE
126+
);
127+
}
128+
}
97129

98-
// Verify equivalence of compression
99-
x.compress::<D>();
100-
let fcx = ((f2d / fq * fx).round() as Integer) % (1 << D::USIZE);
101-
assert_eq!(x.0, fcx);
130+
fn decompression_compression_equality<D: CompressionFactor>() {
131+
for x in 0..(1 << D::USIZE) {
132+
let mut y = FieldElement(x);
133+
y.decompress::<D>();
134+
y.compress::<D>();
102135

103-
// Verify equivalence of decompression
104-
x.decompress::<D>();
105-
let fdx = (fq / f2d * (fcx as f64)).round() as Integer;
106-
assert_eq!(x.0, fdx);
136+
assert_eq!(y.0, x, "failed for x: {}, D: {}", x, D::USIZE);
137+
}
138+
}
139+
140+
fn decompress_KAT<D: CompressionFactor>() {
141+
for y in 0..(1 << D::USIZE) {
142+
let x_expected = rational_decompress::<D>(y);
143+
let mut x_actual = FieldElement(y);
144+
x_actual.decompress::<D>();
145+
146+
assert_eq!(x_expected, x_actual.0);
147+
}
148+
}
149+
150+
fn compress_KAT<D: CompressionFactor>() {
151+
for x in 0..FieldElement::Q {
152+
let y_expected = rational_compress::<D>(x);
153+
let mut y_actual = FieldElement(x);
154+
y_actual.compress::<D>();
155+
156+
assert_eq!(y_expected, y_actual.0, "for x: {}, D: {}", x, D::USIZE);
107157
}
108158
}
109159

160+
fn compress_decompress_properties<D: CompressionFactor>() {
161+
compression_decompression_inequality::<D>();
162+
decompression_compression_equality::<D>();
163+
}
164+
165+
fn compress_decompress_KATs<D: CompressionFactor>() {
166+
decompress_KAT::<D>();
167+
compress_KAT::<D>();
168+
}
169+
110170
#[test]
111-
fn compress_decompress() {
112-
compression_known_answer_test::<U1>();
113-
compression_known_answer_test::<U4>();
114-
compression_known_answer_test::<U5>();
115-
compression_known_answer_test::<U6>();
116-
compression_known_answer_test::<U10>();
117-
compression_known_answer_test::<U11>();
118-
compression_known_answer_test::<U12>();
171+
fn decompress_compress() {
172+
compress_decompress_properties::<U1>();
173+
compress_decompress_properties::<U4>();
174+
compress_decompress_properties::<U5>();
175+
compress_decompress_properties::<U6>();
176+
compress_decompress_properties::<U10>();
177+
compress_decompress_properties::<U11>();
178+
// preservation under decompression first only holds for d < 12
179+
compression_decompression_inequality::<U12>();
180+
181+
compress_decompress_KATs::<U1>();
182+
compress_decompress_KATs::<U4>();
183+
compress_decompress_KATs::<U5>();
184+
compress_decompress_KATs::<U6>();
185+
compress_decompress_KATs::<U10>();
186+
compress_decompress_KATs::<U11>();
187+
compress_decompress_KATs::<U12>();
119188
}
120189
}

ml-kem/src/encode.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,13 @@ pub(crate) mod test {
165165
D: ArraySize + Rem<N>,
166166
Mod<D, N>: Zero,
167167
{
168+
#[allow(clippy::integer_division_remainder_used)]
168169
fn repeat(&self) -> Array<T, D> {
169170
Array::from_fn(|i| self[i % N::USIZE].clone())
170171
}
171172
}
172173

174+
#[allow(clippy::integer_division_remainder_used)]
173175
fn byte_codec_test<D>(decoded: DecodedValue, encoded: EncodedPolynomial<D>)
174176
where
175177
D: EncodingSize,
@@ -247,6 +249,7 @@ pub(crate) mod test {
247249
byte_codec_test::<U12>(decoded, encoded);
248250
}
249251

252+
#[allow(clippy::integer_division_remainder_used)]
250253
#[test]
251254
fn byte_codec_12_mod() {
252255
// DecodeBytes_12 is required to reduce mod q

ml-kem/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
77
)]
88
#![warn(clippy::pedantic)] // Be pedantic by default
9+
#![warn(clippy::integer_division_remainder_used)] // Be judicious about using `/` and `%`
910
#![allow(non_snake_case)] // Allow notation matching the spec
1011
#![allow(clippy::clone_on_copy)] // Be explicit about moving data
1112
#![deny(missing_docs)] // Require all public interfaces to be documented

ml-kem/src/param.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ where
117117
let mut x = 0usize;
118118
while x < max {
119119
let mut y = 0usize;
120+
#[allow(clippy::integer_division_remainder_used)]
120121
while y < max {
121122
let x_ones = x.count_ones() as u16;
122123
let y_ones = y.count_ones() as u16;

0 commit comments

Comments
 (0)