@@ -24,6 +24,7 @@ import (
24
24
"maps"
25
25
"math"
26
26
"math/big"
27
+ "math/bits"
27
28
28
29
"github.com/consensys/gnark-crypto/ecc"
29
30
bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381"
@@ -38,6 +39,7 @@ import (
38
39
"github.com/ethereum/go-ethereum/crypto/kzg4844"
39
40
"github.com/ethereum/go-ethereum/crypto/secp256r1"
40
41
"github.com/ethereum/go-ethereum/params"
42
+ "github.com/holiman/uint256"
41
43
"golang.org/x/crypto/ripemd160"
42
44
)
43
45
@@ -378,142 +380,223 @@ type bigModExp struct {
378
380
eip7883 bool
379
381
}
380
382
381
- var (
382
- big1 = big .NewInt (1 )
383
- big3 = big .NewInt (3 )
384
- big7 = big .NewInt (7 )
385
- big20 = big .NewInt (20 )
386
- big32 = big .NewInt (32 )
387
- big64 = big .NewInt (64 )
388
- big96 = big .NewInt (96 )
389
- big480 = big .NewInt (480 )
390
- big1024 = big .NewInt (1024 )
391
- big3072 = big .NewInt (3072 )
392
- big199680 = big .NewInt (199680 )
393
- )
394
-
395
- // modexpMultComplexity implements bigModexp multComplexity formula, as defined in EIP-198
383
+ // byzantiumMultComplexity implements the bigModexp multComplexity formula, as defined in EIP-198.
396
384
//
397
385
// def mult_complexity(x):
398
386
// if x <= 64: return x ** 2
399
387
// elif x <= 1024: return x ** 2 // 4 + 96 * x - 3072
400
388
// else: return x ** 2 // 16 + 480 * x - 199680
401
389
//
402
390
// where is x is max(length_of_MODULUS, length_of_BASE)
403
- func modexpMultComplexity (x * big.Int ) * big.Int {
391
+ // returns MaxUint64 if an overflow occurred.
392
+ func byzantiumMultComplexity (x uint64 ) uint64 {
404
393
switch {
405
- case x .Cmp (big64 ) <= 0 :
406
- x .Mul (x , x ) // x ** 2
407
- case x .Cmp (big1024 ) <= 0 :
408
- // (x ** 2 // 4 ) + ( 96 * x - 3072)
409
- x = new (big.Int ).Add (
410
- new (big.Int ).Rsh (new (big.Int ).Mul (x , x ), 2 ),
411
- new (big.Int ).Sub (new (big.Int ).Mul (big96 , x ), big3072 ),
412
- )
394
+ case x <= 64 :
395
+ return x * x
396
+ case x <= 1024 :
397
+ // x^2 / 4 + 96*x - 3072
398
+ return x * x / 4 + 96 * x - 3072
399
+
413
400
default :
414
- // (x ** 2 // 16) + (480 * x - 199680)
415
- x = new (big.Int ).Add (
416
- new (big.Int ).Rsh (new (big.Int ).Mul (x , x ), 4 ),
417
- new (big.Int ).Sub (new (big.Int ).Mul (big480 , x ), big199680 ),
418
- )
401
+ // For large x, use uint256 arithmetic to avoid overflow
402
+ // x^2 / 16 + 480*x - 199680
403
+
404
+ // xSqr = x^2 / 16
405
+ carry , xSqr := bits .Mul64 (x , x )
406
+ if carry != 0 {
407
+ return math .MaxUint64
408
+ }
409
+ xSqr = xSqr >> 4
410
+
411
+ // Calculate 480 * x (can't overflow if x^2 didn't overflow)
412
+ x480 := x * 480
413
+ // Calculate 480 * x - 199680 (will not underflow, since x > 1024)
414
+ x480 = x480 - 199680
415
+
416
+ // xSqr + x480
417
+ sum , carry := bits .Add64 (xSqr , x480 , 0 )
418
+ if carry != 0 {
419
+ return math .MaxUint64
420
+ }
421
+ return sum
422
+ }
423
+ }
424
+
425
+ // berlinMultComplexity implements the multiplication complexity formula for Berlin.
426
+ //
427
+ // def mult_complexity(x):
428
+ //
429
+ // ceiling(x/8)^2
430
+ //
431
+ // where is x is max(length_of_MODULUS, length_of_BASE)
432
+ func berlinMultComplexity (x uint64 ) uint64 {
433
+ // x = (x + 7) / 8
434
+ x , carry := bits .Add64 (x , 7 , 0 )
435
+ if carry != 0 {
436
+ return math .MaxUint64
437
+ }
438
+ x /= 8
439
+
440
+ // x^2
441
+ carry , x = bits .Mul64 (x , x )
442
+ if carry != 0 {
443
+ return math .MaxUint64
419
444
}
420
445
return x
421
446
}
422
447
448
+ // osakaMultComplexity implements the multiplication complexity formula for Osaka.
449
+ //
450
+ // For x <= 32: returns 16
451
+ // For x > 32: returns 2 * ceiling(x/8)^2
452
+ func osakaMultComplexity (x uint64 ) uint64 {
453
+ if x <= 32 {
454
+ return 16
455
+ }
456
+ // For x > 32, return 2 * berlinMultComplexity(x)
457
+ result := berlinMultComplexity (x )
458
+ carry , result := bits .Mul64 (result , 2 )
459
+ if carry != 0 {
460
+ return math .MaxUint64
461
+ }
462
+ return result
463
+ }
464
+
465
+ // modexpIterationCount calculates the number of iterations for the modexp precompile.
466
+ // This is the adjusted exponent length used in gas calculation.
467
+ func modexpIterationCount (expLen uint64 , expHead uint256.Int , multiplier uint64 ) uint64 {
468
+ var iterationCount uint64
469
+
470
+ // For large exponents (expLen > 32), add (expLen - 32) * multiplier
471
+ if expLen > 32 {
472
+ iterationCount = (expLen - 32 ) * multiplier
473
+ }
474
+
475
+ // Add the MSB position - 1 if expHead is non-zero
476
+ if bitLen := expHead .BitLen (); bitLen > 0 {
477
+ iterationCount += uint64 (bitLen - 1 )
478
+ }
479
+
480
+ return max (iterationCount , 1 )
481
+ }
482
+
483
+ // byzantiumModexpGas calculates the gas cost for the modexp precompile using Byzantium rules.
484
+ func byzantiumModexpGas (baseLen , expLen , modLen uint64 , expHead uint256.Int ) uint64 {
485
+ const (
486
+ multiplier = 8
487
+ divisor = 20
488
+ )
489
+
490
+ maxLen := max (baseLen , modLen )
491
+ multComplexity := byzantiumMultComplexity (maxLen )
492
+ if multComplexity == math .MaxUint64 {
493
+ return math .MaxUint64
494
+ }
495
+ iterationCount := modexpIterationCount (expLen , expHead , multiplier )
496
+
497
+ // Calculate gas: (multComplexity * iterationCount) / divisor
498
+ carry , gas := bits .Mul64 (iterationCount , multComplexity )
499
+ gas /= divisor
500
+ if carry != 0 {
501
+ return math .MaxUint64
502
+ }
503
+ return gas
504
+ }
505
+
506
+ // berlinModexpGas calculates the gas cost for the modexp precompile using Berlin rules.
507
+ func berlinModexpGas (baseLen , expLen , modLen uint64 , expHead uint256.Int ) uint64 {
508
+ const (
509
+ multiplier = 8
510
+ divisor = 3
511
+ minGas = 200
512
+ )
513
+
514
+ maxLen := max (baseLen , modLen )
515
+ multComplexity := berlinMultComplexity (maxLen )
516
+ if multComplexity == math .MaxUint64 {
517
+ return math .MaxUint64
518
+ }
519
+ iterationCount := modexpIterationCount (expLen , expHead , multiplier )
520
+
521
+ // Calculate gas: (multComplexity * iterationCount) / divisor
522
+ carry , gas := bits .Mul64 (iterationCount , multComplexity )
523
+ gas /= divisor
524
+ if carry != 0 {
525
+ return math .MaxUint64
526
+ }
527
+ return max (gas , minGas )
528
+ }
529
+
530
+ // osakaModexpGas calculates the gas cost for the modexp precompile using Osaka rules.
531
+ func osakaModexpGas (baseLen , expLen , modLen uint64 , expHead uint256.Int ) uint64 {
532
+ const (
533
+ multiplier = 16
534
+ divisor = 3
535
+ minGas = 500
536
+ )
537
+
538
+ maxLen := max (baseLen , modLen )
539
+ multComplexity := osakaMultComplexity (maxLen )
540
+ if multComplexity == math .MaxUint64 {
541
+ return math .MaxUint64
542
+ }
543
+ iterationCount := modexpIterationCount (expLen , expHead , multiplier )
544
+
545
+ // Calculate gas: (multComplexity * iterationCount) / osakaDivisor
546
+ carry , gas := bits .Mul64 (iterationCount , multComplexity )
547
+ if carry != 0 {
548
+ return math .MaxUint64
549
+ }
550
+ return max (gas , minGas )
551
+ }
552
+
423
553
// RequiredGas returns the gas required to execute the pre-compiled contract.
424
554
func (c * bigModExp ) RequiredGas (input []byte ) uint64 {
425
- var (
426
- baseLen = new (big.Int ).SetBytes (getData (input , 0 , 32 ))
427
- expLen = new (big.Int ).SetBytes (getData (input , 32 , 32 ))
428
- modLen = new (big.Int ).SetBytes (getData (input , 64 , 32 ))
429
- )
555
+ // Parse input lengths
556
+ baseLenBig := new (uint256.Int ).SetBytes (getData (input , 0 , 32 ))
557
+ expLenBig := new (uint256.Int ).SetBytes (getData (input , 32 , 32 ))
558
+ modLenBig := new (uint256.Int ).SetBytes (getData (input , 64 , 32 ))
559
+
560
+ // Convert to uint64, capping at max value
561
+ baseLen := baseLenBig .Uint64 ()
562
+ if ! baseLenBig .IsUint64 () {
563
+ baseLen = math .MaxUint64
564
+ }
565
+ expLen := expLenBig .Uint64 ()
566
+ if ! expLenBig .IsUint64 () {
567
+ expLen = math .MaxUint64
568
+ }
569
+ modLen := modLenBig .Uint64 ()
570
+ if ! modLenBig .IsUint64 () {
571
+ modLen = math .MaxUint64
572
+ }
573
+
574
+ // Skip the header
430
575
if len (input ) > 96 {
431
576
input = input [96 :]
432
577
} else {
433
578
input = input [:0 ]
434
579
}
580
+
435
581
// Retrieve the head 32 bytes of exp for the adjusted exponent length
436
- var expHead * big.Int
437
- if big .NewInt (int64 (len (input ))).Cmp (baseLen ) <= 0 {
438
- expHead = new (big.Int )
439
- } else {
440
- if expLen .Cmp (big32 ) > 0 {
441
- expHead = new (big.Int ).SetBytes (getData (input , baseLen .Uint64 (), 32 ))
582
+ var expHead uint256.Int
583
+ if uint64 (len (input )) > baseLen {
584
+ if expLen > 32 {
585
+ expHead .SetBytes (getData (input , baseLen , 32 ))
442
586
} else {
443
- expHead = new (big.Int ).SetBytes (getData (input , baseLen .Uint64 (), expLen .Uint64 ()))
587
+ // TODO: Check that if expLen < baseLen, then getData will return an empty slice
588
+ expHead .SetBytes (getData (input , baseLen , expLen ))
444
589
}
445
590
}
446
- // Calculate the adjusted exponent length
447
- var msb int
448
- if bitlen := expHead .BitLen (); bitlen > 0 {
449
- msb = bitlen - 1
450
- }
451
- adjExpLen := new (big.Int )
452
- if expLen .Cmp (big32 ) > 0 {
453
- adjExpLen .Sub (expLen , big32 )
454
- if c .eip7883 {
455
- adjExpLen .Lsh (adjExpLen , 4 )
456
- } else {
457
- adjExpLen .Lsh (adjExpLen , 3 )
458
- }
459
- }
460
- adjExpLen .Add (adjExpLen , big .NewInt (int64 (msb )))
461
- // Calculate the gas cost of the operation
462
- gas := new (big.Int )
463
- if modLen .Cmp (baseLen ) < 0 {
464
- gas .Set (baseLen )
465
- } else {
466
- gas .Set (modLen )
467
- }
468
-
469
- maxLenOver32 := gas .Cmp (big32 ) > 0
470
- if c .eip2565 {
471
- // EIP-2565 (Berlin fork) has three changes:
472
- //
473
- // 1. Different multComplexity (inlined here)
474
- // in EIP-2565 (https://eips.ethereum.org/EIPS/eip-2565):
475
- //
476
- // def mult_complexity(x):
477
- // ceiling(x/8)^2
478
- //
479
- // where is x is max(length_of_MODULUS, length_of_BASE)
480
- gas .Add (gas , big7 )
481
- gas .Rsh (gas , 3 )
482
- gas .Mul (gas , gas )
483
-
484
- var minPrice uint64 = 200
485
- if c .eip7883 {
486
- minPrice = 500
487
- if maxLenOver32 {
488
- gas .Add (gas , gas )
489
- } else {
490
- gas = big .NewInt (16 )
491
- }
492
- }
493
-
494
- if adjExpLen .Cmp (big1 ) > 0 {
495
- gas .Mul (gas , adjExpLen )
496
- }
497
- // 2. Different divisor (`GQUADDIVISOR`) (3)
498
- if ! c .eip7883 {
499
- gas .Div (gas , big3 )
500
- }
501
- if gas .BitLen () > 64 {
502
- return math .MaxUint64
503
- }
504
- return max (minPrice , gas .Uint64 ())
505
- }
506
591
507
- // Pre-Berlin logic.
508
- gas = modexpMultComplexity (gas )
509
- if adjExpLen .Cmp (big1 ) > 0 {
510
- gas .Mul (gas , adjExpLen )
511
- }
512
- gas .Div (gas , big20 )
513
- if gas .BitLen () > 64 {
514
- return math .MaxUint64
592
+ // Choose the appropriate gas calculation based on the EIP flags
593
+ if c .eip7883 {
594
+ return osakaModexpGas (baseLen , expLen , modLen , expHead )
595
+ } else if c .eip2565 {
596
+ return berlinModexpGas (baseLen , expLen , modLen , expHead )
597
+ } else {
598
+ return byzantiumModexpGas (baseLen , expLen , modLen , expHead )
515
599
}
516
- return gas .Uint64 ()
517
600
}
518
601
519
602
func (c * bigModExp ) Run (input []byte ) ([]byte , error ) {
0 commit comments