Skip to content

Commit 4bda03b

Browse files
committed
tools: lint x25519/x448 twist flag
1 parent 353821c commit 4bda03b

File tree

4 files changed

+250
-1
lines changed

4 files changed

+250
-1
lines changed

.github/workflows/vectorlint.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,6 @@ jobs:
3939

4040
- name: Run vectorlint
4141
run: go run ./tools/vectorlint
42+
43+
- name: Run twistcheck
44+
run: go run ./tools/twistcheck

go.mod

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ module github.com/c2sp/wycheproof
22

33
go 1.23.6
44

5-
require github.com/santhosh-tekuri/jsonschema/v6 v6.0.1
5+
require (
6+
filippo.io/edwards25519 v1.1.0
7+
github.com/santhosh-tekuri/jsonschema/v6 v6.0.1
8+
)
69

710
require golang.org/x/text v0.14.0 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
2+
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
13
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
24
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
35
github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 h1:PKK9DyHxif4LZo+uQSgXNqs0jj5+xZwwfKHgph2lxBw=

tools/twistcheck/twistcheck.go

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
// twistcheck verifies that X25519 and X448 test vectors with twist points are
2+
// correctly marked with the Twist flag. Additionally, the result must be
3+
// "acceptable" since specific implementations may choose to reject twist points.
4+
package main
5+
6+
import (
7+
"crypto/x509/pkix"
8+
"encoding/asn1"
9+
"encoding/base64"
10+
"encoding/hex"
11+
"encoding/json"
12+
"encoding/pem"
13+
"flag"
14+
"fmt"
15+
"log"
16+
"math/big"
17+
"os"
18+
"slices"
19+
"strings"
20+
21+
"filippo.io/edwards25519/field"
22+
)
23+
24+
var (
25+
vectorFile = flag.String("vectors", "", "path to test vector file")
26+
)
27+
28+
func main() {
29+
flag.Parse()
30+
31+
files := []string{
32+
"testvectors_v1/x25519_test.json",
33+
"testvectors_v1/x448_test.json",
34+
"testvectors_v1/x25519_pem_test.json",
35+
"testvectors_v1/x448_pem_test.json",
36+
"testvectors_v1/x25519_jwk_test.json",
37+
"testvectors_v1/x448_jwk_test.json",
38+
}
39+
if *vectorFile != "" {
40+
files = []string{*vectorFile}
41+
}
42+
43+
totalErrors := 0
44+
for _, file := range files {
45+
log.Printf("Checking %s...\n", file)
46+
totalErrors += checkVectorFile(file)
47+
log.Println()
48+
}
49+
50+
if totalErrors > 0 {
51+
os.Exit(1)
52+
}
53+
}
54+
55+
func checkVectorFile(filename string) int {
56+
data, err := os.ReadFile(filename)
57+
if err != nil {
58+
panic(fmt.Sprintf("failed to read vector file: %v", err))
59+
}
60+
61+
var vectors TestVector
62+
if err := json.Unmarshal(data, &vectors); err != nil {
63+
panic(fmt.Sprintf("failed to parse vector JSON: %v", err))
64+
}
65+
66+
errors := 0
67+
for _, group := range vectors.TestGroups {
68+
for _, test := range group.Tests {
69+
70+
publicKeyBytes, err := extractPublicKey(test.Public)
71+
if err != nil || slices.Contains(test.Flags, "InvalidPublic") {
72+
// Skip test vectors with invalid public keys (different test concern)
73+
continue
74+
}
75+
76+
var expectedLen int
77+
switch group.Curve {
78+
case "curve25519":
79+
expectedLen = 32
80+
case "curve448":
81+
expectedLen = 56
82+
default:
83+
panic(fmt.Sprintf("unknown curve: %s", group.Curve))
84+
}
85+
86+
if len(publicKeyBytes) != expectedLen {
87+
// Skip test vectors with invalid key lengths (different test concern)
88+
continue
89+
}
90+
91+
isOnTwist, err := isPointOnTwist(publicKeyBytes, group.Curve)
92+
if err != nil {
93+
log.Printf("❌ tcId %d: error checking twist: %v", test.TcId, err)
94+
errors++
95+
continue
96+
}
97+
98+
hasTwistFlag := slices.Contains(test.Flags, "Twist")
99+
100+
if !isOnTwist && hasTwistFlag {
101+
log.Printf("❌ tcId %d: point is not on twist but has 'Twist' flag", test.TcId)
102+
errors++
103+
} else if isOnTwist && !hasTwistFlag {
104+
log.Printf("❌ tcId %d: point is on twist but missing 'Twist' flag", test.TcId)
105+
errors++
106+
} else if !isOnTwist {
107+
continue
108+
}
109+
110+
if test.Result != "acceptable" {
111+
log.Printf("❌ tcId %d: point is on twist but result is %q (expected 'acceptable')", test.TcId, test.Result)
112+
errors++
113+
}
114+
}
115+
}
116+
117+
log.Printf("Errors: %d\n", errors)
118+
return errors
119+
}
120+
121+
type TestVector struct {
122+
TestGroups []TestGroup `json:"testGroups"`
123+
}
124+
125+
type TestGroup struct {
126+
Curve string `json:"curve"`
127+
Tests []TestCase `json:"tests"`
128+
}
129+
130+
type TestCase struct {
131+
TcId int `json:"tcId"`
132+
Flags []string `json:"flags"`
133+
Public json.RawMessage `json:"public"`
134+
Result string `json:"result"`
135+
}
136+
137+
func extractPublicKey(publicRaw json.RawMessage) ([]byte, error) {
138+
// Try to parse as a plain string (hex or PEM format)
139+
var publicStr string
140+
if err := json.Unmarshal(publicRaw, &publicStr); err == nil {
141+
if strings.HasPrefix(publicStr, "-----BEGIN") {
142+
return extractFromPEM(publicStr)
143+
}
144+
return hex.DecodeString(publicStr)
145+
}
146+
147+
// Try to parse as JWK object
148+
var jwk struct {
149+
X string `json:"x"`
150+
}
151+
if err := json.Unmarshal(publicRaw, &jwk); err == nil && jwk.X != "" {
152+
return base64.RawURLEncoding.DecodeString(jwk.X)
153+
}
154+
155+
return nil, fmt.Errorf("unknown public key format")
156+
}
157+
158+
func extractFromPEM(pemStr string) ([]byte, error) {
159+
block, _ := pem.Decode([]byte(pemStr))
160+
if block == nil {
161+
return nil, fmt.Errorf("failed to decode PEM block")
162+
}
163+
164+
var spki struct {
165+
Algorithm pkix.AlgorithmIdentifier
166+
SubjectPublicKey asn1.BitString
167+
}
168+
if _, err := asn1.Unmarshal(block.Bytes, &spki); err != nil {
169+
return nil, fmt.Errorf("failed to parse SubjectPublicKeyInfo: %w", err)
170+
}
171+
172+
return spki.SubjectPublicKey.Bytes, nil
173+
}
174+
175+
func isPointOnTwist(publicKey []byte, curve string) (bool, error) {
176+
switch curve {
177+
case "curve25519":
178+
return isPointOnTwist25519(publicKey)
179+
case "curve448":
180+
return isPointOnTwist448(publicKey)
181+
default:
182+
return false, fmt.Errorf("unknown curve: %s", curve)
183+
}
184+
}
185+
186+
// isPointOnTwist25519 checks if a point is on the twist of Curve25519.
187+
// A point is on the twist if x³ + 486662x² + x is NOT a quadratic residue mod p.
188+
func isPointOnTwist25519(publicKey []byte) (bool, error) {
189+
x := new(field.Element)
190+
if _, err := x.SetBytes(publicKey); err != nil {
191+
return false, fmt.Errorf("invalid field element: %w", err)
192+
}
193+
194+
x2 := new(field.Element).Square(x)
195+
x3 := new(field.Element).Multiply(x2, x)
196+
ax2 := new(field.Element).Mult32(x2, 486662)
197+
rhs := new(field.Element).Add(x3, ax2)
198+
rhs.Add(rhs, x)
199+
200+
_, wasSquare := new(field.Element).SqrtRatio(rhs, new(field.Element).One())
201+
202+
return wasSquare == 0, nil
203+
}
204+
205+
// isPointOnTwist448 checks if a point is on the twist of Curve448.
206+
// A point is on the twist if x³ + 156326x² + x is NOT a quadratic residue mod p.
207+
func isPointOnTwist448(publicKey []byte) (bool, error) {
208+
// Curve448 field: p = 2^448 - 2^224 - 1
209+
p := new(big.Int).Lsh(big.NewInt(1), 448)
210+
p.Sub(p, new(big.Int).Lsh(big.NewInt(1), 224))
211+
p.Sub(p, big.NewInt(1))
212+
213+
slices.Reverse(publicKey) // Little endian -> Big endian
214+
x := new(big.Int).SetBytes(publicKey)
215+
x.Mod(x, p)
216+
217+
x2 := new(big.Int).Mul(x, x)
218+
x2.Mod(x2, p)
219+
220+
x3 := new(big.Int).Mul(x2, x)
221+
x3.Mod(x3, p)
222+
223+
a := big.NewInt(156326)
224+
ax2 := new(big.Int).Mul(a, x2)
225+
ax2.Mod(ax2, p)
226+
227+
rhs := new(big.Int).Add(x3, ax2)
228+
rhs.Add(rhs, x)
229+
rhs.Mod(rhs, p)
230+
231+
exp := new(big.Int).Sub(p, big.NewInt(1))
232+
exp.Div(exp, big.NewInt(2))
233+
legendre := new(big.Int).Exp(rhs, exp, p)
234+
235+
// If legendre is 0 or 1, point is on main curve
236+
if legendre.Cmp(big.NewInt(0)) == 0 || legendre.Cmp(big.NewInt(1)) == 0 {
237+
return false, nil
238+
}
239+
240+
return true, nil
241+
}

0 commit comments

Comments
 (0)