Skip to content

Commit f2d8abd

Browse files
authored
Merge pull request #348 from pq-code-package/aarch64-rej-eta-implementation
Add native imlementation of rej_eta (AArch64 + AVX2)
2 parents dc2f5f3 + 57493dd commit f2d8abd

File tree

12 files changed

+1701
-6
lines changed

12 files changed

+1701
-6
lines changed

mldsa/native/aarch64/meta.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#define MLD_USE_NATIVE_NTT
1212
#define MLD_USE_NATIVE_INTT
1313
#define MLD_USE_NATIVE_REJ_UNIFORM
14+
#define MLD_USE_NATIVE_REJ_UNIFORM_ETA2
15+
#define MLD_USE_NATIVE_REJ_UNIFORM_ETA4
1416

1517
/* Identifier for this backend so that source and assembly files
1618
* in the build can be appropriately guarded. */
@@ -43,6 +45,34 @@ static MLD_INLINE int mld_rej_uniform_native(int32_t *r, unsigned len,
4345
return (int)mld_rej_uniform_asm(r, buf, buflen, mld_rej_uniform_table);
4446
}
4547

48+
static MLD_INLINE int mld_rej_uniform_eta2_native(int32_t *r, unsigned len,
49+
const uint8_t *buf,
50+
unsigned buflen)
51+
{
52+
/* AArch64 implementation assumes specific buffer lengths */
53+
if (len != MLDSA_N || buflen != MLD_AARCH64_REJ_UNIFORM_ETA2_BUFLEN)
54+
{
55+
return -1;
56+
}
57+
58+
return (int)mld_rej_uniform_eta2_asm(r, buf, buflen,
59+
mld_rej_uniform_eta_table);
60+
}
61+
62+
static MLD_INLINE int mld_rej_uniform_eta4_native(int32_t *r, unsigned len,
63+
const uint8_t *buf,
64+
unsigned buflen)
65+
{
66+
/* AArch64 implementation assumes specific buffer lengths */
67+
if (len != MLDSA_N || buflen != MLD_AARCH64_REJ_UNIFORM_ETA4_BUFLEN)
68+
{
69+
return -1;
70+
}
71+
72+
return (int)mld_rej_uniform_eta4_asm(r, buf, buflen,
73+
mld_rej_uniform_eta_table);
74+
}
75+
4676
#endif /* !__ASSEMBLER__ */
4777

4878
#endif /* !MLD_NATIVE_AARCH64_META_H */

mldsa/native/aarch64/src/arith_native_aarch64.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,22 @@ extern const int32_t mld_aarch64_intt_zetas_layer78[];
2727
extern const int32_t mld_aarch64_intt_zetas_layer123456[];
2828

2929
extern const uint8_t mld_rej_uniform_table[];
30+
extern const uint8_t mld_rej_uniform_eta_table[];
31+
32+
33+
/*
34+
* Sampling 256 coefficients mod 15 using rejection sampling from 4 bits.
35+
* Expected number of required bytes: (256 * (16/15))/2 = 136.5 bytes.
36+
* We sample 1 block (=136 bytes) of SHAKE256_RATE output initially.
37+
* Sampling 2 blocks initially results in slightly worse performance.
38+
*/
39+
#define MLD_AARCH64_REJ_UNIFORM_ETA2_BUFLEN (1 * 136)
40+
/*
41+
* Sampling 256 coefficients mod 9 using rejection sampling from 4 bits.
42+
* Expected number of required bytes: (256 * (16/9))/2 = 227.5 bytes.
43+
* We sample 2 blocks (=272 bytes) of SHAKE256_RATE output initially.
44+
*/
45+
#define MLD_AARCH64_REJ_UNIFORM_ETA4_BUFLEN (2 * 136)
3046

3147
#define mld_ntt_asm MLD_NAMESPACE(ntt_asm)
3248
void mld_ntt_asm(int32_t *, const int32_t *, const int32_t *);
@@ -38,4 +54,12 @@ void mld_intt_asm(int32_t *, const int32_t *, const int32_t *);
3854
uint64_t mld_rej_uniform_asm(int32_t *r, const uint8_t *buf, unsigned buflen,
3955
const uint8_t *table);
4056

57+
#define mld_rej_uniform_eta2_asm MLD_NAMESPACE(rej_uniform_eta2_asm)
58+
unsigned mld_rej_uniform_eta2_asm(int32_t *r, const uint8_t *buf,
59+
unsigned buflen, const uint8_t *table);
60+
61+
#define mld_rej_uniform_eta4_asm MLD_NAMESPACE(rej_uniform_eta4_asm)
62+
unsigned mld_rej_uniform_eta4_asm(int32_t *r, const uint8_t *buf,
63+
unsigned buflen, const uint8_t *table);
64+
4165
#endif /* !MLD_NATIVE_AARCH64_SRC_ARITH_NATIVE_AARCH64_H */

0 commit comments

Comments
 (0)