Skip to content

Commit c408bba

Browse files
committed
AVX2: Add native implementation of poly_reduce and poly_caddq
Signed-off-by: Jake Massimo <[email protected]>
1 parent d526633 commit c408bba

File tree

5 files changed

+151
-13
lines changed

5 files changed

+151
-13
lines changed

mldsa/native/api.h

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,23 +80,49 @@ set if there are native implementations for NTT and INTT."
8080
* Arguments: - int32_t p[MLDSA_N]: pointer to in/output polynomial
8181
*
8282
**************************************************/
83-
static MLD_INLINE void mld_poly_permute_bitrev_to_custom(int32_t p[MLDSA_N])
83+
static MLD_INLINE void mld_poly_permute_bitrev_to_custom(int32_t p[MLDSA_N]);
8484
#endif /* MLD_USE_NATIVE_NTT_CUSTOM_ORDER */
8585

8686

8787
#if defined(MLD_USE_NATIVE_INTT)
88-
/*************************************************
89-
* Name: mld_intt_native
90-
*
91-
* Description: Computes inverse of negacyclic number-theoretic transform
92-
*(NTT) of a polynomial in place.
93-
*
94-
* The input polynomial is in bitreversed order.
95-
* The output polynomial is assumed to be in normal order.
96-
*
97-
* Arguments: - uint32_t p[MLDSA_N]: pointer to in/output polynomial
98-
**************************************************/
99-
static MLD_INLINE void mld_intt_native(int16_t p[MLDSA_N])
88+
/*************************************************
89+
* Name: mld_intt_native
90+
*
91+
* Description: Computes inverse of negacyclic number-theoretic transform
92+
*(NTT) of a polynomial in place.
93+
*
94+
* The input polynomial is in bitreversed order.
95+
* The output polynomial is assumed to be in normal order.
96+
*
97+
* Arguments: - uint32_t p[MLDSA_N]: pointer to in/output polynomial
98+
**************************************************/
99+
static MLD_INLINE void mld_intt_native(int32_t p[MLDSA_N]);
100100
#endif /* MLD_USE_NATIVE_INTT */
101101

102+
#if defined(MLD_USE_NATIVE_POLY_REDUCE)
103+
/*************************************************
104+
* Name: mld_poly_reduce_native
105+
*
106+
* Description: Inplace reduction of all coefficients of polynomial to
107+
* representative in [-6283009,6283008]. Assumes input
108+
* coefficients to be at most 2^31 - 2^22 - 1 in absolute
109+
*value.
110+
*
111+
* Arguments: - int32_t *a: pointer to input/output polynomial
112+
**************************************************/
113+
static MLD_INLINE void mld_poly_reduce_native(int32_t a[MLDSA_N]);
114+
#endif /* MLD_USE_NATIVE_POLY_REDUCE */
115+
116+
#if defined(MLD_USE_NATIVE_POLY_CADDQ)
117+
/*************************************************
118+
* Name: mld_poly_caddq_native
119+
*
120+
* Description: For all coefficients of in/out polynomial add Q if
121+
* coefficient is negative.
122+
*
123+
* Arguments: - int32_t *a: pointer to input/output polynomial
124+
**************************************************/
125+
static MLD_INLINE void mld_poly_caddq_native(int32_t a[MLDSA_N]);
126+
#endif /* MLD_USE_NATIVE_POLY_CADDQ */
127+
102128
#endif /* !MLD_NATIVE_API_H */

mldsa/native/x86_64/meta.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#define MLD_USE_NATIVE_NTT_CUSTOM_ORDER
1515
#define MLD_USE_NATIVE_NTT
1616
#define MLD_USE_NATIVE_INTT
17+
#define MLD_USE_NATIVE_POLY_REDUCE
18+
#define MLD_USE_NATIVE_POLY_CADDQ
1719

1820
#if !defined(__ASSEMBLER__)
1921
#include <string.h>
@@ -34,6 +36,16 @@ static MLD_INLINE void mld_intt_native(int32_t data[MLDSA_N])
3436
mld_invntt_avx2((__m256i *)data, mld_qdata.vec);
3537
}
3638

39+
static MLD_INLINE void mld_poly_reduce_native(int32_t a[MLDSA_N])
40+
{
41+
mld_poly_reduce_avx2(a);
42+
}
43+
44+
static MLD_INLINE void mld_poly_caddq_native(int32_t a[MLDSA_N])
45+
{
46+
mld_poly_caddq_avx2(a);
47+
}
48+
3749
#endif /* !__ASSEMBLER__ */
3850

3951
#endif /* !MLD_NATIVE_X86_64_META_H */

mldsa/native/x86_64/src/arith_native_x86_64.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,10 @@ void mld_invntt_avx2(__m256i *r, const __m256i *mld_qdata);
1919
#define mld_nttunpack_avx2 MLD_NAMESPACE(nttunpack_avx2)
2020
void mld_nttunpack_avx2(__m256i *r);
2121

22+
#define mld_poly_reduce_avx2 MLD_NAMESPACE(poly_reduce_avx2)
23+
void mld_poly_reduce_avx2(int32_t *r);
24+
25+
#define mld_poly_caddq_avx2 MLD_NAMESPACE(poly_caddq_avx2)
26+
void mld_poly_caddq_avx2(int32_t *r);
27+
2228
#endif /* !MLD_NATIVE_X86_64_SRC_ARITH_NATIVE_X86_64_H */

mldsa/native/x86_64/src/reduce_avx2.c

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright (c) The mldsa-native project authors
3+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
4+
*/
5+
6+
/*
7+
* This file is derived from the public domain
8+
* AVX2 Dilithium implementation @[REF_AVX2].
9+
*/
10+
11+
#include "../../../common.h"
12+
13+
#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \
14+
!defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)
15+
16+
#include <immintrin.h>
17+
#include <stdint.h>
18+
#include "../../../reduce.h"
19+
#include "arith_native_x86_64.h"
20+
#include "consts.h"
21+
22+
/*************************************************
23+
* Name: mld_poly_reduce_avx2
24+
*
25+
* Description: Inplace reduction of all coefficients of polynomial to
26+
* representative in [-6283009,6283008]. Assumes input
27+
* coefficients to be at most 2^31 - 2^22 - 1 in absolute value.
28+
*
29+
* Arguments: - int32_t *r: pointer to input/output polynomial
30+
**************************************************/
31+
void mld_poly_reduce_avx2(int32_t *r)
32+
{
33+
unsigned int i;
34+
__m256i f, g;
35+
const __m256i q = _mm256_set1_epi32(MLDSA_Q);
36+
const __m256i off = _mm256_set1_epi32(1 << 22);
37+
__m256i *rr = (__m256i *)r;
38+
39+
for (i = 0; i < MLDSA_N / 8; i++)
40+
{
41+
f = _mm256_load_si256(&rr[i]);
42+
g = _mm256_add_epi32(f, off);
43+
g = _mm256_srai_epi32(g, 23);
44+
g = _mm256_mullo_epi32(g, q);
45+
f = _mm256_sub_epi32(f, g);
46+
_mm256_store_si256(&rr[i], f);
47+
}
48+
}
49+
50+
/*************************************************
51+
* Name: mld_poly_caddq_avx2
52+
*
53+
* Description: For all coefficients of in/out polynomial add Q if
54+
* coefficient is negative.
55+
*
56+
* Arguments: - int32_t *r: pointer to input/output polynomial
57+
**************************************************/
58+
void mld_poly_caddq_avx2(int32_t *r)
59+
{
60+
unsigned int i;
61+
__m256i f, g;
62+
const __m256i q = _mm256_set1_epi32(MLDSA_Q);
63+
const __m256i zero = _mm256_setzero_si256();
64+
__m256i *rr = (__m256i *)r;
65+
66+
for (i = 0; i < MLDSA_N / 8; i++)
67+
{
68+
f = _mm256_load_si256(&rr[i]);
69+
g = _mm256_cmpgt_epi32(zero, f);
70+
g = _mm256_and_si256(g, q);
71+
f = _mm256_add_epi32(f, g);
72+
_mm256_store_si256(&rr[i], f);
73+
}
74+
}
75+
76+
#else /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \
77+
*/
78+
79+
MLD_EMPTY_CU(avx2_reduce)
80+
81+
#endif /* !(MLD_ARITH_BACKEND_X86_64_DEFAULT && \
82+
!MLD_CONFIG_MULTILEVEL_NO_SHARED) */

mldsa/poly.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@
1616

1717
void poly_reduce(poly *a)
1818
{
19+
#if !defined(MLD_USE_NATIVE_POLY_REDUCE)
1920
unsigned int i;
21+
#endif
2022
/* TODO: Introduce the following after using inclusive lower bounds in
2123
* the underlying debug function mld_debug_check_bounds(). */
2224
/* mld_assert_bound(a->coeffs, MLDSA_N, INT32_MIN, REDUCE_DOMAIN_MAX); */
2325

26+
#if defined(MLD_USE_NATIVE_POLY_REDUCE)
27+
mld_poly_reduce_native(a->coeffs);
28+
#else
2429
for (i = 0; i < MLDSA_N; ++i)
2530
__loop__(
2631
invariant(i <= MLDSA_N)
@@ -29,15 +34,21 @@ void poly_reduce(poly *a)
2934
{
3035
a->coeffs[i] = reduce32(a->coeffs[i]);
3136
}
37+
#endif /* !MLD_USE_NATIVE_POLY_REDUCE */
3238

3339
mld_assert_bound(a->coeffs, MLDSA_N, -REDUCE_RANGE_MAX, REDUCE_RANGE_MAX);
3440
}
3541

3642
void poly_caddq(poly *a)
3743
{
44+
#if !defined(MLD_USE_NATIVE_POLY_CADDQ)
3845
unsigned int i;
46+
#endif
3947
mld_assert_abs_bound(a->coeffs, MLDSA_N, MLDSA_Q);
4048

49+
#if defined(MLD_USE_NATIVE_POLY_CADDQ)
50+
mld_poly_caddq_native(a->coeffs);
51+
#else
4152
for (i = 0; i < MLDSA_N; ++i)
4253
__loop__(
4354
invariant(i <= MLDSA_N)
@@ -47,6 +58,7 @@ void poly_caddq(poly *a)
4758
{
4859
a->coeffs[i] = caddq(a->coeffs[i]);
4960
}
61+
#endif /* !MLD_USE_NATIVE_POLY_CADDQ */
5062

5163
mld_assert_bound(a->coeffs, MLDSA_N, 0, MLDSA_Q);
5264
}

0 commit comments

Comments
 (0)