From e06f5fc9cd0accac83d6af95fd537916bf6c2aeb Mon Sep 17 00:00:00 2001 From: "Matthias J. Kannwischer" Date: Wed, 2 Jul 2025 11:33:22 +0800 Subject: [PATCH 1/5] AArch64: Native poly_pointwise_montgomery Signed-off-by: Matthias J. Kannwischer --- mldsa/native/aarch64/meta.h | 8 ++ .../native/aarch64/src/arith_native_aarch64.h | 3 + .../native/aarch64/src/pointwise_montgomery.S | 119 ++++++++++++++++++ mldsa/native/api.h | 20 +++ mldsa/poly.c | 7 ++ 5 files changed, 157 insertions(+) create mode 100644 mldsa/native/aarch64/src/pointwise_montgomery.S diff --git a/mldsa/native/aarch64/meta.h b/mldsa/native/aarch64/meta.h index 6b044a93..64c88d05 100644 --- a/mldsa/native/aarch64/meta.h +++ b/mldsa/native/aarch64/meta.h @@ -10,6 +10,7 @@ /* Set of primitives that this backend replaces */ #define MLD_USE_NATIVE_NTT #define MLD_USE_NATIVE_INTT +#define MLD_USE_NATIVE_POINTWISE /* Identifier for this backend so that source and assembly files * in the build can be appropriately guarded. */ @@ -31,6 +32,13 @@ static MLD_INLINE void mld_intt_native(int32_t data[MLDSA_N]) mld_aarch64_intt_zetas_layer123456); } +static MLD_INLINE void mld_pointwise_montgomery_native( + int32_t out[MLDSA_N], const int32_t in0[MLDSA_N], + const int32_t in1[MLDSA_N]) +{ + mld_pointwise_montgomery_asm(out, in0, in1); +} + #endif /* !__ASSEMBLER__ */ #endif /* !MLD_NATIVE_AARCH64_META_H */ diff --git a/mldsa/native/aarch64/src/arith_native_aarch64.h b/mldsa/native/aarch64/src/arith_native_aarch64.h index d3528e6f..3acc873c 100644 --- a/mldsa/native/aarch64/src/arith_native_aarch64.h +++ b/mldsa/native/aarch64/src/arith_native_aarch64.h @@ -32,4 +32,7 @@ void mld_ntt_asm(int32_t *, const int32_t *, const int32_t *); #define mld_intt_asm MLD_NAMESPACE(intt_asm) void mld_intt_asm(int32_t *, const int32_t *, const int32_t *); +#define mld_pointwise_montgomery_asm MLD_NAMESPACE(mld_pointwise_montgomery_asm) +void mld_pointwise_montgomery_asm(int32_t *, const int32_t *, const int32_t *); + #endif /* !MLD_NATIVE_AARCH64_SRC_ARITH_NATIVE_AARCH64_H */ diff --git a/mldsa/native/aarch64/src/pointwise_montgomery.S b/mldsa/native/aarch64/src/pointwise_montgomery.S new file mode 100644 index 00000000..bfa1d7ad --- /dev/null +++ b/mldsa/native/aarch64/src/pointwise_montgomery.S @@ -0,0 +1,119 @@ +/* Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_AARCH64) + +.macro montgomery_reduce_long res, inl, inh + uzp1 t0.4s, \inl\().4s, \inh\().4s + mul t0.4s, t0.4s, modulus_twisted.4s + smlal \inl\().2d, t0.2s, modulus.2s + smlal2 \inh\().2d, t0.4s, modulus.4s + uzp2 \res\().4s, \inl\().4s, \inh\().4s +.endm + + +.macro pmull dl, dh, a, b + smull \dl\().2d, \a\().2s, \b\().2s + smull2 \dh\().2d, \a\().4s, \b\().4s +.endm + +.macro pmlal dl, dh, a, b + smlal \dl\().2d, \a\().2s, \b\().2s + smlal2 \dh\().2d, \a\().4s, \b\().4s +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + +out_ptr .req x0 +a0_ptr .req x1 +b0_ptr .req x2 +count .req x3 +wtmp .req w3 + +modulus .req v0 +modulus_twisted .req v1 + +aa .req v2 +bb .req v3 +res .req v4 +resl .req v5 +resh .req v6 +t0 .req v7 + +q_aa .req q2 +q_bb .req q3 +q_res .req q4 + +.text +.global MLD_ASM_NAMESPACE(mld_pointwise_montgomery_asm) +.balign 4 +MLD_ASM_FN_SYMBOL(mld_pointwise_montgomery_asm) + push_stack + + // load q = 8380417 + movz wtmp, #57345 + movk wtmp, #127, lsl #16 + dup modulus.4s, wtmp + + // load -q^-1 = 4236238847 + movz wtmp, #57343 + movk wtmp, #64639, lsl #16 + dup modulus_twisted.4s, wtmp + mov count, #(MLDSA_N / 4) +loop_start: + + + ldr q_aa, [a0_ptr], #64 + ldr q_bb, [b0_ptr], #64 + pmull resl, resh, aa, bb + montgomery_reduce_long res, resl, resh + str q_res, [out_ptr], #64 + + ldr q_aa, [a0_ptr, #-48] + ldr q_bb, [b0_ptr, #-48] + pmull resl, resh, aa, bb + montgomery_reduce_long res, resl, resh + str q_res, [out_ptr, #-48] + + ldr q_aa, [a0_ptr, #-32] + ldr q_bb, [b0_ptr, #-32] + pmull resl, resh, aa, bb + montgomery_reduce_long res, resl, resh + str q_res, [out_ptr, #-32] + + ldr q_aa, [a0_ptr, #-16] + ldr q_bb, [b0_ptr, #-16] + pmull resl, resh, aa, bb + montgomery_reduce_long res, resl, resh + str q_res, [out_ptr, #-16] + + subs count, count, #4 + cbnz count, loop_start + + pop_stack + ret +#endif /* MLD_ARITH_BACKEND_AARCH64 */ diff --git a/mldsa/native/api.h b/mldsa/native/api.h index eb3196db..7edc87a7 100644 --- a/mldsa/native/api.h +++ b/mldsa/native/api.h @@ -99,4 +99,24 @@ static MLD_INLINE void mld_poly_permute_bitrev_to_custom(int32_t p[MLDSA_N]) static MLD_INLINE void mld_intt_native(int16_t p[MLDSA_N]) #endif /* MLD_USE_NATIVE_INTT */ +#if defined(MLD_USE_NATIVE_POINTWISE) + /************************************************* + * Name: mld_pointwise_montgomery_native + * + * Description: Pointwise multiplication of polynomials in NTT domain + * representation and multiplication of resulting polynomial + * by 2^{-32}. + * + * Arguments: - int32_t out[MLDSA_N]: pointer to output polynomial + * - const int32_t in0[MLDSA_N]: pointer to first input + *polynomial + * - const int32_t in1[MLDSA_N]: pointer to second input + *polynomial + **************************************************/ + static MLD_INLINE + void mld_pointwise_montgomery_native(int32_t out[MLDSA_N], + const int32_t in0[MLDSA_N], + const int32_t in1[MLDSA_N]); +#endif /* MLD_USE_NATIVE_POINTWISE */ + #endif /* !MLD_NATIVE_API_H */ diff --git a/mldsa/poly.c b/mldsa/poly.c index 1d923782..a82ccc28 100644 --- a/mldsa/poly.c +++ b/mldsa/poly.c @@ -134,6 +134,7 @@ void poly_invntt_tomont(poly *a) } #endif /* MLD_USE_NATIVE_INTT */ +#if !defined(MLD_USE_NATIVE_POINTWISE) void poly_pointwise_montgomery(poly *c, const poly *a, const poly *b) { unsigned int i; @@ -145,6 +146,12 @@ void poly_pointwise_montgomery(poly *c, const poly *a, const poly *b) c->coeffs[i] = montgomery_reduce((int64_t)a->coeffs[i] * b->coeffs[i]); } } +#else /* !MLD_USE_NATIVE_POINTWISE */ +void poly_pointwise_montgomery(poly *c, const poly *a, const poly *b) +{ + mld_pointwise_montgomery_native(c->coeffs, a->coeffs, b->coeffs); +} +#endif /* MLD_USE_NATIVE_POINTWISE */ void poly_power2round(poly *a1, poly *a0, const poly *a) { From bf4f33d06043632a6f47e9829fc151c92d7e234e Mon Sep 17 00:00:00 2001 From: "Matthias J. Kannwischer" Date: Wed, 2 Jul 2025 11:55:06 +0800 Subject: [PATCH 2/5] AVX2: Native poly_pointwise_montgomery Signed-off-by: Matthias J. Kannwischer --- mldsa/native/x86_64/meta.h | 9 ++ mldsa/native/x86_64/src/arith_native_x86_64.h | 5 + mldsa/native/x86_64/src/pointwise.S | 124 ++++++++++++++++++ 3 files changed, 138 insertions(+) create mode 100644 mldsa/native/x86_64/src/pointwise.S diff --git a/mldsa/native/x86_64/meta.h b/mldsa/native/x86_64/meta.h index 373cf12b..87c21ff8 100644 --- a/mldsa/native/x86_64/meta.h +++ b/mldsa/native/x86_64/meta.h @@ -14,6 +14,7 @@ #define MLD_USE_NATIVE_NTT_CUSTOM_ORDER #define MLD_USE_NATIVE_NTT #define MLD_USE_NATIVE_INTT +#define MLD_USE_NATIVE_POINTWISE #if !defined(__ASSEMBLER__) #include @@ -34,6 +35,14 @@ static MLD_INLINE void mld_intt_native(int32_t data[MLDSA_N]) mld_invntt_avx2((__m256i *)data, mld_qdata.vec); } +static MLD_INLINE void mld_pointwise_montgomery_native( + int32_t out[MLDSA_N], const int32_t in0[MLDSA_N], + const int32_t in1[MLDSA_N]) +{ + mld_pointwise_montgomery_avx2((__m256i *)out, (const __m256i *)in0, + (const __m256i *)in1, mld_qdata.vec); +} + #endif /* !__ASSEMBLER__ */ #endif /* !MLD_NATIVE_X86_64_META_H */ diff --git a/mldsa/native/x86_64/src/arith_native_x86_64.h b/mldsa/native/x86_64/src/arith_native_x86_64.h index 602f7485..4058666d 100644 --- a/mldsa/native/x86_64/src/arith_native_x86_64.h +++ b/mldsa/native/x86_64/src/arith_native_x86_64.h @@ -19,4 +19,9 @@ void mld_invntt_avx2(__m256i *r, const __m256i *mld_qdata); #define mld_nttunpack_avx2 MLD_NAMESPACE(nttunpack_avx2) void mld_nttunpack_avx2(__m256i *r); +#define mld_pointwise_montgomery_avx2 \ + MLD_NAMESPACE(mld_pointwise_montgomery_avx2) +void mld_pointwise_montgomery_avx2(__m256i *r, const __m256i *a, + const __m256i *b, const __m256i *mld_qdata); + #endif /* !MLD_NATIVE_X86_64_SRC_ARITH_NATIVE_X86_64_H */ diff --git a/mldsa/native/x86_64/src/pointwise.S b/mldsa/native/x86_64/src/pointwise.S new file mode 100644 index 00000000..213862cc --- /dev/null +++ b/mldsa/native/x86_64/src/pointwise.S @@ -0,0 +1,124 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + /* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ + !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) + +#include "consts.h" + +.text +.global MLD_ASM_NAMESPACE(mld_pointwise_montgomery_avx2) +MLD_ASM_FN_SYMBOL(mld_pointwise_montgomery_avx2) +#consts +vmovdqa MLD_AVX2_BACKEND_DATA_OFFSET_8XQINV*4(%rcx),%ymm0 +vmovdqa MLD_AVX2_BACKEND_DATA_OFFSET_8XQ*4(%rcx),%ymm1 + +xor %eax,%eax +_looptop1: +#load +vmovdqa (%rsi),%ymm2 +vmovdqa 32(%rsi),%ymm4 +vmovdqa 64(%rsi),%ymm6 +vmovdqa (%rdx),%ymm10 +vmovdqa 32(%rdx),%ymm12 +vmovdqa 64(%rdx),%ymm14 +vpsrlq $32,%ymm2,%ymm3 +vpsrlq $32,%ymm4,%ymm5 +vmovshdup %ymm6,%ymm7 +vpsrlq $32,%ymm10,%ymm11 +vpsrlq $32,%ymm12,%ymm13 +vmovshdup %ymm14,%ymm15 + +#mul +vpmuldq %ymm2,%ymm10,%ymm2 +vpmuldq %ymm3,%ymm11,%ymm3 +vpmuldq %ymm4,%ymm12,%ymm4 +vpmuldq %ymm5,%ymm13,%ymm5 +vpmuldq %ymm6,%ymm14,%ymm6 +vpmuldq %ymm7,%ymm15,%ymm7 + +#reduce +vpmuldq %ymm0,%ymm2,%ymm10 +vpmuldq %ymm0,%ymm3,%ymm11 +vpmuldq %ymm0,%ymm4,%ymm12 +vpmuldq %ymm0,%ymm5,%ymm13 +vpmuldq %ymm0,%ymm6,%ymm14 +vpmuldq %ymm0,%ymm7,%ymm15 +vpmuldq %ymm1,%ymm10,%ymm10 +vpmuldq %ymm1,%ymm11,%ymm11 +vpmuldq %ymm1,%ymm12,%ymm12 +vpmuldq %ymm1,%ymm13,%ymm13 +vpmuldq %ymm1,%ymm14,%ymm14 +vpmuldq %ymm1,%ymm15,%ymm15 +vpsubq %ymm10,%ymm2,%ymm2 +vpsubq %ymm11,%ymm3,%ymm3 +vpsubq %ymm12,%ymm4,%ymm4 +vpsubq %ymm13,%ymm5,%ymm5 +vpsubq %ymm14,%ymm6,%ymm6 +vpsubq %ymm15,%ymm7,%ymm7 +vpsrlq $32,%ymm2,%ymm2 +vpsrlq $32,%ymm4,%ymm4 +vmovshdup %ymm6,%ymm6 + +#store +vpblendd $0xAA,%ymm3,%ymm2,%ymm2 +vpblendd $0xAA,%ymm5,%ymm4,%ymm4 +vpblendd $0xAA,%ymm7,%ymm6,%ymm6 +vmovdqa %ymm2,(%rdi) +vmovdqa %ymm4,32(%rdi) +vmovdqa %ymm6,64(%rdi) + +add $96,%rdi +add $96,%rsi +add $96,%rdx +add $1,%eax +cmp $10,%eax +jb _looptop1 + +vmovdqa (%rsi),%ymm2 +vmovdqa 32(%rsi),%ymm4 +vmovdqa (%rdx),%ymm10 +vmovdqa 32(%rdx),%ymm12 +vpsrlq $32,%ymm2,%ymm3 +vpsrlq $32,%ymm4,%ymm5 +vmovshdup %ymm10,%ymm11 +vmovshdup %ymm12,%ymm13 + +#mul +vpmuldq %ymm2,%ymm10,%ymm2 +vpmuldq %ymm3,%ymm11,%ymm3 +vpmuldq %ymm4,%ymm12,%ymm4 +vpmuldq %ymm5,%ymm13,%ymm5 + +#reduce +vpmuldq %ymm0,%ymm2,%ymm10 +vpmuldq %ymm0,%ymm3,%ymm11 +vpmuldq %ymm0,%ymm4,%ymm12 +vpmuldq %ymm0,%ymm5,%ymm13 +vpmuldq %ymm1,%ymm10,%ymm10 +vpmuldq %ymm1,%ymm11,%ymm11 +vpmuldq %ymm1,%ymm12,%ymm12 +vpmuldq %ymm1,%ymm13,%ymm13 +vpsubq %ymm10,%ymm2,%ymm2 +vpsubq %ymm11,%ymm3,%ymm3 +vpsubq %ymm12,%ymm4,%ymm4 +vpsubq %ymm13,%ymm5,%ymm5 +vpsrlq $32,%ymm2,%ymm2 +vmovshdup %ymm4,%ymm4 + +#store +vpblendd $0x55,%ymm2,%ymm3,%ymm2 +vpblendd $0x55,%ymm4,%ymm5,%ymm4 +vmovdqa %ymm2,(%rdi) +vmovdqa %ymm4,32(%rdi) + +ret + +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED */ From 1615f141e8bdf93125a6f91fd1b7247ab0426331 Mon Sep 17 00:00:00 2001 From: "Matthias J. Kannwischer" Date: Wed, 25 Jun 2025 13:02:05 +0800 Subject: [PATCH 3/5] Test performance penalty for not using lazy reduction in matrix-vector mul Signed-off-by: Matthias J. Kannwischer --- mldsa/poly.c | 13 +++++++++++++ mldsa/poly.h | 11 +++++++++++ mldsa/polyvec.c | 49 ++++--------------------------------------------- 3 files changed, 28 insertions(+), 45 deletions(-) diff --git a/mldsa/poly.c b/mldsa/poly.c index a82ccc28..a614da95 100644 --- a/mldsa/poly.c +++ b/mldsa/poly.c @@ -153,6 +153,19 @@ void poly_pointwise_montgomery(poly *c, const poly *a, const poly *b) } #endif /* MLD_USE_NATIVE_POINTWISE */ +void poly_pointwise_acc_montgomery(poly *c, const poly *a, const poly *b) +{ + unsigned int i; + + for (i = 0; i < MLDSA_N; ++i) + __loop__( + invariant(i <= MLDSA_N)) + { + c->coeffs[i] += montgomery_reduce((int64_t)a->coeffs[i] * b->coeffs[i]); + } +} + + void poly_power2round(poly *a1, poly *a0, const poly *a) { unsigned int i; diff --git a/mldsa/poly.h b/mldsa/poly.h index dbc00a0c..baf54a51 100644 --- a/mldsa/poly.h +++ b/mldsa/poly.h @@ -176,6 +176,17 @@ __contract__( assigns(memory_slice(c, sizeof(poly))) ); + +#define poly_pointwise_acc_montgomery \ + MLD_NAMESPACE(poly_pointwise_acc_montgomery) +void poly_pointwise_acc_montgomery(poly *c, const poly *a, const poly *b) +__contract__( + requires(memory_no_alias(a, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(memory_no_alias(c, sizeof(poly))) + assigns(memory_slice(c, sizeof(poly))) +); + #define poly_power2round MLD_NAMESPACE(poly_power2round) /************************************************* * Name: poly_power2round diff --git a/mldsa/polyvec.c b/mldsa/polyvec.c index 8a9a81e0..e00e5f48 100644 --- a/mldsa/polyvec.c +++ b/mldsa/polyvec.c @@ -231,52 +231,11 @@ void polyvecl_pointwise_poly_montgomery(polyvecl *r, const poly *a, void polyvecl_pointwise_acc_montgomery(poly *w, const polyvecl *u, const polyvecl *v) { - unsigned int i, j; - /* The first input is bounded by [0, Q-1] inclusive - * The second input is bounded by [-9Q+1, 9Q-1] inclusive . Hence, we can - * safely accumulate in 64-bits without intermediate reductions as - * MLDSA_L * (MLD_NTT_BOUND-1) * (Q-1) < INT64_MAX - * - * The worst case is ML-DSA-87: 7 * (9Q-1) * (Q-1) < 2**52 - * (and likewise for negative values) - */ - - for (i = 0; i < MLDSA_N; i++) - __loop__( - assigns(i, j, object_whole(w)) - invariant(i <= MLDSA_N) - invariant(array_abs_bound(w->coeffs, 0, i, MLDSA_Q)) - ) + unsigned int i; + poly_pointwise_montgomery(w, &u->vec[0], &v->vec[0]); + for (i = 1; i < MLDSA_L; i++) { - int64_t t = 0; - int32_t r; - for (j = 0; j < MLDSA_L; j++) - __loop__( - assigns(j, t) - invariant(j <= MLDSA_L) - invariant(t >= -(int64_t)j*(MLDSA_Q - 1)*(MLD_NTT_BOUND - 1)) - invariant(t <= (int64_t)j*(MLDSA_Q - 1)*(MLD_NTT_BOUND - 1)) - ) - { - t += (int64_t)u->vec[j].coeffs[i] * v->vec[j].coeffs[i]; - } - - /* Substitute j == MLSDA_L into the loop invariant to get... */ - cassert(j == MLDSA_L); - cassert(t >= -(int64_t)MLDSA_L * (MLDSA_Q - 1) * (MLD_NTT_BOUND - 1)); - cassert(t <= (int64_t)MLDSA_L * (MLDSA_Q - 1) * (MLD_NTT_BOUND - 1)); - - /* ...and therefore... */ - cassert(t >= -MONTGOMERY_REDUCE_STRONG_DOMAIN_MAX); - cassert(t < MONTGOMERY_REDUCE_STRONG_DOMAIN_MAX); - - /* ...which meets the "strong" case of montgomery_reduce() */ - r = montgomery_reduce(t); - - /* ...and therefore we can assert a stronger bound on r */ - cassert(r > -MLDSA_Q); - cassert(r < MLDSA_Q); - w->coeffs[i] = r; + poly_pointwise_acc_montgomery(w, &u->vec[i], &v->vec[i]); } } From 3e07fc318884835011d50baf5418e9462c5d752c Mon Sep 17 00:00:00 2001 From: "Matthias J. Kannwischer" Date: Wed, 25 Jun 2025 13:35:05 +0800 Subject: [PATCH 4/5] aarch64: add basemul asm Signed-off-by: Matthias J. Kannwischer --- mldsa/native/aarch64/meta.h | 8 ++ .../native/aarch64/src/arith_native_aarch64.h | 5 + .../aarch64/src/pointwise_acc_montgomery.S | 129 ++++++++++++++++++ mldsa/poly.c | 7 + 4 files changed, 149 insertions(+) create mode 100644 mldsa/native/aarch64/src/pointwise_acc_montgomery.S diff --git a/mldsa/native/aarch64/meta.h b/mldsa/native/aarch64/meta.h index 64c88d05..144fa131 100644 --- a/mldsa/native/aarch64/meta.h +++ b/mldsa/native/aarch64/meta.h @@ -39,6 +39,14 @@ static MLD_INLINE void mld_pointwise_montgomery_native( mld_pointwise_montgomery_asm(out, in0, in1); } +static MLD_INLINE void mld_pointwise_acc_montgomery_native( + int32_t out[MLDSA_N], const int32_t in0[MLDSA_N], + const int32_t in1[MLDSA_N]) +{ + mld_pointwise_acc_montgomery_asm(out, in0, in1); +} + + #endif /* !__ASSEMBLER__ */ #endif /* !MLD_NATIVE_AARCH64_META_H */ diff --git a/mldsa/native/aarch64/src/arith_native_aarch64.h b/mldsa/native/aarch64/src/arith_native_aarch64.h index 3acc873c..3a99fa6b 100644 --- a/mldsa/native/aarch64/src/arith_native_aarch64.h +++ b/mldsa/native/aarch64/src/arith_native_aarch64.h @@ -35,4 +35,9 @@ void mld_intt_asm(int32_t *, const int32_t *, const int32_t *); #define mld_pointwise_montgomery_asm MLD_NAMESPACE(mld_pointwise_montgomery_asm) void mld_pointwise_montgomery_asm(int32_t *, const int32_t *, const int32_t *); +#define mld_pointwise_acc_montgomery_asm \ + MLD_NAMESPACE(mld_pointwise_acc_montgomery_asm) +void mld_pointwise_acc_montgomery_asm(int32_t *, const int32_t *, + const int32_t *); + #endif /* !MLD_NATIVE_AARCH64_SRC_ARITH_NATIVE_AARCH64_H */ diff --git a/mldsa/native/aarch64/src/pointwise_acc_montgomery.S b/mldsa/native/aarch64/src/pointwise_acc_montgomery.S new file mode 100644 index 00000000..6ba92062 --- /dev/null +++ b/mldsa/native/aarch64/src/pointwise_acc_montgomery.S @@ -0,0 +1,129 @@ +/* Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_AARCH64) + +.macro montgomery_reduce_long res, inl, inh + uzp1 t0.4s, \inl\().4s, \inh\().4s + mul t0.4s, t0.4s, modulus_twisted.4s + smlal \inl\().2d, t0.2s, modulus.2s + smlal2 \inh\().2d, t0.4s, modulus.4s + uzp2 \res\().4s, \inl\().4s, \inh\().4s +.endm + + +.macro pmull dl, dh, a, b + smull \dl\().2d, \a\().2s, \b\().2s + smull2 \dh\().2d, \a\().4s, \b\().4s +.endm + +.macro pmlal dl, dh, a, b + smlal \dl\().2d, \a\().2s, \b\().2s + smlal2 \dh\().2d, \a\().4s, \b\().4s +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + +out_ptr .req x0 +a0_ptr .req x1 +b0_ptr .req x2 +count .req x3 +wtmp .req w3 + +modulus .req v0 +modulus_twisted .req v1 + +aa .req v2 +bb .req v3 +res .req v4 +resl .req v5 +resh .req v6 +t0 .req v7 +acc .req v8 + +q_aa .req q2 +q_bb .req q3 +q_res .req q4 +q_acc .req q8 + +.text +.global MLD_ASM_NAMESPACE(mld_pointwise_acc_montgomery_asm) +.balign 4 +MLD_ASM_FN_SYMBOL(mld_pointwise_acc_montgomery_asm) + push_stack + + // load q = 8380417 + movz wtmp, #57345 + movk wtmp, #127, lsl #16 + dup modulus.4s, wtmp + + // load -q^-1 = 4236238847 + movz wtmp, #57343 + movk wtmp, #64639, lsl #16 + dup modulus_twisted.4s, wtmp + mov count, #(MLDSA_N / 4) +loop_start: + + + ldr q_aa, [a0_ptr], #64 + ldr q_bb, [b0_ptr], #64 + ldr q_acc, [out_ptr] + pmull resl, resh, aa, bb + montgomery_reduce_long res, resl, resh + add res.4s, acc.4s, res.4s + str q_res, [out_ptr], #64 + + ldr q_aa, [a0_ptr, #-48] + ldr q_bb, [b0_ptr, #-48] + ldr q_acc, [out_ptr, #-48] + pmull resl, resh, aa, bb + montgomery_reduce_long res, resl, resh + add res.4s, acc.4s, res.4s + str q_res, [out_ptr, #-48] + + ldr q_aa, [a0_ptr, #-32] + ldr q_bb, [b0_ptr, #-32] + ldr q_acc, [out_ptr, #-32] + pmull resl, resh, aa, bb + montgomery_reduce_long res, resl, resh + add res.4s, acc.4s, res.4s + str q_res, [out_ptr, #-32] + + ldr q_aa, [a0_ptr, #-16] + ldr q_bb, [b0_ptr, #-16] + ldr q_acc, [out_ptr, #-16] + pmull resl, resh, aa, bb + montgomery_reduce_long res, resl, resh + add res.4s, acc.4s, res.4s + str q_res, [out_ptr, #-16] + + subs count, count, #4 + cbnz count, loop_start + + pop_stack + ret +#endif /* MLD_ARITH_BACKEND_AARCH64 */ diff --git a/mldsa/poly.c b/mldsa/poly.c index a614da95..dc9577d9 100644 --- a/mldsa/poly.c +++ b/mldsa/poly.c @@ -153,6 +153,7 @@ void poly_pointwise_montgomery(poly *c, const poly *a, const poly *b) } #endif /* MLD_USE_NATIVE_POINTWISE */ +#if !defined(MLD_USE_NATIVE_POINTWISE) void poly_pointwise_acc_montgomery(poly *c, const poly *a, const poly *b) { unsigned int i; @@ -164,6 +165,12 @@ void poly_pointwise_acc_montgomery(poly *c, const poly *a, const poly *b) c->coeffs[i] += montgomery_reduce((int64_t)a->coeffs[i] * b->coeffs[i]); } } +#else /* !MLD_USE_NATIVE_POINTWISE */ +void poly_pointwise_acc_montgomery(poly *c, const poly *a, const poly *b) +{ + mld_pointwise_acc_montgomery_native(c->coeffs, a->coeffs, b->coeffs); +} +#endif /* MLD_USE_NATIVE_POINTWISE */ void poly_power2round(poly *a1, poly *a0, const poly *a) From 9c8e59218fbe32a51d12cc0b0c3b3745d231243e Mon Sep 17 00:00:00 2001 From: "Matthias J. Kannwischer" Date: Fri, 4 Jul 2025 17:05:50 +0800 Subject: [PATCH 5/5] AVX2: Add native pointwise_acc_montgomery2 Signed-off-by: Matthias J. Kannwischer --- mldsa/native/api.h | 52 ++++--- mldsa/native/x86_64/meta.h | 8 ++ mldsa/native/x86_64/src/arith_native_x86_64.h | 6 + mldsa/native/x86_64/src/pointwise_acc.S | 129 ++++++++++++++++++ 4 files changed, 177 insertions(+), 18 deletions(-) create mode 100644 mldsa/native/x86_64/src/pointwise_acc.S diff --git a/mldsa/native/api.h b/mldsa/native/api.h index 7edc87a7..5d7a87a1 100644 --- a/mldsa/native/api.h +++ b/mldsa/native/api.h @@ -96,27 +96,43 @@ static MLD_INLINE void mld_poly_permute_bitrev_to_custom(int32_t p[MLDSA_N]) * * Arguments: - uint32_t p[MLDSA_N]: pointer to in/output polynomial **************************************************/ - static MLD_INLINE void mld_intt_native(int16_t p[MLDSA_N]) + static MLD_INLINE void mld_intt_native(int16_t p[MLDSA_N]); #endif /* MLD_USE_NATIVE_INTT */ #if defined(MLD_USE_NATIVE_POINTWISE) - /************************************************* - * Name: mld_pointwise_montgomery_native - * - * Description: Pointwise multiplication of polynomials in NTT domain - * representation and multiplication of resulting polynomial - * by 2^{-32}. - * - * Arguments: - int32_t out[MLDSA_N]: pointer to output polynomial - * - const int32_t in0[MLDSA_N]: pointer to first input - *polynomial - * - const int32_t in1[MLDSA_N]: pointer to second input - *polynomial - **************************************************/ - static MLD_INLINE - void mld_pointwise_montgomery_native(int32_t out[MLDSA_N], - const int32_t in0[MLDSA_N], - const int32_t in1[MLDSA_N]); +/************************************************* + * Name: mld_pointwise_montgomery_native + * + * Description: Pointwise multiplication of polynomials in NTT domain + * representation and multiplication of resulting polynomial + * by 2^{-32}. + * + * Arguments: - int32_t out[MLDSA_N]: pointer to output polynomial + * - const int32_t in0[MLDSA_N]: pointer to first input + *polynomial + * - const int32_t in1[MLDSA_N]: pointer to second input + *polynomial + **************************************************/ +static MLD_INLINE void mld_pointwise_montgomery_native( + int32_t out[MLDSA_N], const int32_t in0[MLDSA_N], + const int32_t in1[MLDSA_N]); + +/************************************************* + * Name: mld_pointwise_acc_montgomery_native + * + * Description: Pointwise multiplication of polynomials in NTT domain + * representation and multiplication of resulting polynomial + * by 2^{-32}, with accumulation. + * + * Arguments: - int32_t out[MLDSA_N]: pointer to output polynomial + * - const int32_t in0[MLDSA_N]: pointer to first input polynomial + * - const int32_t in1[MLDSA_N]: pointer to second input polynomial + * + * Note: out = out + in0 * in1 * 2^{-32} + **************************************************/ +static MLD_INLINE void mld_pointwise_acc_montgomery_native( + int32_t out[MLDSA_N], const int32_t in0[MLDSA_N], + const int32_t in1[MLDSA_N]); #endif /* MLD_USE_NATIVE_POINTWISE */ #endif /* !MLD_NATIVE_API_H */ diff --git a/mldsa/native/x86_64/meta.h b/mldsa/native/x86_64/meta.h index 87c21ff8..2fae0c13 100644 --- a/mldsa/native/x86_64/meta.h +++ b/mldsa/native/x86_64/meta.h @@ -43,6 +43,14 @@ static MLD_INLINE void mld_pointwise_montgomery_native( (const __m256i *)in1, mld_qdata.vec); } +static MLD_INLINE void mld_pointwise_acc_montgomery_native( + int32_t out[MLDSA_N], const int32_t in0[MLDSA_N], + const int32_t in1[MLDSA_N]) +{ + mld_pointwise_acc_montgomery_avx2((__m256i *)out, (const __m256i *)in0, + (const __m256i *)in1, mld_qdata.vec); +} + #endif /* !__ASSEMBLER__ */ #endif /* !MLD_NATIVE_X86_64_META_H */ diff --git a/mldsa/native/x86_64/src/arith_native_x86_64.h b/mldsa/native/x86_64/src/arith_native_x86_64.h index 4058666d..e51ec572 100644 --- a/mldsa/native/x86_64/src/arith_native_x86_64.h +++ b/mldsa/native/x86_64/src/arith_native_x86_64.h @@ -24,4 +24,10 @@ void mld_nttunpack_avx2(__m256i *r); void mld_pointwise_montgomery_avx2(__m256i *r, const __m256i *a, const __m256i *b, const __m256i *mld_qdata); +#define mld_pointwise_acc_montgomery_avx2 \ + MLD_NAMESPACE(mld_pointwise_acc_montgomery_avx2) +void mld_pointwise_acc_montgomery_avx2(__m256i *r, const __m256i *a, + const __m256i *b, + const __m256i *mld_qdata); + #endif /* !MLD_NATIVE_X86_64_SRC_ARITH_NATIVE_X86_64_H */ diff --git a/mldsa/native/x86_64/src/pointwise_acc.S b/mldsa/native/x86_64/src/pointwise_acc.S new file mode 100644 index 00000000..b6f3f095 --- /dev/null +++ b/mldsa/native/x86_64/src/pointwise_acc.S @@ -0,0 +1,129 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + /* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ + !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) + +#include "consts.h" + +.text +.global MLD_ASM_NAMESPACE(mld_pointwise_acc_montgomery_avx2) +MLD_ASM_FN_SYMBOL(mld_pointwise_acc_montgomery_avx2) +#consts +vmovdqa MLD_AVX2_BACKEND_DATA_OFFSET_8XQINV*4(%rcx),%ymm0 +vmovdqa MLD_AVX2_BACKEND_DATA_OFFSET_8XQ*4(%rcx),%ymm1 + +xor %eax,%eax +_looptop1: +#load +vmovdqa (%rsi),%ymm2 +vmovdqa 32(%rsi),%ymm4 +vmovdqa 64(%rsi),%ymm6 +vmovdqa (%rdx),%ymm10 +vmovdqa 32(%rdx),%ymm12 +vmovdqa 64(%rdx),%ymm14 +vpsrlq $32,%ymm2,%ymm3 +vpsrlq $32,%ymm4,%ymm5 +vmovshdup %ymm6,%ymm7 +vpsrlq $32,%ymm10,%ymm11 +vpsrlq $32,%ymm12,%ymm13 +vmovshdup %ymm14,%ymm15 + +#mul +vpmuldq %ymm2,%ymm10,%ymm2 +vpmuldq %ymm3,%ymm11,%ymm3 +vpmuldq %ymm4,%ymm12,%ymm4 +vpmuldq %ymm5,%ymm13,%ymm5 +vpmuldq %ymm6,%ymm14,%ymm6 +vpmuldq %ymm7,%ymm15,%ymm7 + +#reduce +vpmuldq %ymm0,%ymm2,%ymm10 +vpmuldq %ymm0,%ymm3,%ymm11 +vpmuldq %ymm0,%ymm4,%ymm12 +vpmuldq %ymm0,%ymm5,%ymm13 +vpmuldq %ymm0,%ymm6,%ymm14 +vpmuldq %ymm0,%ymm7,%ymm15 +vpmuldq %ymm1,%ymm10,%ymm10 +vpmuldq %ymm1,%ymm11,%ymm11 +vpmuldq %ymm1,%ymm12,%ymm12 +vpmuldq %ymm1,%ymm13,%ymm13 +vpmuldq %ymm1,%ymm14,%ymm14 +vpmuldq %ymm1,%ymm15,%ymm15 +vpsubq %ymm10,%ymm2,%ymm2 +vpsubq %ymm11,%ymm3,%ymm3 +vpsubq %ymm12,%ymm4,%ymm4 +vpsubq %ymm13,%ymm5,%ymm5 +vpsubq %ymm14,%ymm6,%ymm6 +vpsubq %ymm15,%ymm7,%ymm7 +vpsrlq $32,%ymm2,%ymm2 +vpsrlq $32,%ymm4,%ymm4 +vmovshdup %ymm6,%ymm6 + +#store +vpblendd $0xAA,%ymm3,%ymm2,%ymm2 +vpblendd $0xAA,%ymm5,%ymm4,%ymm4 +vpblendd $0xAA,%ymm7,%ymm6,%ymm6 +vpaddd (%rdi),%ymm2,%ymm2 +vpaddd 32(%rdi),%ymm4,%ymm4 +vpaddd 64(%rdi),%ymm6,%ymm6 +vmovdqa %ymm2,(%rdi) +vmovdqa %ymm4,32(%rdi) +vmovdqa %ymm6,64(%rdi) + +add $96,%rdi +add $96,%rsi +add $96,%rdx +add $1,%eax +cmp $10,%eax +jb _looptop1 + +vmovdqa (%rsi),%ymm2 +vmovdqa 32(%rsi),%ymm4 +vmovdqa (%rdx),%ymm10 +vmovdqa 32(%rdx),%ymm12 +vpsrlq $32,%ymm2,%ymm3 +vpsrlq $32,%ymm4,%ymm5 +vmovshdup %ymm10,%ymm11 +vmovshdup %ymm12,%ymm13 + +#mul +vpmuldq %ymm2,%ymm10,%ymm2 +vpmuldq %ymm3,%ymm11,%ymm3 +vpmuldq %ymm4,%ymm12,%ymm4 +vpmuldq %ymm5,%ymm13,%ymm5 + +#reduce +vpmuldq %ymm0,%ymm2,%ymm10 +vpmuldq %ymm0,%ymm3,%ymm11 +vpmuldq %ymm0,%ymm4,%ymm12 +vpmuldq %ymm0,%ymm5,%ymm13 +vpmuldq %ymm1,%ymm10,%ymm10 +vpmuldq %ymm1,%ymm11,%ymm11 +vpmuldq %ymm1,%ymm12,%ymm12 +vpmuldq %ymm1,%ymm13,%ymm13 +vpsubq %ymm10,%ymm2,%ymm2 +vpsubq %ymm11,%ymm3,%ymm3 +vpsubq %ymm12,%ymm4,%ymm4 +vpsubq %ymm13,%ymm5,%ymm5 +vpsrlq $32,%ymm2,%ymm2 +vmovshdup %ymm4,%ymm4 + +#store +vpblendd $0x55,%ymm2,%ymm3,%ymm2 +vpblendd $0x55,%ymm4,%ymm5,%ymm4 +vpaddd (%rdi),%ymm2,%ymm2 +vpaddd 32(%rdi),%ymm4,%ymm4 +vmovdqa %ymm2,(%rdi) +vmovdqa %ymm4,32(%rdi) + +ret + +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED */