Skip to content

Commit 6269ca2

Browse files
move felt divition to a mlir function
1 parent d8c86dd commit 6269ca2

File tree

4 files changed

+51
-113
lines changed

4 files changed

+51
-113
lines changed

.cargo/config.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[install]
2+
root = "binaries/"

src/libfuncs/circuit.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ fn build_gate_evaluation<'ctx, 'this>(
608608
location,
609609
rhs_value,
610610
circuit_modulus_u768,
611+
u768_type
611612
)?;
612613
// Extract the values from the result struct
613614
let gcd =

src/libfuncs/felt252.rs

Lines changed: 35 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
33
use super::LibfuncHelper;
44
use crate::{
5-
error::Result,
6-
metadata::MetadataStorage,
5+
error::{panic::ToNativeAssertError, Result},
6+
metadata::{runtime_bindings::RuntimeBindingsMeta, MetadataStorage},
77
utils::{ProgramRegistryExt, PRIME},
88
};
99
use cairo_lang_sierra::{
@@ -19,11 +19,8 @@ use cairo_lang_sierra::{
1919
program_registry::ProgramRegistry,
2020
};
2121
use melior::{
22-
dialect::{
23-
arith::{self, CmpiPredicate},
24-
cf,
25-
},
26-
helpers::{ArithBlockExt, BuiltinBlockExt},
22+
dialect::arith::{self, CmpiPredicate},
23+
helpers::{ArithBlockExt, BuiltinBlockExt, LlvmBlockExt},
2724
ir::{r#type::IntegerType, Block, BlockLike, Location, Value, ValueLike},
2825
Context,
2926
};
@@ -149,80 +146,35 @@ pub fn build_binary_operation<'ctx, 'this>(
149146
entry.trunci(result, felt252_ty, location)?
150147
}
151148
Felt252BinaryOperator::Div => {
152-
// The extended euclidean algorithm calculates the greatest common divisor of two integers,
153-
// as well as the bezout coefficients x and y such that for inputs a and b, ax+by=gcd(a,b)
154-
// We use this in felt division to find the modular inverse of a given number
155-
// If a is the number we're trying to find the inverse of, we can do
156-
// ax+y*PRIME=gcd(a,PRIME)=1 => ax = 1 (mod PRIME)
157-
// Hence for input a, we return x
158-
// The input MUST be non-zero
159-
// See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm
160-
let start_block = helper.append_block(Block::new(&[(i512, location)]));
161-
let loop_block = helper.append_block(Block::new(&[
162-
(i512, location),
163-
(i512, location),
164-
(i512, location),
165-
(i512, location),
166-
]));
167-
let negative_check_block = helper.append_block(Block::new(&[]));
168-
// Block containing final result
169-
let inverse_result_block = helper.append_block(Block::new(&[(i512, location)]));
170-
// Egcd works by calculating a series of remainders, each the remainder of dividing the previous two
171-
// For the initial setup, r0 = PRIME, r1 = a
172-
// This order is chosen because if we reverse them, then the first iteration will just swap them
173-
let prev_remainder =
174-
start_block.const_int_from_type(context, location, PRIME.clone(), i512)?;
175-
let remainder = start_block.arg(0)?;
176-
// Similarly we'll calculate another series which starts 0,1,... and from which we will retrieve the modular inverse of a
177-
let prev_inverse = start_block.const_int_from_type(context, location, 0, i512)?;
178-
let inverse = start_block.const_int_from_type(context, location, 1, i512)?;
179-
start_block.append_operation(cf::br(
180-
loop_block,
181-
&[prev_remainder, remainder, prev_inverse, inverse],
182-
location,
183-
));
184-
185-
//---Loop body---
186-
// Arguments are rem_(i-1), rem, inv_(i-1), inv
187-
let prev_remainder = loop_block.arg(0)?;
188-
let remainder = loop_block.arg(1)?;
189-
let prev_inverse = loop_block.arg(2)?;
190-
let inverse = loop_block.arg(3)?;
191-
192-
// First calculate q = rem_(i-1)/rem_i, rounded down
193-
let quotient =
194-
loop_block.append_op_result(arith::divui(prev_remainder, remainder, location))?;
195-
// Then r_(i+1) = r_(i-1) - q * r_i, and inv_(i+1) = inv_(i-1) - q * inv_i
196-
let rem_times_quo = loop_block.muli(remainder, quotient, location)?;
197-
let inv_times_quo = loop_block.muli(inverse, quotient, location)?;
198-
let next_remainder = loop_block.append_op_result(arith::subi(
199-
prev_remainder,
200-
rem_times_quo,
201-
location,
202-
))?;
203-
let next_inverse =
204-
loop_block.append_op_result(arith::subi(prev_inverse, inv_times_quo, location))?;
205-
206-
// If r_(i+1) is 0, then inv_i is the inverse
207-
let zero = loop_block.const_int_from_type(context, location, 0, i512)?;
208-
let next_remainder_eq_zero =
209-
loop_block.cmpi(context, CmpiPredicate::Eq, next_remainder, zero, location)?;
210-
loop_block.append_operation(cf::cond_br(
149+
let runtime_bindings_meta = metadata
150+
.get_mut::<RuntimeBindingsMeta>()
151+
.to_native_assert_error(
152+
"Unable to get the RuntimeBindingsMeta from MetadataStorage",
153+
)?;
154+
155+
let prime = entry.const_int_from_type(context, location, PRIME.clone(), i512)?;
156+
let lhs = entry.extui(lhs, i512, location)?;
157+
let rhs = entry.extui(rhs, i512, location)?;
158+
159+
// Find 1 / rhs.
160+
let euclidean_result = runtime_bindings_meta.extended_euclidean_algorithm(
211161
context,
212-
next_remainder_eq_zero,
213-
negative_check_block,
214-
loop_block,
215-
&[],
216-
&[remainder, next_remainder, inverse, next_inverse],
162+
helper.module,
163+
entry,
217164
location,
218-
));
165+
rhs,
166+
prime,
167+
i512
168+
)?;
169+
170+
let inverse = entry.extract_value(context, location, euclidean_result, i512, 1)?;
219171

220172
// egcd sometimes returns a negative number for the inverse,
221173
// in such cases we must simply wrap it around back into [0, PRIME)
222174
// this suffices because |inv_i| <= divfloor(PRIME,2)
223-
let zero = negative_check_block.const_int_from_type(context, location, 0, i512)?;
175+
let zero = entry.const_int_from_type(context, location, 0, i512)?;
224176

225-
let is_negative = negative_check_block
177+
let is_negative = entry
226178
.append_operation(arith::cmpi(
227179
context,
228180
CmpiPredicate::Slt,
@@ -233,46 +185,30 @@ pub fn build_binary_operation<'ctx, 'this>(
233185
.result(0)?
234186
.into();
235187
// if the inverse is < 0, add PRIME
236-
let prime =
237-
negative_check_block.const_int_from_type(context, location, PRIME.clone(), i512)?;
238-
let wrapped_inverse = negative_check_block.addi(inverse, prime, location)?;
239-
let inverse = negative_check_block.append_op_result(arith::select(
188+
let wrapped_inverse = entry.addi(inverse, prime, location)?;
189+
let inverse = entry.append_op_result(arith::select(
240190
is_negative,
241191
wrapped_inverse,
242192
inverse,
243193
location,
244194
))?;
245-
negative_check_block.append_operation(cf::br(
246-
inverse_result_block,
247-
&[inverse],
248-
location,
249-
));
250195

251-
// Div Logic Start
252-
// Fetch operands
253-
let lhs = entry.extui(lhs, i512, location)?;
254-
let rhs = entry.extui(rhs, i512, location)?;
255-
// Calculate inverse of rhs, callling the inverse implementation's starting block
256-
entry.append_operation(cf::br(start_block, &[rhs], location));
257-
// Fetch the inverse result from the result block
258-
let inverse = inverse_result_block.arg(0)?;
259-
// Peform lhs * (1/ rhs)
260-
let result = inverse_result_block.muli(lhs, inverse, location)?;
196+
// Peform lhs * (1 / rhs)
197+
let result = entry.muli(lhs, inverse, location)?;
261198
// Apply modulo and convert result to felt252
262-
let result_mod =
263-
inverse_result_block.append_op_result(arith::remui(result, prime, location))?;
199+
let result_mod = entry.append_op_result(arith::remui(result, prime, location))?;
264200
let is_out_of_range =
265-
inverse_result_block.cmpi(context, CmpiPredicate::Uge, result, prime, location)?;
201+
entry.cmpi(context, CmpiPredicate::Uge, result, prime, location)?;
266202

267-
let result = inverse_result_block.append_op_result(arith::select(
203+
let result = entry.append_op_result(arith::select(
268204
is_out_of_range,
269205
result_mod,
270206
result,
271207
location,
272208
))?;
273-
let result = inverse_result_block.trunci(result, felt252_ty, location)?;
209+
let result = entry.trunci(result, felt252_ty, location)?;
274210

275-
return helper.br(inverse_result_block, 0, &[result], location);
211+
return helper.br(entry, 0, &[result], location);
276212
}
277213
};
278214

src/metadata/runtime_bindings.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ impl RuntimeBindingsMeta {
210210
location: Location<'c>,
211211
a: Value<'c, '_>,
212212
b: Value<'c, '_>,
213+
integer_type: Type<'c>
213214
) -> Result<Value<'c, 'a>>
214215
where
215216
'c: 'a,
@@ -219,9 +220,8 @@ impl RuntimeBindingsMeta {
219220
.active_map
220221
.insert(RuntimeBinding::ExtendedEuclideanAlgorithm)
221222
{
222-
build_egcd_function(module, context, location, func_symbol)?;
223+
build_egcd_function(module, context, location, func_symbol, integer_type)?;
223224
}
224-
let integer_type: Type = IntegerType::new(context, 384 * 2).into();
225225
// The struct returned by the function that contains both of the results
226226
let return_type = llvm::r#type::r#struct(context, &[integer_type, integer_type], false);
227227
Ok(block
@@ -825,14 +825,24 @@ fn build_egcd_function<'ctx>(
825825
context: &'ctx Context,
826826
location: Location<'ctx>,
827827
func_symbol: &str,
828+
integer_type: Type,
828829
) -> Result<()> {
829-
let integer_type: Type = IntegerType::new(context, 384 * 2).into();
830830
let region = Region::new();
831831

832832
let entry_block = region.append_block(Block::new(&[
833833
(integer_type, location),
834834
(integer_type, location),
835835
]));
836+
let loop_block = region.append_block(Block::new(&[
837+
(integer_type, location),
838+
(integer_type, location),
839+
(integer_type, location),
840+
(integer_type, location),
841+
]));
842+
let end_block = region.append_block(Block::new(&[
843+
(integer_type, location),
844+
(integer_type, location),
845+
]));
836846

837847
let a = entry_block.arg(0)?;
838848
let b = entry_block.arg(1)?;
@@ -847,17 +857,6 @@ fn build_egcd_function<'ctx>(
847857
let prev_inverse = entry_block.const_int_from_type(context, location, 0, integer_type)?;
848858
let inverse = entry_block.const_int_from_type(context, location, 1, integer_type)?;
849859

850-
let loop_block = region.append_block(Block::new(&[
851-
(integer_type, location),
852-
(integer_type, location),
853-
(integer_type, location),
854-
(integer_type, location),
855-
]));
856-
let end_block = region.append_block(Block::new(&[
857-
(integer_type, location),
858-
(integer_type, location),
859-
]));
860-
861860
entry_block.append_operation(cf::br(
862861
&loop_block,
863862
&[prev_remainder, remainder, prev_inverse, inverse],

0 commit comments

Comments
 (0)