Skip to content

Test performance penalty for not using lazy reduction in matrix-vector mul #334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions mldsa/native/aarch64/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
/* Set of primitives that this backend replaces */
#define MLD_USE_NATIVE_NTT
#define MLD_USE_NATIVE_INTT
#define MLD_USE_NATIVE_POINTWISE

/* Identifier for this backend so that source and assembly files
* in the build can be appropriately guarded. */
Expand All @@ -31,6 +32,21 @@ static MLD_INLINE void mld_intt_native(int32_t data[MLDSA_N])
mld_aarch64_intt_zetas_layer123456);
}

static MLD_INLINE void mld_pointwise_montgomery_native(
int32_t out[MLDSA_N], const int32_t in0[MLDSA_N],
const int32_t in1[MLDSA_N])
{
mld_pointwise_montgomery_asm(out, in0, in1);
}

static MLD_INLINE void mld_pointwise_acc_montgomery_native(
int32_t out[MLDSA_N], const int32_t in0[MLDSA_N],
const int32_t in1[MLDSA_N])
{
mld_pointwise_acc_montgomery_asm(out, in0, in1);
}


#endif /* !__ASSEMBLER__ */

#endif /* !MLD_NATIVE_AARCH64_META_H */
8 changes: 8 additions & 0 deletions mldsa/native/aarch64/src/arith_native_aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,12 @@ void mld_ntt_asm(int32_t *, const int32_t *, const int32_t *);
#define mld_intt_asm MLD_NAMESPACE(intt_asm)
void mld_intt_asm(int32_t *, const int32_t *, const int32_t *);

#define mld_pointwise_montgomery_asm MLD_NAMESPACE(mld_pointwise_montgomery_asm)
void mld_pointwise_montgomery_asm(int32_t *, const int32_t *, const int32_t *);

#define mld_pointwise_acc_montgomery_asm \
MLD_NAMESPACE(mld_pointwise_acc_montgomery_asm)
void mld_pointwise_acc_montgomery_asm(int32_t *, const int32_t *,
const int32_t *);

#endif /* !MLD_NATIVE_AARCH64_SRC_ARITH_NATIVE_AARCH64_H */
129 changes: 129 additions & 0 deletions mldsa/native/aarch64/src/pointwise_acc_montgomery.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/* Copyright (c) The mldsa-native project authors
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
*/

#include "../../../common.h"
#if defined(MLD_ARITH_BACKEND_AARCH64)

.macro montgomery_reduce_long res, inl, inh
uzp1 t0.4s, \inl\().4s, \inh\().4s
mul t0.4s, t0.4s, modulus_twisted.4s
smlal \inl\().2d, t0.2s, modulus.2s
smlal2 \inh\().2d, t0.4s, modulus.4s
uzp2 \res\().4s, \inl\().4s, \inh\().4s
.endm


.macro pmull dl, dh, a, b
smull \dl\().2d, \a\().2s, \b\().2s
smull2 \dh\().2d, \a\().4s, \b\().4s
.endm

.macro pmlal dl, dh, a, b
smlal \dl\().2d, \a\().2s, \b\().2s
smlal2 \dh\().2d, \a\().4s, \b\().4s
.endm

.macro save_vregs
sub sp, sp, #(16*4)
stp d8, d9, [sp, #16*0]
stp d10, d11, [sp, #16*1]
stp d12, d13, [sp, #16*2]
stp d14, d15, [sp, #16*3]
.endm

.macro restore_vregs
ldp d8, d9, [sp, #16*0]
ldp d10, d11, [sp, #16*1]
ldp d12, d13, [sp, #16*2]
ldp d14, d15, [sp, #16*3]
add sp, sp, #(16*4)
.endm

.macro push_stack
save_vregs
.endm

.macro pop_stack
restore_vregs
.endm

out_ptr .req x0
a0_ptr .req x1
b0_ptr .req x2
count .req x3
wtmp .req w3

modulus .req v0
modulus_twisted .req v1

aa .req v2
bb .req v3
res .req v4
resl .req v5
resh .req v6
t0 .req v7
acc .req v8

q_aa .req q2
q_bb .req q3
q_res .req q4
q_acc .req q8

.text
.global MLD_ASM_NAMESPACE(mld_pointwise_acc_montgomery_asm)
.balign 4
MLD_ASM_FN_SYMBOL(mld_pointwise_acc_montgomery_asm)
push_stack

// load q = 8380417
movz wtmp, #57345
movk wtmp, #127, lsl #16
dup modulus.4s, wtmp

// load -q^-1 = 4236238847
movz wtmp, #57343
movk wtmp, #64639, lsl #16
dup modulus_twisted.4s, wtmp
mov count, #(MLDSA_N / 4)
loop_start:


ldr q_aa, [a0_ptr], #64
ldr q_bb, [b0_ptr], #64
ldr q_acc, [out_ptr]
pmull resl, resh, aa, bb
montgomery_reduce_long res, resl, resh
add res.4s, acc.4s, res.4s
str q_res, [out_ptr], #64

ldr q_aa, [a0_ptr, #-48]
ldr q_bb, [b0_ptr, #-48]
ldr q_acc, [out_ptr, #-48]
pmull resl, resh, aa, bb
montgomery_reduce_long res, resl, resh
add res.4s, acc.4s, res.4s
str q_res, [out_ptr, #-48]

ldr q_aa, [a0_ptr, #-32]
ldr q_bb, [b0_ptr, #-32]
ldr q_acc, [out_ptr, #-32]
pmull resl, resh, aa, bb
montgomery_reduce_long res, resl, resh
add res.4s, acc.4s, res.4s
str q_res, [out_ptr, #-32]

ldr q_aa, [a0_ptr, #-16]
ldr q_bb, [b0_ptr, #-16]
ldr q_acc, [out_ptr, #-16]
pmull resl, resh, aa, bb
montgomery_reduce_long res, resl, resh
add res.4s, acc.4s, res.4s
str q_res, [out_ptr, #-16]

subs count, count, #4
cbnz count, loop_start

pop_stack
ret
#endif /* MLD_ARITH_BACKEND_AARCH64 */
119 changes: 119 additions & 0 deletions mldsa/native/aarch64/src/pointwise_montgomery.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/* Copyright (c) The mldsa-native project authors
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
*/

#include "../../../common.h"
#if defined(MLD_ARITH_BACKEND_AARCH64)

.macro montgomery_reduce_long res, inl, inh
uzp1 t0.4s, \inl\().4s, \inh\().4s
mul t0.4s, t0.4s, modulus_twisted.4s
smlal \inl\().2d, t0.2s, modulus.2s
smlal2 \inh\().2d, t0.4s, modulus.4s
uzp2 \res\().4s, \inl\().4s, \inh\().4s
.endm


.macro pmull dl, dh, a, b
smull \dl\().2d, \a\().2s, \b\().2s
smull2 \dh\().2d, \a\().4s, \b\().4s
.endm

.macro pmlal dl, dh, a, b
smlal \dl\().2d, \a\().2s, \b\().2s
smlal2 \dh\().2d, \a\().4s, \b\().4s
.endm

.macro save_vregs
sub sp, sp, #(16*4)
stp d8, d9, [sp, #16*0]
stp d10, d11, [sp, #16*1]
stp d12, d13, [sp, #16*2]
stp d14, d15, [sp, #16*3]
.endm

.macro restore_vregs
ldp d8, d9, [sp, #16*0]
ldp d10, d11, [sp, #16*1]
ldp d12, d13, [sp, #16*2]
ldp d14, d15, [sp, #16*3]
add sp, sp, #(16*4)
.endm

.macro push_stack
save_vregs
.endm

.macro pop_stack
restore_vregs
.endm

out_ptr .req x0
a0_ptr .req x1
b0_ptr .req x2
count .req x3
wtmp .req w3

modulus .req v0
modulus_twisted .req v1

aa .req v2
bb .req v3
res .req v4
resl .req v5
resh .req v6
t0 .req v7

q_aa .req q2
q_bb .req q3
q_res .req q4

.text
.global MLD_ASM_NAMESPACE(mld_pointwise_montgomery_asm)
.balign 4
MLD_ASM_FN_SYMBOL(mld_pointwise_montgomery_asm)
push_stack

// load q = 8380417
movz wtmp, #57345
movk wtmp, #127, lsl #16
dup modulus.4s, wtmp

// load -q^-1 = 4236238847
movz wtmp, #57343
movk wtmp, #64639, lsl #16
dup modulus_twisted.4s, wtmp
mov count, #(MLDSA_N / 4)
loop_start:


ldr q_aa, [a0_ptr], #64
ldr q_bb, [b0_ptr], #64
pmull resl, resh, aa, bb
montgomery_reduce_long res, resl, resh
str q_res, [out_ptr], #64

ldr q_aa, [a0_ptr, #-48]
ldr q_bb, [b0_ptr, #-48]
pmull resl, resh, aa, bb
montgomery_reduce_long res, resl, resh
str q_res, [out_ptr, #-48]

ldr q_aa, [a0_ptr, #-32]
ldr q_bb, [b0_ptr, #-32]
pmull resl, resh, aa, bb
montgomery_reduce_long res, resl, resh
str q_res, [out_ptr, #-32]

ldr q_aa, [a0_ptr, #-16]
ldr q_bb, [b0_ptr, #-16]
pmull resl, resh, aa, bb
montgomery_reduce_long res, resl, resh
str q_res, [out_ptr, #-16]

subs count, count, #4
cbnz count, loop_start

pop_stack
ret
#endif /* MLD_ARITH_BACKEND_AARCH64 */
38 changes: 37 additions & 1 deletion mldsa/native/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,43 @@ static MLD_INLINE void mld_poly_permute_bitrev_to_custom(int32_t p[MLDSA_N])
*
* Arguments: - uint32_t p[MLDSA_N]: pointer to in/output polynomial
**************************************************/
static MLD_INLINE void mld_intt_native(int16_t p[MLDSA_N])
static MLD_INLINE void mld_intt_native(int16_t p[MLDSA_N]);
#endif /* MLD_USE_NATIVE_INTT */

#if defined(MLD_USE_NATIVE_POINTWISE)
/*************************************************
* Name: mld_pointwise_montgomery_native
*
* Description: Pointwise multiplication of polynomials in NTT domain
* representation and multiplication of resulting polynomial
* by 2^{-32}.
*
* Arguments: - int32_t out[MLDSA_N]: pointer to output polynomial
* - const int32_t in0[MLDSA_N]: pointer to first input
*polynomial
* - const int32_t in1[MLDSA_N]: pointer to second input
*polynomial
**************************************************/
static MLD_INLINE void mld_pointwise_montgomery_native(
int32_t out[MLDSA_N], const int32_t in0[MLDSA_N],
const int32_t in1[MLDSA_N]);

/*************************************************
* Name: mld_pointwise_acc_montgomery_native
*
* Description: Pointwise multiplication of polynomials in NTT domain
* representation and multiplication of resulting polynomial
* by 2^{-32}, with accumulation.
*
* Arguments: - int32_t out[MLDSA_N]: pointer to output polynomial
* - const int32_t in0[MLDSA_N]: pointer to first input polynomial
* - const int32_t in1[MLDSA_N]: pointer to second input polynomial
*
* Note: out = out + in0 * in1 * 2^{-32}
**************************************************/
static MLD_INLINE void mld_pointwise_acc_montgomery_native(
int32_t out[MLDSA_N], const int32_t in0[MLDSA_N],
const int32_t in1[MLDSA_N]);
#endif /* MLD_USE_NATIVE_POINTWISE */

#endif /* !MLD_NATIVE_API_H */
17 changes: 17 additions & 0 deletions mldsa/native/x86_64/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLD_USE_NATIVE_NTT_CUSTOM_ORDER
#define MLD_USE_NATIVE_NTT
#define MLD_USE_NATIVE_INTT
#define MLD_USE_NATIVE_POINTWISE

#if !defined(__ASSEMBLER__)
#include <string.h>
Expand All @@ -34,6 +35,22 @@ static MLD_INLINE void mld_intt_native(int32_t data[MLDSA_N])
mld_invntt_avx2((__m256i *)data, mld_qdata.vec);
}

static MLD_INLINE void mld_pointwise_montgomery_native(
int32_t out[MLDSA_N], const int32_t in0[MLDSA_N],
const int32_t in1[MLDSA_N])
{
mld_pointwise_montgomery_avx2((__m256i *)out, (const __m256i *)in0,
(const __m256i *)in1, mld_qdata.vec);
}

static MLD_INLINE void mld_pointwise_acc_montgomery_native(
int32_t out[MLDSA_N], const int32_t in0[MLDSA_N],
const int32_t in1[MLDSA_N])
{
mld_pointwise_acc_montgomery_avx2((__m256i *)out, (const __m256i *)in0,
(const __m256i *)in1, mld_qdata.vec);
}

#endif /* !__ASSEMBLER__ */

#endif /* !MLD_NATIVE_X86_64_META_H */
11 changes: 11 additions & 0 deletions mldsa/native/x86_64/src/arith_native_x86_64.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,15 @@ void mld_invntt_avx2(__m256i *r, const __m256i *mld_qdata);
#define mld_nttunpack_avx2 MLD_NAMESPACE(nttunpack_avx2)
void mld_nttunpack_avx2(__m256i *r);

#define mld_pointwise_montgomery_avx2 \
MLD_NAMESPACE(mld_pointwise_montgomery_avx2)
void mld_pointwise_montgomery_avx2(__m256i *r, const __m256i *a,
const __m256i *b, const __m256i *mld_qdata);

#define mld_pointwise_acc_montgomery_avx2 \
MLD_NAMESPACE(mld_pointwise_acc_montgomery_avx2)
void mld_pointwise_acc_montgomery_avx2(__m256i *r, const __m256i *a,
const __m256i *b,
const __m256i *mld_qdata);

#endif /* !MLD_NATIVE_X86_64_SRC_ARITH_NATIVE_X86_64_H */
Loading
Loading