Skip to content

Commit 89ef937

Browse files
Merge branch 'main' into montgmomery-felts
2 parents cdd4e66 + ac3cb7f commit 89ef937

File tree

4 files changed

+178
-95
lines changed

4 files changed

+178
-95
lines changed

src/executor.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,17 @@ fn invoke_dynamic(
110110
.peekable();
111111

112112
let num_return_args = ret_types_iter.clone().count();
113+
// If there is more than one return value, or the return value is _complex_,
114+
// as defined by the architecture ABI, then we pass a return pointer as
115+
// the first argument to the program entrypoint.
113116
let mut return_ptr = if num_return_args > 1
114117
|| ret_types_iter
115118
.peek()
116119
.map(|id| registry.get_type(id)?.is_complex(registry))
117120
.transpose()?
118121
== Some(true)
119122
{
123+
// The return pointer should be able to hold all the return values.
120124
let layout = ret_types_iter.try_fold(Layout::new::<()>(), |layout, id| {
121125
let type_info = registry.get_type(id)?;
122126
Result::<_, Error>::Ok(layout.extend(type_info.layout(registry)?)?.0)

src/libfuncs/circuit.rs

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -595,28 +595,23 @@ fn build_gate_evaluation<'ctx, 'this>(
595595
}
596596
// INV: lhs = 1 / rhs
597597
(None, Some(rhs_value), Some(_)) => {
598-
// Extend to avoid overflow
599-
let u768_type = IntegerType::new(context, 768).into();
600-
let rhs_value = block.extui(rhs_value, u768_type, location)?;
601-
let circuit_modulus_u768 = block.extui(circuit_modulus, u768_type, location)?;
602-
603598
// Apply egcd to find gcd and inverse
604599
let euclidean_result = runtime_bindings_meta.extended_euclidean_algorithm(
605600
context,
606601
helper.module,
607602
block,
608603
location,
609604
rhs_value,
610-
circuit_modulus_u768,
605+
circuit_modulus,
611606
)?;
612607
// Extract the values from the result struct
613608
let gcd =
614-
block.extract_value(context, location, euclidean_result, u768_type, 0)?;
609+
block.extract_value(context, location, euclidean_result, u384_type, 0)?;
615610
let inverse =
616-
block.extract_value(context, location, euclidean_result, u768_type, 1)?;
611+
block.extract_value(context, location, euclidean_result, u384_type, 1)?;
617612

618613
// if the gcd is not 1, then fail (a and b are not coprimes)
619-
let one = block.const_int_from_type(context, location, 1, u768_type)?;
614+
let one = block.const_int_from_type(context, location, 1, u384_type)?;
620615
let gate_offset_idx_value = block.const_int_from_type(
621616
context,
622617
location,
@@ -637,7 +632,7 @@ fn build_gate_evaluation<'ctx, 'this>(
637632
block = has_inverse_block;
638633

639634
// if the inverse is negative, then add modulus
640-
let zero = block.const_int_from_type(context, location, 0, u768_type)?;
635+
let zero = block.const_int_from_type(context, location, 0, u384_type)?;
641636
let is_negative = block
642637
.append_operation(arith::cmpi(
643638
context,
@@ -648,17 +643,14 @@ fn build_gate_evaluation<'ctx, 'this>(
648643
))
649644
.result(0)?
650645
.into();
651-
let wrapped_inverse = block.addi(inverse, circuit_modulus_u768, location)?;
646+
let wrapped_inverse = block.addi(inverse, circuit_modulus, location)?;
652647
let inverse = block.append_op_result(arith::select(
653648
is_negative,
654649
wrapped_inverse,
655650
inverse,
656651
location,
657652
))?;
658653

659-
// Truncate back
660-
let inverse = block.trunci(inverse, u384_type, location)?;
661-
662654
gates[gate_offset.lhs] = Some(inverse);
663655
}
664656
// The imposibility to solve this mul gate offset would render the circuit unsolvable

src/metadata/runtime_bindings.rs

Lines changed: 137 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ impl RuntimeBindingsMeta {
221221
{
222222
build_egcd_function(module, context, location, func_symbol)?;
223223
}
224-
let integer_type: Type = IntegerType::new(context, 384 * 2).into();
224+
let integer_type: Type = IntegerType::new(context, 384).into();
225225
// The struct returned by the function that contains both of the results
226226
let return_type = llvm::r#type::r#struct(context, &[integer_type, integer_type], false);
227227
Ok(block
@@ -813,105 +813,164 @@ pub fn setup_runtime(find_symbol_ptr: impl Fn(&str) -> Option<*mut c_void>) {
813813
}
814814
}
815815

816-
/// The extended euclidean algorithm calculates the greatest common divisor (gcd) of two integers a and b,
817-
/// as well as the bezout coefficients x and y such that ax+by=gcd(a,b)
818-
/// if gcd(a,b) = 1, then x is the modular multiplicative inverse of a modulo b.
819-
/// See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm
816+
/// Build the extended euclidean algorithm MLIR function.
820817
///
821-
/// This function declares a MLIR function that given two numbers a and b, returns a MLIR struct with gcd(a, b)
822-
/// and the bezout coefficient x. The declaration is done in the body of the module.
818+
/// The extended euclidean algorithm calculates the greatest common divisor
819+
/// (gcd) of two integers `a` and `b`, as well as the Bézout coefficients `x`
820+
/// and `y` such that `ax + by = gcd(a,b)`. If `gcd(a,b) = 1`, then `x` is the
821+
/// modular multiplicative inverse of `a` modulo `b`.
822+
///
823+
/// This function declares a MLIR function that given two 384 bit integers `a`
824+
/// and `b`, returns a MLIR struct with `gcd(a,b)` and the Bézout coefficient
825+
/// `x`. The declaration is done in the body of the module.
823826
fn build_egcd_function<'ctx>(
824827
module: &Module,
825828
context: &'ctx Context,
826829
location: Location<'ctx>,
827830
func_symbol: &str,
828831
) -> Result<()> {
829-
let integer_type: Type = IntegerType::new(context, 384 * 2).into();
832+
let integer_width = 384;
833+
let integer_type = IntegerType::new(context, integer_width).into();
834+
835+
// Pseudocode for calculating the EGCD of two integers `a` and `b`.
836+
// https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Pseudocode.
837+
//
838+
// ```
839+
// (old_r, new_r) := (a, b)
840+
// (old_s, new_s) := (1, 0)
841+
//
842+
// while new_r != 0 do
843+
// quotient := old_r / new_r
844+
// (old_r, new_r) := (new_r, old_r − quotient * new_r)
845+
// (old_s, new_s) := (new_s, old_s − quotient * new_s)
846+
//
847+
// old_s is equal to Bézout coefficient X
848+
// old_r is equal to GCD
849+
// ```
850+
//
851+
// Note that when `b > a`, the first iteration inverts the values. Our
852+
// implementation does it manually as we already know that `b > a`.
853+
//
854+
// The core idea of the method is that `gcd(a,b) = gcd(a,b-a)`, and that
855+
// `gcd(a,b) = gcd(b,a)`. As an optimization, we can actually substract `a`
856+
// from `b` as many times as possible, so `gcd(a,b) = gcd(b%a,a)`.
857+
//
858+
// Take, for example, `a=21` and `b=54`:
859+
//
860+
// gcd(21, 54)
861+
// = gcd(12, 21)
862+
// = gcd(9, 12)
863+
// = gcd(3, 9)
864+
// = gcd(0, 3)
865+
// = 3
866+
//
867+
// Thus, the algorithm works by calculating a series of remainders `r` which
868+
// starts with b,a,... being `r[i]` the remainder of dividing `r[i-2]` by
869+
// `r[i-1]`. At each step, `r[i]` can be calculated as:
870+
//
871+
// r[i] = r[i-2] - r[i-1] * quotient
872+
//
873+
// The GCD will be the last non-zero remainder.
874+
//
875+
// [54; 21; 12; 9; 3; 0]
876+
// ^
877+
//
878+
// See Dr. Katherine Stange's Youtube video for a better explanation on how
879+
// this works: https://www.youtube.com/watch?v=Jwf6ncRmhPg.
880+
//
881+
// The extended algorithm also obtains the Bézout coefficients
882+
// by calculating a series of coefficients `s`. See Dr. Katherine
883+
// Stange's Youtube video for a better explanation on how this works:
884+
// https://www.youtube.com/watch?v=IwRtISxAHY4.
885+
886+
// Define entry block for function. Receives arguments `a` and `b`.
830887
let region = Region::new();
831-
832888
let entry_block = region.append_block(Block::new(&[
833-
(integer_type, location),
834-
(integer_type, location),
889+
(integer_type, location), // a
890+
(integer_type, location), // b
835891
]));
836892

837-
let a = entry_block.arg(0)?;
838-
let b = entry_block.arg(1)?;
839-
// The egcd algorithm works by calculating a series of remainders `rem`, being each `rem_i` the remainder of dividing `rem_{i-1}` with `rem_{i-2}`
840-
// For the initial setup, rem_0 = b, rem_1 = a.
841-
// This order is chosen because if we reverse them, then the first iteration will just swap them
842-
let remainder = a;
843-
let prev_remainder = b;
844-
845-
// Similarly we'll calculate another series which starts 0,1,... and from which we
846-
// will retrieve the modular inverse of a
847-
let prev_inverse = entry_block.const_int_from_type(context, location, 0, integer_type)?;
848-
let inverse = entry_block.const_int_from_type(context, location, 1, integer_type)?;
849-
893+
// Define loop block for function. Each iteration last two values from each series.
850894
let loop_block = region.append_block(Block::new(&[
851-
(integer_type, location),
852-
(integer_type, location),
853-
(integer_type, location),
854-
(integer_type, location),
895+
(integer_type, location), // old_r
896+
(integer_type, location), // new_r
897+
(integer_type, location), // old_s
898+
(integer_type, location), // new_s
855899
]));
900+
901+
// Define end block for function.
856902
let end_block = region.append_block(Block::new(&[
857-
(integer_type, location),
858-
(integer_type, location),
903+
(integer_type, location), // old_r
904+
(integer_type, location), // old_s
859905
]));
860906

907+
// Jump to loop block from entry block, with initial values.
908+
// - old_r = b
909+
// - new_r = a
910+
// - old_s = 0
911+
// - new_s = 1
861912
entry_block.append_operation(cf::br(
862913
&loop_block,
863-
&[prev_remainder, remainder, prev_inverse, inverse],
914+
&[
915+
entry_block.arg(1)?,
916+
entry_block.arg(0)?,
917+
entry_block.const_int_from_type(context, location, 0, integer_type)?,
918+
entry_block.const_int_from_type(context, location, 1, integer_type)?,
919+
],
864920
location,
865921
));
866922

867-
// -- Loop body --
868-
// Arguments are rem_(i-1), rem, inv_(i-1), inv
869-
let prev_remainder = loop_block.arg(0)?;
870-
let remainder = loop_block.arg(1)?;
871-
let prev_inverse = loop_block.arg(2)?;
872-
let inverse = loop_block.arg(3)?;
873-
874-
// First calculate q = rem_(i-1)/rem_i, rounded down
875-
let quotient =
876-
loop_block.append_op_result(arith::divui(prev_remainder, remainder, location))?;
877-
878-
// Then rem_(i+1) = rem_(i-1) - q * rem_i, and inv_(i+1) = inv_(i-1) - q * inv_i
879-
let rem_times_quo = loop_block.muli(remainder, quotient, location)?;
880-
let inv_times_quo = loop_block.muli(inverse, quotient, location)?;
881-
let next_remainder =
882-
loop_block.append_op_result(arith::subi(prev_remainder, rem_times_quo, location))?;
883-
let next_inverse =
884-
loop_block.append_op_result(arith::subi(prev_inverse, inv_times_quo, location))?;
885-
886-
// Check if rem_(i+1) is 0
887-
// If true, then:
888-
// - rem_i is the gcd of a and b
889-
// - inv_i is the bezout coefficient x
890-
let zero = loop_block.const_int_from_type(context, location, 0, integer_type)?;
891-
let next_remainder_eq_zero =
892-
loop_block.cmpi(context, CmpiPredicate::Eq, next_remainder, zero, location)?;
893-
loop_block.append_operation(cf::cond_br(
894-
context,
895-
next_remainder_eq_zero,
896-
&end_block,
897-
&loop_block,
898-
&[remainder, inverse],
899-
&[remainder, next_remainder, inverse, next_inverse],
900-
location,
901-
));
923+
// LOOP BLOCK
924+
{
925+
let old_r = loop_block.arg(0)?;
926+
let new_r = loop_block.arg(1)?;
927+
let old_s = loop_block.arg(2)?;
928+
let new_s = loop_block.arg(3)?;
929+
930+
// First calculate quotient of old_r/new_r.
931+
let quotient = loop_block.append_op_result(arith::divui(old_r, new_r, location))?;
932+
933+
// Multiply quotient by new_r and new_s.
934+
let quotient_by_new_r = loop_block.muli(quotient, new_r, location)?;
935+
let quotient_by_new_s = loop_block.muli(quotient, new_s, location)?;
936+
937+
// Calculate new values for next iteration.
938+
// - next_new_r := old_r − quotient * new_r
939+
// - next_new_s := old_s − quotient * new_s
940+
let next_new_r =
941+
loop_block.append_op_result(arith::subi(old_r, quotient_by_new_r, location))?;
942+
let next_new_s =
943+
loop_block.append_op_result(arith::subi(old_s, quotient_by_new_s, location))?;
944+
945+
// Jump to end block if next_new_r is zero.
946+
let zero = loop_block.const_int_from_type(context, location, 0, integer_type)?;
947+
let next_new_r_is_zero =
948+
loop_block.cmpi(context, CmpiPredicate::Eq, next_new_r, zero, location)?;
949+
loop_block.append_operation(cf::cond_br(
950+
context,
951+
next_new_r_is_zero,
952+
&end_block,
953+
&loop_block,
954+
&[new_r, new_s],
955+
&[new_r, next_new_r, new_s, next_new_s],
956+
location,
957+
));
958+
}
902959

903-
// Create the struct that will contain the results
904-
let results = end_block.append_op_result(llvm::undef(
905-
llvm::r#type::r#struct(context, &[integer_type, integer_type], false),
906-
location,
907-
))?;
908-
let results = end_block.insert_values(
909-
context,
910-
location,
911-
results,
912-
&[end_block.arg(0)?, end_block.arg(1)?],
913-
)?;
914-
end_block.append_operation(llvm::r#return(Some(results), location));
960+
// END BLOCK
961+
{
962+
let results = end_block.append_op_result(llvm::undef(
963+
llvm::r#type::r#struct(context, &[integer_type, integer_type], false),
964+
location,
965+
))?;
966+
let results = end_block.insert_values(
967+
context,
968+
location,
969+
results,
970+
&[end_block.arg(0)?, end_block.arg(1)?],
971+
)?;
972+
end_block.append_operation(llvm::r#return(Some(results), location));
973+
}
915974

916975
let func_name = StringAttribute::new(context, func_symbol);
917976
module.body().append_operation(llvm::func(

src/types.rs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,26 @@ pub trait TypeBuilder {
8585

8686
/// Return whether the type is a builtin.
8787
fn is_builtin(&self) -> bool;
88-
/// Return whether the type requires a return pointer when returning.
88+
/// Return whether the type requires a return pointer when returning,
89+
/// instead of using the CPU registers.
90+
///
91+
/// This attribute does not modify the compilation, and it only reflects
92+
/// what the ABI of the target architecture already specifies.
93+
/// - For x86-64: https://gitlab.com/x86-psABIs/x86-64-ABI.
94+
/// - For AArch64: https://github.com/ARM-software/abi-aa.
95+
///
96+
/// We can validate this empirically, by building a Cairo program that
97+
/// returns a particular type, and seeing how it is lowered to machine code.
98+
///
99+
/// ```bash
100+
/// llc a.llvmir -o - --mtriple "aarch64"
101+
/// llc a.llvmir -o - --mtriple "x86_64"
102+
/// ```
89103
fn is_complex(
90104
&self,
91105
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
92106
) -> Result<bool, Self::Error>;
107+
93108
/// Return whether the Sierra type resolves to a zero-sized type.
94109
fn is_zst(
95110
&self,
@@ -104,8 +119,21 @@ pub trait TypeBuilder {
104119
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
105120
) -> Result<Layout, Self::Error>;
106121

107-
/// Whether the layout should be allocated in memory (either the stack or the heap) when used as
108-
/// a function invocation argument or return value.
122+
/// Whether the layout should be allocated in memory (either the stack or
123+
/// the heap) when used as a function invocation argument or return value.
124+
///
125+
/// Unlike `is_complex`, this attribute alters the compilation:
126+
///
127+
/// - When passing a memory allocated value to a function, we allocate that
128+
/// value on the stack, and pass a pointer to it.
129+
///
130+
/// - If a function returns a memory allocated value, we receive a return
131+
/// pointer as its first argument, and write the return value there
132+
/// instead.
133+
///
134+
/// The rationale behind allocating a value in memory, rather than
135+
/// registers, is to avoid putting too much pressure on the register
136+
/// allocation pass for really complex types, like enums.
109137
fn is_memory_allocated(
110138
&self,
111139
registry: &ProgramRegistry<CoreType, CoreLibfunc>,

0 commit comments

Comments
 (0)