diff --git a/Cargo.toml b/Cargo.toml index 9430a71710..a48c207aa7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -111,6 +111,7 @@ repository.workspace = true resolver = "2" [features] +default = ["with-debug-utils"] with-cheatcode = [] with-debug-utils = [] with-mem-tracing = [] @@ -130,6 +131,7 @@ cairo-lang-sierra.workspace = true cairo-lang-utils.workspace = true educe.workspace = true itertools.workspace = true +lambdaworks-math.workspace = true lazy_static.workspace = true libc.workspace = true libloading.workspace = true diff --git a/programs/benches/factorial_2M_inv.cairo b/programs/benches/factorial_2M_inv.cairo new file mode 100644 index 0000000000..187597f9bd --- /dev/null +++ b/programs/benches/factorial_2M_inv.cairo @@ -0,0 +1,15 @@ +fn factorial_inv(value: felt252, n: felt252) -> felt252 { + if (n == 1) { + value + } else { + factorial_inv(felt252_div(value, n.try_into().unwrap()), n - 1) + } +} + +fn main() { + let result = factorial_inv(0x4d6e41de886ac83938da3456ccf1481182687989ead34d9d35236f0864575a0, 2_000_000); + assert( + result == 1, + 'invalid result' + ); +} diff --git a/src/arch/aarch64.rs b/src/arch/aarch64.rs index 556de0c584..ec82017fa8 100644 --- a/src/arch/aarch64.rs +++ b/src/arch/aarch64.rs @@ -10,7 +10,11 @@ #![cfg(target_arch = "aarch64")] use super::AbiArgument; -use crate::{error::Error, starknet::U256, utils::get_integer_layout}; +use crate::{ + error::Error, + starknet::U256, + utils::{get_integer_layout, montgomery::MontyBytes}, +}; use cairo_lang_sierra::ids::ConcreteTypeId; use num_traits::ToBytes; use starknet_types_core::felt::Felt; @@ -197,7 +201,8 @@ impl AbiArgument for Felt { if buffer.len() >= 56 { align_to(buffer, get_integer_layout(252).align()); } - buffer.extend_from_slice(&self.to_bytes_le()); + + buffer.extend_from_slice(&self.to_bytes_le_raw()); Ok(()) } } diff --git a/src/arch/x86_64.rs b/src/arch/x86_64.rs index 59e274affd..f42a30278c 100644 --- a/src/arch/x86_64.rs +++ b/src/arch/x86_64.rs @@ -10,7 +10,11 @@ #![cfg(target_arch = "x86_64")] use super::AbiArgument; -use crate::{error::Error, starknet::U256, utils::get_integer_layout}; +use crate::{ + error::Error, + starknet::U256, + utils::{get_integer_layout, montgomery::MontyBytes}, +}; use cairo_lang_sierra::ids::ConcreteTypeId; use num_traits::ToBytes; use starknet_types_core::felt::Felt; @@ -159,7 +163,7 @@ impl AbiArgument for Felt { align_to(buffer, get_integer_layout(252).align()); } - buffer.extend_from_slice(&self.to_bytes_le()); + buffer.extend_from_slice(&self.to_bytes_le_raw()); Ok(()) } } diff --git a/src/executor.rs b/src/executor.rs index 2f9e97a27e..b620b31f4a 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -34,6 +34,7 @@ use cairo_lang_sierra::{ use libc::c_void; use num_bigint::BigInt; use num_traits::One; +use starknet_types_core::felt::Felt; use std::{alloc::Layout, arch::global_asm, ptr::NonNull}; mod aot; @@ -450,11 +451,17 @@ fn parse_result( #[cfg(target_arch = "aarch64")] Ok(Value::Felt252({ + use lambdaworks_math::{ + traits::ByteConversion, unsigned_integer::element::U256, + }; + let data = unsafe { std::mem::transmute::<&mut [u64; 4], &mut [u8; 32]>(&mut ret_registers) }; data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - starknet_types_core::felt::Felt::from_bytes_le(data) + let value = U256::from_bytes_le(data).unwrap(); + + Felt::from_raw(value.limbs) })) } }, diff --git a/src/executor/contract.rs b/src/executor/contract.rs index a2da8cb022..a30ef35403 100644 --- a/src/executor/contract.rs +++ b/src/executor/contract.rs @@ -51,7 +51,7 @@ use crate::{ types::TypeBuilder, utils::{ decode_error_message, generate_function_name, get_integer_layout, get_types_total_size, - libc_free, libc_malloc, BuiltinCosts, + libc_free, libc_malloc, montgomery::MontyBytes, BuiltinCosts, }, OptLevel, }; @@ -75,6 +75,7 @@ use cairo_lang_starknet_classes::{ }; use educe::Educe; use itertools::{chain, Itertools}; +use lambdaworks_math::{traits::ByteConversion, unsigned_integer::element::UnsignedInteger}; use libloading::Library; use serde::{Deserialize, Serialize}; use starknet_types_core::felt::Felt; @@ -500,7 +501,8 @@ impl AotContractExecutor { }; for (idx, elem) in args.iter().enumerate() { - let f = elem.to_bytes_le(); + let f = elem.to_bytes_le_raw(); + unsafe { std::ptr::copy_nonoverlapping( f.as_ptr().cast::(), @@ -666,9 +668,15 @@ impl AotContractExecutor { let cur_elem_ptr = unsafe { array_ptr.byte_add(elem_stride * i as usize) }; let mut data = unsafe { cur_elem_ptr.cast::<[u8; 32]>().read() }; + data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - array_value.push(Felt::from_bytes_le(&data)); + let felt = { + let data = UnsignedInteger::from_bytes_le(&data).unwrap(); + Felt::from_raw(data.limbs) + }; + + array_value.push(felt); } unsafe { diff --git a/src/libfuncs/bool.rs b/src/libfuncs/bool.rs index afa1c490b4..5fa17c1a69 100644 --- a/src/libfuncs/bool.rs +++ b/src/libfuncs/bool.rs @@ -5,7 +5,10 @@ use crate::{ error::{panic::ToNativeAssertError, Result}, metadata::MetadataStorage, types::TypeBuilder, - utils::ProgramRegistryExt, + utils::{ + montgomery::{self, MONTY_R2}, + ProgramRegistryExt, + }, }; use cairo_lang_sierra::{ extensions::{ @@ -200,9 +203,11 @@ pub fn build_bool_to_felt252<'ctx, 'this>( let value = entry.arg(0)?; let tag_value = entry.extract_value(context, location, value, tag_ty, 0)?; - let result = entry.extui(tag_value, felt252_ty, location)?; + // Convert into Montgomery representation. + let r2 = entry.const_int(context, location, *MONTY_R2, 257)?; + let felt = montgomery::mlir::monty_mul(context, entry, tag_value, r2, felt252_ty, location)?; - helper.br(entry, 0, &[result], location) + helper.br(entry, 0, &[felt], location) } #[cfg(test)] diff --git a/src/libfuncs/const.rs b/src/libfuncs/const.rs index d6c52c3e84..6652f1ca9b 100644 --- a/src/libfuncs/const.rs +++ b/src/libfuncs/const.rs @@ -2,12 +2,12 @@ use super::LibfuncHelper; use crate::{ - error::{Error, Result}, + error::{panic::ToNativeAssertError, Error, Result}, libfuncs::{r#enum::build_enum_value, r#struct::build_struct_value}, metadata::{realloc_bindings::ReallocBindingsMeta, MetadataStorage}, native_panic, types::TypeBuilder, - utils::{ProgramRegistryExt, RangeExt, PRIME}, + utils::{montgomery::monty_transform, ProgramRegistryExt, RangeExt, PRIME}, }; use cairo_lang_sierra::{ extensions::{ @@ -265,8 +265,10 @@ pub fn build_const_type_value<'ctx, 'this>( Sign::Minus => PRIME.clone() - value, _ => value, }; - - Ok(entry.const_int_from_type(context, location, value, inner_ty)?) + let monty_value = monty_transform(&value, &PRIME).to_native_assert_error(&format!( + "could not transform felt252: {value} to Montgomery form" + ))?; + Ok(entry.const_int_from_type(context, location, monty_value, inner_ty)?) } CoreTypeConcrete::Starknet( StarknetTypeConcrete::ClassHash(_) | StarknetTypeConcrete::ContractAddress(_), @@ -281,8 +283,10 @@ pub fn build_const_type_value<'ctx, 'this>( Sign::Minus => PRIME.clone() - value, _ => value, }; - - Ok(entry.const_int_from_type(context, location, value, inner_ty)?) + let monty_value = monty_transform(&value, &PRIME).to_native_assert_error(&format!( + "could not transform felt252: {value} to Montgomery form" + ))?; + Ok(entry.const_int_from_type(context, location, monty_value, inner_ty)?) } CoreTypeConcrete::Uint8(_) | CoreTypeConcrete::Uint16(_) diff --git a/src/libfuncs/felt252.rs b/src/libfuncs/felt252.rs index af79646f12..53da58cacd 100644 --- a/src/libfuncs/felt252.rs +++ b/src/libfuncs/felt252.rs @@ -2,9 +2,9 @@ use super::LibfuncHelper; use crate::{ - error::Result, + error::{panic::ToNativeAssertError, Result}, metadata::MetadataStorage, - utils::{ProgramRegistryExt, PRIME}, + utils::{montgomery, ProgramRegistryExt, PRIME}, }; use cairo_lang_sierra::{ extensions::{ @@ -19,12 +19,9 @@ use cairo_lang_sierra::{ program_registry::ProgramRegistry, }; use melior::{ - dialect::{ - arith::{self, CmpiPredicate}, - cf, - }, + dialect::arith::{self, CmpiPredicate}, helpers::{ArithBlockExt, BuiltinBlockExt}, - ir::{r#type::IntegerType, Block, BlockLike, Location, Value, ValueLike}, + ir::{r#type::IntegerType, Block, Location, Value, ValueLike}, Context, }; use num_bigint::{BigInt, Sign}; @@ -73,7 +70,6 @@ pub fn build_binary_operation<'ctx, 'this>( &info.branch_signatures()[0].vars[0].ty, )?; let i256 = IntegerType::new(context, 256).into(); - let i512 = IntegerType::new(context, 512).into(); let (op, lhs, rhs) = match info { Felt252BinaryOperationConcrete::WithVar(operation) => { @@ -86,9 +82,12 @@ pub fn build_binary_operation<'ctx, 'this>( .clone(), _ => operation.c.magnitude().clone(), }; + let monty_value = montgomery::monty_transform(&value, &PRIME).to_native_assert_error( + &format!("could not transform felt252: {value} to Montgomery form"), + )?; // TODO: Ensure that the constant is on the correct side of the operation. - let rhs = entry.const_int_from_type(context, location, value, felt252_ty)?; + let rhs = entry.const_int_from_type(context, location, monty_value, felt252_ty)?; (operation.operator, entry.arg(0)?, rhs) } @@ -101,7 +100,7 @@ pub fn build_binary_operation<'ctx, 'this>( let result = entry.addi(lhs, rhs, location)?; let prime = entry.const_int_from_type(context, location, PRIME.clone(), i256)?; - let result_mod = entry.append_op_result(arith::subi(result, prime, location))?; + let result_mod = entry.subi(result, prime, location)?; let is_out_of_range = entry.cmpi(context, CmpiPredicate::Uge, result, prime, location)?; @@ -111,12 +110,13 @@ pub fn build_binary_operation<'ctx, 'this>( result, location, ))?; + entry.trunci(result, felt252_ty, location)? } Felt252BinaryOperator::Sub => { let lhs = entry.extui(lhs, i256, location)?; let rhs = entry.extui(rhs, i256, location)?; - let result = entry.append_op_result(arith::subi(lhs, rhs, location))?; + let result = entry.subi(lhs, rhs, location)?; let prime = entry.const_int_from_type(context, location, PRIME.clone(), i256)?; let result_mod = entry.addi(result, prime, location)?; @@ -131,148 +131,10 @@ pub fn build_binary_operation<'ctx, 'this>( entry.trunci(result, felt252_ty, location)? } Felt252BinaryOperator::Mul => { - let lhs = entry.extui(lhs, i512, location)?; - let rhs = entry.extui(rhs, i512, location)?; - let result = entry.muli(lhs, rhs, location)?; - - let prime = entry.const_int_from_type(context, location, PRIME.clone(), i512)?; - let result_mod = entry.append_op_result(arith::remui(result, prime, location))?; - let is_out_of_range = - entry.cmpi(context, CmpiPredicate::Uge, result, prime, location)?; - - let result = entry.append_op_result(arith::select( - is_out_of_range, - result_mod, - result, - location, - ))?; - entry.trunci(result, felt252_ty, location)? + montgomery::mlir::monty_mul(context, entry, lhs, rhs, felt252_ty, location)? } Felt252BinaryOperator::Div => { - // The extended euclidean algorithm calculates the greatest common divisor of two integers, - // as well as the bezout coefficients x and y such that for inputs a and b, ax+by=gcd(a,b) - // We use this in felt division to find the modular inverse of a given number - // If a is the number we're trying to find the inverse of, we can do - // ax+y*PRIME=gcd(a,PRIME)=1 => ax = 1 (mod PRIME) - // Hence for input a, we return x - // The input MUST be non-zero - // See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm - let start_block = helper.append_block(Block::new(&[(i512, location)])); - let loop_block = helper.append_block(Block::new(&[ - (i512, location), - (i512, location), - (i512, location), - (i512, location), - ])); - let negative_check_block = helper.append_block(Block::new(&[])); - // Block containing final result - let inverse_result_block = helper.append_block(Block::new(&[(i512, location)])); - // Egcd works by calculating a series of remainders, each the remainder of dividing the previous two - // For the initial setup, r0 = PRIME, r1 = a - // This order is chosen because if we reverse them, then the first iteration will just swap them - let prev_remainder = - start_block.const_int_from_type(context, location, PRIME.clone(), i512)?; - let remainder = start_block.arg(0)?; - // Similarly we'll calculate another series which starts 0,1,... and from which we will retrieve the modular inverse of a - let prev_inverse = start_block.const_int_from_type(context, location, 0, i512)?; - let inverse = start_block.const_int_from_type(context, location, 1, i512)?; - start_block.append_operation(cf::br( - loop_block, - &[prev_remainder, remainder, prev_inverse, inverse], - location, - )); - - //---Loop body--- - // Arguments are rem_(i-1), rem, inv_(i-1), inv - let prev_remainder = loop_block.arg(0)?; - let remainder = loop_block.arg(1)?; - let prev_inverse = loop_block.arg(2)?; - let inverse = loop_block.arg(3)?; - - // First calculate q = rem_(i-1)/rem_i, rounded down - let quotient = - loop_block.append_op_result(arith::divui(prev_remainder, remainder, location))?; - // Then r_(i+1) = r_(i-1) - q * r_i, and inv_(i+1) = inv_(i-1) - q * inv_i - let rem_times_quo = loop_block.muli(remainder, quotient, location)?; - let inv_times_quo = loop_block.muli(inverse, quotient, location)?; - let next_remainder = loop_block.append_op_result(arith::subi( - prev_remainder, - rem_times_quo, - location, - ))?; - let next_inverse = - loop_block.append_op_result(arith::subi(prev_inverse, inv_times_quo, location))?; - - // If r_(i+1) is 0, then inv_i is the inverse - let zero = loop_block.const_int_from_type(context, location, 0, i512)?; - let next_remainder_eq_zero = - loop_block.cmpi(context, CmpiPredicate::Eq, next_remainder, zero, location)?; - loop_block.append_operation(cf::cond_br( - context, - next_remainder_eq_zero, - negative_check_block, - loop_block, - &[], - &[remainder, next_remainder, inverse, next_inverse], - location, - )); - - // egcd sometimes returns a negative number for the inverse, - // in such cases we must simply wrap it around back into [0, PRIME) - // this suffices because |inv_i| <= divfloor(PRIME,2) - let zero = negative_check_block.const_int_from_type(context, location, 0, i512)?; - - let is_negative = negative_check_block - .append_operation(arith::cmpi( - context, - CmpiPredicate::Slt, - inverse, - zero, - location, - )) - .result(0)? - .into(); - // if the inverse is < 0, add PRIME - let prime = - negative_check_block.const_int_from_type(context, location, PRIME.clone(), i512)?; - let wrapped_inverse = negative_check_block.addi(inverse, prime, location)?; - let inverse = negative_check_block.append_op_result(arith::select( - is_negative, - wrapped_inverse, - inverse, - location, - ))?; - negative_check_block.append_operation(cf::br( - inverse_result_block, - &[inverse], - location, - )); - - // Div Logic Start - // Fetch operands - let lhs = entry.extui(lhs, i512, location)?; - let rhs = entry.extui(rhs, i512, location)?; - // Calculate inverse of rhs, callling the inverse implementation's starting block - entry.append_operation(cf::br(start_block, &[rhs], location)); - // Fetch the inverse result from the result block - let inverse = inverse_result_block.arg(0)?; - // Peform lhs * (1/ rhs) - let result = inverse_result_block.muli(lhs, inverse, location)?; - // Apply modulo and convert result to felt252 - let result_mod = - inverse_result_block.append_op_result(arith::remui(result, prime, location))?; - let is_out_of_range = - inverse_result_block.cmpi(context, CmpiPredicate::Uge, result, prime, location)?; - - let result = inverse_result_block.append_op_result(arith::select( - is_out_of_range, - result_mod, - result, - location, - ))?; - let result = inverse_result_block.trunci(result, felt252_ty, location)?; - - return helper.br(inverse_result_block, 0, &[result], location); + montgomery::mlir::monty_div(context, entry, lhs, rhs, felt252_ty, location)? } }; @@ -303,7 +165,10 @@ pub fn build_const<'ctx, 'this>( &info.branch_signatures()[0].vars[0].ty, )?; - let value = entry.const_int_from_type(context, location, value, felt252_ty)?; + let monty_value = montgomery::monty_transform(&value, &PRIME).to_native_assert_error( + &format!("could not transform felt252: {value} to Montgomery form"), + )?; + let value = entry.const_int_from_type(context, location, monty_value, felt252_ty)?; helper.br(entry, 0, &[value], location) } diff --git a/src/libfuncs/int.rs b/src/libfuncs/int.rs index c8f1e5e478..3e54ff089c 100644 --- a/src/libfuncs/int.rs +++ b/src/libfuncs/int.rs @@ -6,7 +6,10 @@ use crate::{ metadata::MetadataStorage, native_panic, types::TypeBuilder, - utils::{ProgramRegistryExt, PRIME}, + utils::{ + montgomery::{self, MONTY_R2}, + ProgramRegistryExt, PRIME, + }, }; use cairo_lang_sierra::{ extensions::{ @@ -387,7 +390,6 @@ fn build_from_felt252<'ctx, 'this>( let value_ty = registry.get_type(&info.signature.branch_signatures[0].vars[1].ty)?; let threshold = value_ty.integer_range(registry)?; let threshold_size = threshold.size(); - let value_ty = value_ty.build( context, helper, @@ -396,7 +398,10 @@ fn build_from_felt252<'ctx, 'this>( &info.signature.branch_signatures[0].vars[1].ty, )?; + // We casting from a felt, so we need to reduce it. let input = entry.arg(1)?; + let k1 = entry.const_int_from_type(context, location, 1, input.r#type())?; + let input = montgomery::mlir::monty_mul(context, entry, input, k1, input.r#type(), location)?; // Handle signedness separately. let (is_in_range, value) = if threshold.lower.is_zero() { @@ -894,9 +899,13 @@ fn build_to_felt252<'ctx, 'this>( entry.append_op_result(arith::select(is_negative, neg_value, value, location))? } else { - entry.extui(entry.arg(0)?, felt252_ty, location)? + entry.arg(0)? }; + // We are casting to a felt, so we need convert it into Montgomery space. + let r2 = entry.const_int(context, location, *MONTY_R2, 257)?; + let value = montgomery::mlir::monty_mul(context, entry, value, r2, felt252_ty, location)?; + helper.br(entry, 0, &[value], location) } @@ -911,10 +920,15 @@ fn build_u128s_from_felt252<'ctx, 'this>( ) -> Result<()> { let target_ty = IntegerType::new(context, 128).into(); - let lo = entry.trunci(entry.arg(1)?, target_ty, location)?; + // We casting from a felt, so we need to reduce it. + let felt = entry.arg(1)?; + let k1 = entry.const_int_from_type(context, location, 1, felt.r#type())?; + let lo = montgomery::mlir::monty_mul(context, entry, felt, k1, felt.r#type(), location)?; + + let k128 = entry.const_int_from_type(context, location, 128, felt.r#type())?; + let hi = entry.shrui(lo, k128, location)?; - let k128 = entry.const_int_from_type(context, location, 128, entry.arg(1)?.r#type())?; - let hi = entry.shrui(entry.arg(1)?, k128, location)?; + let lo = entry.trunci(lo, target_ty, location)?; let hi = entry.trunci(hi, target_ty, location)?; let k0 = entry.const_int_from_type(context, location, 0, target_ty)?; diff --git a/src/runtime.rs b/src/runtime.rs index 249d562fad..760be5fdbe 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -5,6 +5,7 @@ use cairo_lang_sierra_gas::core_libfunc_cost::{ DICT_SQUASH_REPEATED_ACCESS_COST, DICT_SQUASH_UNIQUE_KEY_COST, }; use itertools::Itertools; +use lambdaworks_math::{traits::ByteConversion, unsigned_integer::element::U256}; use lazy_static::lazy_static; use num_bigint::BigInt; use num_traits::{ToPrimitive, Zero}; @@ -319,7 +320,8 @@ pub unsafe extern "C" fn cairo_native__dict_squash( let no_big_keys = dict .mappings .keys() - .map(Felt::from_bytes_le) + .map(|b| U256::from_bytes_le(b).expect("felt bytes should be valid")) + .map(|v| Felt::from_raw(v.limbs)) .all(|key| key < Felt::from(BigInt::from(1).shl(128))); let number_of_keys = dict.mappings.len() as u64; diff --git a/src/utils.rs b/src/utils.rs index 6398d96be5..5376731024 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -32,6 +32,7 @@ use std::{ use thiserror::Error; pub mod mem_tracing; +pub mod montgomery; mod program_registry_ext; mod range_ext; #[cfg(feature = "with-segfault-catcher")] diff --git a/src/utils/montgomery.rs b/src/utils/montgomery.rs new file mode 100644 index 0000000000..18025e20e2 --- /dev/null +++ b/src/utils/montgomery.rs @@ -0,0 +1,581 @@ +//! # Montgomery implementation for Felt252. +//! +//! This module holds utility functions for performing arithmetic operations +//! inside the Montgomery space. +//! +//! Representing felts in the Montgomery space allows for optimizations when +//! performing multiplication and division operations. This is because it +//! avoids having to perform modulo operations and even divisions. Montgomery +//! reduces these operations to shifts and simple arithmetic operation such as +//! additions and subtractions. +//! +//! The way this works is by representing a values as x' = x * r mod n. This +//! introduces a new constant `r` which, for performance reasons, it is defined +//! as r = 2^{k} where k should be big enough to satisfy r > n. +//! +//! For more information on check: https://en.wikipedia.org/wiki/Montgomery_modular_multiplication. + +use std::sync::LazyLock; + +use lambdaworks_math::{ + errors::CreationError, + traits::ByteConversion, + unsigned_integer::{ + element::{UnsignedInteger, U256}, + montgomery::MontgomeryAlgorithms, + }, +}; +use num_bigint::BigUint; +use num_traits::Num; +use starknet_types_core::felt::Felt; + +// R parameter for felts. R = 2^{256} which is the smallets power of 2 greater than prime. +pub static MONTY_R: LazyLock = LazyLock::new(|| BigUint::from(1u64) << 256); +// R2 parameter for felts. R2 = 2^{256 * 2} mod prime. This value is a U256 instead of a +// BigUint to integrate with lambdaworks with ease. +pub static MONTY_R2: LazyLock = LazyLock::new(|| { + UnsignedInteger::from_hex_unchecked( + "7FFD4AB5E008810FFFFFFFFFF6F800000000001330FFFFFFFFFFD737E000401", + ) +}); +// MU parameter for felts. MU = -prime^{-1} mod 2^{64}. The variant is used to +// allow a better integration with lambdaworks. +// Check: https://github.com/lambdaclass/lambdaworks/blob/main/crates/math/src/field/fields/montgomery_backed_prime_fields.rs#L60 +pub const MONTY_MU_U64: u64 = 18446744073709551615; +// MU parameter for felts. MU = prime^{-1} mod R. +pub static MONTY_MU_U256: LazyLock = LazyLock::new(|| { + BigUint::from_str_radix( + "f7ffffffffffffef000000000000000000000000000000000000000000000001", + 16, + ) + .expect("hardcoded mu constant should be valid") +}); + +pub trait MontyBytes { + fn to_bytes_le_raw(&self) -> [u8; 32]; +} + +impl MontyBytes for Felt { + fn to_bytes_le_raw(&self) -> [u8; 32] { + let limbs = self.to_raw(); + let mut buffer = [0; 32]; + + for i in (0..4).rev() { + let bytes = limbs[i].to_le_bytes(); + let init = (3 - i) * 8; + buffer[init..init + 8].copy_from_slice(&bytes); + } + + buffer + } +} + +/// Computes the Montgomery reduction. +/// TODO: add docs. +pub fn monty_reduction(x: &BigUint, modulus: &BigUint) -> Result { + let x = U256::from_hex(&x.to_str_radix(16))?; + let modulus = U256::from_hex(&modulus.to_str_radix(16))?; + + let reduced = MontgomeryAlgorithms::cios(&x, &U256::from_u64(1), &modulus, &MONTY_MU_U64); + + Ok(BigUint::from_bytes_le(&reduced.to_bytes_le())) +} + +/// Computes the Montgomery transform operation. +/// TODO: add docs. +pub fn monty_transform(x: &BigUint, modulus: &BigUint) -> Result { + let x = U256::from_hex(&x.to_str_radix(16))?; + let modulus = U256::from_hex(&modulus.to_str_radix(16))?; + + let reduced = MontgomeryAlgorithms::cios(&x, &MONTY_R2, &modulus, &MONTY_MU_U64); + + Ok(BigUint::from_bytes_le(&reduced.to_bytes_le())) +} + +pub fn monty_inverse() { + let felt = Felt::from(1).inverse().unwrap(); + let felt = U256::from_bytes_le(&felt.to_bytes_le_raw()).unwrap(); + dbg!(felt.to_hex()); +} + +pub mod mlir { + use crate::{ + error::Result, + utils::{ + montgomery::{MONTY_MU_U256, MONTY_R, MONTY_R2}, + PRIME, + }, + }; + use melior::{ + dialect::{arith, ods, scf}, + helpers::{ArithBlockExt, BuiltinBlockExt}, + ir::{r#type::IntegerType, Block, BlockLike, Location, Region, Type, Value, ValueLike}, + Context, + }; + + pub fn monty_mul<'c, 'a>( + context: &'c Context, + block: &'a Block<'c>, + lhs: Value<'c, '_>, + rhs: Value<'c, '_>, + res_ty: Type<'c>, + location: Location<'c>, + ) -> Result> { + let i512 = IntegerType::new(context, 512).into(); + + let lhs = block.extui(lhs, i512, location)?; + let rhs = block.extui(rhs, i512, location)?; + + let t = block.muli(lhs, rhs, location)?; + + let result = monty_reduce(context, block, t, location)?; + + Ok(block.trunci(result, res_ty, location)?) + } + + pub fn monty_div<'c, 'a>( + context: &'c Context, + block: &'a Block<'c>, + lhs: Value<'c, '_>, + rhs: Value<'c, '_>, + res_ty: Type<'c>, + location: Location<'c>, + ) -> Result> { + let inv_rhs = monty_inverse(context, block, rhs, location)?; + monty_mul(context, block, lhs, inv_rhs, res_ty, location) + } + + /// Compute Montgomery modular inverse. + /// + /// The algorithm is given by B. S. Kaliski Jr. in "The Montgomery Inverse + /// and Its Applications". The algorithm consists of two phases: + /// 1. Compute x = a^{-1}2^{k} mod p, where n < k < 2n (denoted as + /// almost inverse). + /// 2. Corrects the result from phase 1 so that x = a^{-1}2^{n} mod p. + /// The algorithm can be check + /// [here](https://www.researchgate.net/publication/225962646_Efficient_Software-Implementation_of_Finite_Fields_with_Applications_to_Cryptography) + /// (Algorithm 17). + fn monty_inverse<'c, 'a>( + context: &'c Context, + block: &'a Block<'c>, + value: Value<'c, '_>, + location: Location<'c>, + ) -> Result> { + let value = block.extui(value, IntegerType::new(context, 256).into(), location)?; + let (r, k) = almost_inverse(context, block, value, location)?; + let inverse = inverse_correction(context, block, r, k, location)?; + + let r2 = block.const_int_from_type(context, location, *MONTY_R2, inverse.r#type())?; + monty_mul(context, block, inverse, r2, inverse.r#type(), location) + } + + fn inverse_correction<'c, 'a>( + context: &'c Context, + block: &'a Block<'c>, + r: Value<'c, '_>, + k: Value<'c, '_>, + location: Location<'c>, + ) -> Result> { + let i16 = IntegerType::new(context, 16).into(); + let i256 = IntegerType::new(context, 256).into(); + + let k0 = block.const_int(context, location, 0, 256)?; + let k0_i16 = block.const_int(context, location, 0, 16)?; + let k1_i16 = block.const_int(context, location, 1, 16)?; + let k1 = block.const_int(context, location, 1, 256)?; + let k256 = block.const_int(context, location, 256, 16)?; + + let loop_limit = block.subi(k, k256, location)?; + + let result = block.append_operation( + ods::scf::r#for( + context, + &[i256], + k0_i16, + loop_limit, + k1_i16, + &[r], + { + let region = Region::new(); + let loop_block = + region.append_block(Block::new(&[(i16, location), (i256, location)])); + + let r = loop_block.arg(1)?; + + let r_and_one = loop_block.andi(r, k1, location)?; + let is_r_even = loop_block.cmpi( + context, + arith::CmpiPredicate::Eq, + r_and_one, + k0, + location, + )?; + + let next_r = loop_block.append_op_result(scf::r#if( + is_r_even, + &[i256], + { + let region = Region::new(); + let block_then = region.append_block(Block::new(&[])); + + let result = block_then.shrui(r, k1, location)?; + + block_then.append_operation(scf::r#yield(&[result], location)); + + region + }, + { + let region = Region::new(); + let block_else = region.append_block(Block::new(&[])); + + let prime = + block_else.const_int(context, location, PRIME.clone(), 256)?; + + let result = block_else.addi(r, prime, location)?; + let result = block_else.shrui(result, k1, location)?; + + block_else.append_operation(scf::r#yield(&[result], location)); + + region + }, + location, + ))?; + + loop_block.append_operation(scf::r#yield(&[next_r], location)); + + region + }, + location, + ) + .into(), + ); + + Ok(result.result(0)?.into()) + } + + fn almost_inverse<'c, 'a>( + context: &'c Context, + block: &'a Block<'c>, + value: Value<'c, '_>, + location: Location<'c>, + ) -> Result<(Value<'c, 'a>, Value<'c, 'a>)> { + let i16 = IntegerType::new(context, 16).into(); + let value_ty = value.r#type(); + + let k0 = block.const_int_from_type(context, location, 0, value_ty)?; + let k0_i16 = block.const_int(context, location, 0, 16)?; + let prime = block.const_int_from_type(context, location, PRIME.clone(), value_ty)?; + let k1 = block.const_int_from_type(context, location, 1, value_ty)?; + let k1_i16 = block.const_int(context, location, 1, 16)?; + + let result = block.append_operation(scf::r#while( + &[prime, value, k0, k1, k0_i16], + &[value_ty, value_ty, value_ty, value_ty, i16], + { + let region = Region::new(); + let cond_block = region.append_block(Block::new(&[ + (value_ty, location), + (value_ty, location), + (value_ty, location), + (value_ty, location), + (i16, location), + ])); + let u = cond_block.arg(0)?; + let v = cond_block.arg(1)?; + let r = cond_block.arg(2)?; + let s = cond_block.arg(3)?; + + let u_is_even = { + let u_and_one = cond_block.andi(u, k1, location)?; + cond_block.cmpi(context, arith::CmpiPredicate::Eq, u_and_one, k0, location)? + }; + + // if u is even then + // u = u / 2 + // s = 2 * s + // else if v is even then + // v = v / 2 + // s = 2 * s + // else if u > v then + // u = (u − v) / 2 + // r = r + s + // s = 2 * s + // else if u <= v then + // v = (v − u) / 2 + // s = r + s + // r = 2 * r + let result = cond_block.append_operation(scf::r#if( + u_is_even, + &[value_ty, value_ty, value_ty, value_ty], + { + let region = Region::new(); + let u_even_block = region.append_block(Block::new(&[])); + + let u = u_even_block.shrui(u, k1, location)?; + let s = u_even_block.shli(s, k1, location)?; + + u_even_block.append_operation(scf::r#yield(&[u, v, r, s], location)); + + region + }, + { + let region = Region::new(); + let u_not_even_block = region.append_block(Block::new(&[])); + + let v_is_even = { + let v_and_one = u_not_even_block.andi(v, k1, location)?; + u_not_even_block.cmpi( + context, + arith::CmpiPredicate::Eq, + v_and_one, + k0, + location, + )? + }; + + let result = u_not_even_block.append_operation(scf::r#if( + v_is_even, + &[value_ty, value_ty, value_ty, value_ty], + { + let region = Region::new(); + let v_even_block = region.append_block(Block::new(&[])); + + let v = v_even_block.shrui(v, k1, location)?; + let r = v_even_block.shli(r, k1, location)?; + + v_even_block + .append_operation(scf::r#yield(&[u, v, r, s], location)); + + region + }, + { + let region = Region::new(); + let v_not_even_block = region.append_block(Block::new(&[])); + + let is_u_gt_v = v_not_even_block.cmpi( + context, + arith::CmpiPredicate::Ugt, + u, + v, + location, + )?; + + let result = v_not_even_block.append_operation(scf::r#if( + is_u_gt_v, + &[value_ty, value_ty, value_ty, value_ty], + { + let region = Region::new(); + let u_gt_v_block = region.append_block(Block::new(&[])); + + let u = { + let u_min_v = u_gt_v_block.subi(u, v, location)?; + u_gt_v_block.shrui(u_min_v, k1, location)? + }; + let r = u_gt_v_block.addi(r, s, location)?; + let s = u_gt_v_block.shli(s, k1, location)?; + + u_gt_v_block.append_operation(scf::r#yield( + &[u, v, r, s], + location, + )); + + region + }, + { + let region = Region::new(); + let v_ge_u_block = region.append_block(Block::new(&[])); + + let v = { + let v_min_u = v_ge_u_block.subi(v, u, location)?; + v_ge_u_block.shrui(v_min_u, k1, location)? + }; + let s = v_ge_u_block.addi(r, s, location)?; + let r = v_ge_u_block.shli(r, k1, location)?; + + v_ge_u_block.append_operation(scf::r#yield( + &[u, v, r, s], + location, + )); + + region + }, + location, + )); + + let u = result.result(0)?.into(); + let v = result.result(1)?.into(); + let r = result.result(2)?.into(); + let s = result.result(3)?.into(); + + v_not_even_block + .append_operation(scf::r#yield(&[u, v, r, s], location)); + + region + }, + location, + )); + + let u = result.result(0)?.into(); + let v = result.result(1)?.into(); + let r = result.result(2)?.into(); + let s = result.result(3)?.into(); + + u_not_even_block.append_operation(scf::r#yield(&[u, v, r, s], location)); + region + }, + location, + )); + + let u = result.result(0)?.into(); + let v = result.result(1)?.into(); + let r = result.result(2)?.into(); + let s = result.result(3)?.into(); + let k = cond_block.addi(cond_block.arg(4)?, k1_i16, location)?; + + let is_v_gt_zero = + cond_block.cmpi(context, arith::CmpiPredicate::Ugt, v, k0, location)?; + + cond_block.append_operation(scf::condition( + is_v_gt_zero, + &[u, v, r, s, k], + location, + )); + + region + }, + { + let region = Region::new(); + let loop_block = region.append_block(Block::new(&[ + (value_ty, location), + (value_ty, location), + (value_ty, location), + (value_ty, location), + (i16, location), + ])); + + let u = loop_block.arg(0)?; + let v = loop_block.arg(1)?; + let r = loop_block.arg(2)?; + let s = loop_block.arg(3)?; + let k = loop_block.arg(4)?; + + loop_block.append_operation(scf::r#yield(&[u, v, r, s, k], location)); + + region + }, + location, + )); + + let (almost_inv, k) = { + // if r >= p: + // r = r − p + // else: + // r + // return (p - r), k + let k = result.result(4)?.into(); + let r = { + let r = result.result(2)?.into(); + let r_wrapped = block.subi(r, prime, location)?; + let r_ge_prime = + block.cmpi(context, arith::CmpiPredicate::Uge, r, prime, location)?; + let r = + block.append_op_result(arith::select(r_ge_prime, r_wrapped, r, location))?; + + block.subi(prime, r, location)? + }; + + (r, k) + }; + + Ok((almost_inv, k)) + } + + fn monty_reduce<'c, 'a>( + context: &'c Context, + block: &'a Block<'c>, + x: Value<'c, '_>, + location: Location<'c>, + ) -> Result> { + let mu = block.const_int(context, location, MONTY_MU_U256.clone(), 512)?; + let r_minus_1 = block.const_int(context, location, MONTY_R.clone() - 1u8, 512)?; + let k256 = block.const_int(context, location, 256, 512)?; + let modulus = block.const_int(context, location, PRIME.clone(), 512)?; + + // q = (value * mu) mod r. + let q = block.muli(x, mu, location)?; + let q = block.andi(q, r_minus_1, location)?; + // m = q * modulus. + let m = block.muli(q, modulus, location)?; + // y = (value - m) / r. + let y = block.subi(x, m, location)?; + let y = block.shrui(y, k256, location)?; + // if (m > x): + // y = y + modulus + let y_plus_mod = block.addi(y, modulus, location)?; + + let is_negative = block.cmpi(context, arith::CmpiPredicate::Ugt, m, x, location)?; + + Ok(block.append_op_result(arith::select(is_negative, y_plus_mod, y, location))?) + } +} + +#[cfg(test)] +mod tests { + use crate::utils::{ + montgomery::{monty_reduction, monty_transform, MontyBytes}, + PRIME, + }; + use lambdaworks_math::{traits::ByteConversion, unsigned_integer::element::U256}; + use starknet_types_core::felt::Felt; + + #[test] + fn felt_to_bytes_raw() { + let felt = Felt::from(10); + let bytes = felt.to_bytes_le_raw(); + let felt_from_raw = { + let value = U256::from_bytes_le(&bytes).unwrap(); + Felt::from_raw(value.limbs) + }; + + assert_eq!(felt_from_raw, felt); + + let felt = Felt::from(-10); + let bytes = felt.to_bytes_le_raw(); + let felt_from_raw = { + let value = U256::from_bytes_le(&bytes).unwrap(); + Felt::from_raw(value.limbs) + }; + + assert_eq!(felt_from_raw, felt); + + let felt = Felt::from(PRIME.clone()); + let bytes = felt.to_bytes_le_raw(); + let felt_from_raw = { + let value = U256::from_bytes_le(&bytes).unwrap(); + Felt::from_raw(value.limbs) + }; + + assert_eq!(felt_from_raw, felt); + } + + #[test] + fn felt_to_monty_to_felt() { + let felt = Felt::from(10).to_biguint(); + let monty_felt = monty_transform(&felt, &PRIME).unwrap(); + let reduced_monty_felt = monty_reduction(&monty_felt, &PRIME).unwrap(); + + assert_eq!(reduced_monty_felt, felt); + + let felt = Felt::from(-10).to_biguint(); + let monty_felt = monty_transform(&felt, &PRIME).unwrap(); + let reduced_monty_felt = monty_reduction(&monty_felt, &PRIME).unwrap(); + + assert_eq!(reduced_monty_felt, felt); + + let felt = Felt::from(PRIME.clone()).to_biguint(); + let monty_felt = monty_transform(&felt, &PRIME).unwrap(); + let reduced_monty_felt = monty_reduction(&monty_felt, &PRIME).unwrap(); + + assert_eq!(reduced_monty_felt, felt); + } +} diff --git a/src/values.rs b/src/values.rs index c561cf6ba2..eee7077fbb 100644 --- a/src/values.rs +++ b/src/values.rs @@ -9,7 +9,8 @@ use crate::{ starknet::{Secp256k1Point, Secp256r1Point}, types::TypeBuilder, utils::{ - felt252_bigint, get_integer_layout, layout_repeat, libc_free, libc_malloc, RangeExt, PRIME, + felt252_bigint, get_integer_layout, layout_repeat, libc_free, libc_malloc, + montgomery::MontyBytes, RangeExt, PRIME, }, }; use bumpalo::Bump; @@ -24,6 +25,7 @@ use cairo_lang_sierra::{ program_registry::ProgramRegistry, }; use educe::Educe; +use lambdaworks_math::{traits::ByteConversion, unsigned_integer::element::U256}; use num_bigint::{BigInt, BigUint, Sign}; use num_traits::{Euclid, One}; use starknet_types_core::felt::Felt; @@ -175,7 +177,7 @@ impl Value { Self::Felt252(value) => { let ptr = arena.alloc_layout(get_integer_layout(252)).cast(); - let data = felt252_bigint(value.to_bigint()).to_bytes_le(); + let data = value.to_bytes_le_raw(); ptr.cast::<[u8; 32]>().as_mut().copy_from_slice(&data); ptr } @@ -431,7 +433,7 @@ impl Value { // next key must be called before next_value for (key, value) in map.iter() { - let key = key.to_bytes_le(); + let key = key.to_bytes_le_raw(); let value = value.to_ptr(arena, registry, &info.ty, find_dict_drop_override)?; @@ -739,8 +741,8 @@ impl Value { CoreTypeConcrete::Felt252(_) => { let data = ptr.cast::<[u8; 32]>().as_mut(); data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - let data = Felt::from_bytes_le_slice(data); - Self::Felt252(data) + let data = U256::from_bytes_le(data).unwrap(); + Self::Felt252(Felt::from_raw(data.limbs)) } CoreTypeConcrete::Uint8(_) => Self::Uint8(*ptr.cast::().as_ref()), CoreTypeConcrete::Uint16(_) => Self::Uint16(*ptr.cast::().as_ref()), @@ -866,7 +868,11 @@ impl Value { let mut key = key; key[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - let key = Felt::from_bytes_le(&key); + // TODO: add comment here. + let key = { + let key = U256::from_bytes_le(&key).unwrap(); + Felt::from_raw(key.limbs) + }; // The dictionary items are not being dropped here. They'll be dropped along // with the dictionary (if requested using `should_drop`). output_map.insert( @@ -920,8 +926,8 @@ impl Value { // felt values let data = ptr.cast::<[u8; 32]>().as_mut(); data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - let data = Felt::from_bytes_le(data); - Self::Felt252(data) + let data = U256::from_bytes_le(data).unwrap(); + Self::Felt252(Felt::from_raw(data.limbs)) } StarknetTypeConcrete::System(_) => { native_panic!("should be handled before") @@ -1195,6 +1201,7 @@ mod test { let registry = ProgramRegistry::::new(&program).unwrap(); + // Assert bytes of Montgomery form of 42_felt252. assert_eq!( unsafe { *Value::Felt252(Felt::from(42)) @@ -1208,9 +1215,13 @@ mod test { .cast::<[u32; 8]>() .as_ptr() }, - [42, 0, 0, 0, 0, 0, 0, 0] + [ + 4294965953, 4294967295, 4294967295, 4294967295, 4294967295, 4294967295, 4294944464, + 134217727 + ] ); + // Assert bytes of Montgomery form of Felt::MAX. assert_eq!( unsafe { *Value::Felt252(Felt::MAX) @@ -1224,8 +1235,8 @@ mod test { .cast::<[u32; 8]>() .as_ptr() }, - // 0x800000000000011000000000000000000000000000000000000000000000001 - 1 - [0, 0, 0, 0, 0, 0, 17, 134217728] + // Montgomery(0x800000000000011000000000000000000000000000000000000000000000001 - 1) + [32, 0, 0, 0, 0, 0, 544, 0] ); assert_eq!(