@@ -221,7 +221,7 @@ impl RuntimeBindingsMeta {
221221 {
222222 build_egcd_function ( module, context, location, func_symbol) ?;
223223 }
224- let integer_type: Type = IntegerType :: new ( context, 384 * 2 ) . into ( ) ;
224+ let integer_type: Type = IntegerType :: new ( context, 384 ) . into ( ) ;
225225 // The struct returned by the function that contains both of the results
226226 let return_type = llvm:: r#type:: r#struct ( context, & [ integer_type, integer_type] , false ) ;
227227 Ok ( block
@@ -813,105 +813,164 @@ pub fn setup_runtime(find_symbol_ptr: impl Fn(&str) -> Option<*mut c_void>) {
813813 }
814814}
815815
816- /// The extended euclidean algorithm calculates the greatest common divisor (gcd) of two integers a and b,
817- /// as well as the bezout coefficients x and y such that ax+by=gcd(a,b)
818- /// if gcd(a,b) = 1, then x is the modular multiplicative inverse of a modulo b.
819- /// See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm
816+ /// Build the extended euclidean algorithm MLIR function.
820817///
821- /// This function declares a MLIR function that given two numbers a and b, returns a MLIR struct with gcd(a, b)
822- /// and the bezout coefficient x. The declaration is done in the body of the module.
818+ /// The extended euclidean algorithm calculates the greatest common divisor
819+ /// (gcd) of two integers `a` and `b`, as well as the Bézout coefficients `x`
820+ /// and `y` such that `ax + by = gcd(a,b)`. If `gcd(a,b) = 1`, then `x` is the
821+ /// modular multiplicative inverse of `a` modulo `b`.
822+ ///
823+ /// This function declares a MLIR function that given two 384 bit integers `a`
824+ /// and `b`, returns a MLIR struct with `gcd(a,b)` and the Bézout coefficient
825+ /// `x`. The declaration is done in the body of the module.
823826fn build_egcd_function < ' ctx > (
824827 module : & Module ,
825828 context : & ' ctx Context ,
826829 location : Location < ' ctx > ,
827830 func_symbol : & str ,
828831) -> Result < ( ) > {
829- let integer_type: Type = IntegerType :: new ( context, 384 * 2 ) . into ( ) ;
832+ let integer_width = 384 ;
833+ let integer_type = IntegerType :: new ( context, integer_width) . into ( ) ;
834+
835+ // Pseudocode for calculating the EGCD of two integers `a` and `b`.
836+ // https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Pseudocode.
837+ //
838+ // ```
839+ // (old_r, new_r) := (a, b)
840+ // (old_s, new_s) := (1, 0)
841+ //
842+ // while new_r != 0 do
843+ // quotient := old_r / new_r
844+ // (old_r, new_r) := (new_r, old_r − quotient * new_r)
845+ // (old_s, new_s) := (new_s, old_s − quotient * new_s)
846+ //
847+ // old_s is equal to Bézout coefficient X
848+ // old_r is equal to GCD
849+ // ```
850+ //
851+ // Note that when `b > a`, the first iteration inverts the values. Our
852+ // implementation does it manually as we already know that `b > a`.
853+ //
854+ // The core idea of the method is that `gcd(a,b) = gcd(a,b-a)`, and that
855+ // `gcd(a,b) = gcd(b,a)`. As an optimization, we can actually substract `a`
856+ // from `b` as many times as possible, so `gcd(a,b) = gcd(b%a,a)`.
857+ //
858+ // Take, for example, `a=21` and `b=54`:
859+ //
860+ // gcd(21, 54)
861+ // = gcd(12, 21)
862+ // = gcd(9, 12)
863+ // = gcd(3, 9)
864+ // = gcd(0, 3)
865+ // = 3
866+ //
867+ // Thus, the algorithm works by calculating a series of remainders `r` which
868+ // starts with b,a,... being `r[i]` the remainder of dividing `r[i-2]` by
869+ // `r[i-1]`. At each step, `r[i]` can be calculated as:
870+ //
871+ // r[i] = r[i-2] - r[i-1] * quotient
872+ //
873+ // The GCD will be the last non-zero remainder.
874+ //
875+ // [54; 21; 12; 9; 3; 0]
876+ // ^
877+ //
878+ // See Dr. Katherine Stange's Youtube video for a better explanation on how
879+ // this works: https://www.youtube.com/watch?v=Jwf6ncRmhPg.
880+ //
881+ // The extended algorithm also obtains the Bézout coefficients
882+ // by calculating a series of coefficients `s`. See Dr. Katherine
883+ // Stange's Youtube video for a better explanation on how this works:
884+ // https://www.youtube.com/watch?v=IwRtISxAHY4.
885+
886+ // Define entry block for function. Receives arguments `a` and `b`.
830887 let region = Region :: new ( ) ;
831-
832888 let entry_block = region. append_block ( Block :: new ( & [
833- ( integer_type, location) ,
834- ( integer_type, location) ,
889+ ( integer_type, location) , // a
890+ ( integer_type, location) , // b
835891 ] ) ) ;
836892
837- let a = entry_block. arg ( 0 ) ?;
838- let b = entry_block. arg ( 1 ) ?;
839- // The egcd algorithm works by calculating a series of remainders `rem`, being each `rem_i` the remainder of dividing `rem_{i-1}` with `rem_{i-2}`
840- // For the initial setup, rem_0 = b, rem_1 = a.
841- // This order is chosen because if we reverse them, then the first iteration will just swap them
842- let remainder = a;
843- let prev_remainder = b;
844-
845- // Similarly we'll calculate another series which starts 0,1,... and from which we
846- // will retrieve the modular inverse of a
847- let prev_inverse = entry_block. const_int_from_type ( context, location, 0 , integer_type) ?;
848- let inverse = entry_block. const_int_from_type ( context, location, 1 , integer_type) ?;
849-
893+ // Define loop block for function. Each iteration last two values from each series.
850894 let loop_block = region. append_block ( Block :: new ( & [
851- ( integer_type, location) ,
852- ( integer_type, location) ,
853- ( integer_type, location) ,
854- ( integer_type, location) ,
895+ ( integer_type, location) , // old_r
896+ ( integer_type, location) , // new_r
897+ ( integer_type, location) , // old_s
898+ ( integer_type, location) , // new_s
855899 ] ) ) ;
900+
901+ // Define end block for function.
856902 let end_block = region. append_block ( Block :: new ( & [
857- ( integer_type, location) ,
858- ( integer_type, location) ,
903+ ( integer_type, location) , // old_r
904+ ( integer_type, location) , // old_s
859905 ] ) ) ;
860906
907+ // Jump to loop block from entry block, with initial values.
908+ // - old_r = b
909+ // - new_r = a
910+ // - old_s = 0
911+ // - new_s = 1
861912 entry_block. append_operation ( cf:: br (
862913 & loop_block,
863- & [ prev_remainder, remainder, prev_inverse, inverse] ,
914+ & [
915+ entry_block. arg ( 1 ) ?,
916+ entry_block. arg ( 0 ) ?,
917+ entry_block. const_int_from_type ( context, location, 0 , integer_type) ?,
918+ entry_block. const_int_from_type ( context, location, 1 , integer_type) ?,
919+ ] ,
864920 location,
865921 ) ) ;
866922
867- // -- Loop body --
868- // Arguments are rem_(i-1), rem, inv_(i-1), inv
869- let prev_remainder = loop_block. arg ( 0 ) ?;
870- let remainder = loop_block. arg ( 1 ) ?;
871- let prev_inverse = loop_block. arg ( 2 ) ?;
872- let inverse = loop_block. arg ( 3 ) ?;
873-
874- // First calculate q = rem_(i-1)/rem_i, rounded down
875- let quotient =
876- loop_block. append_op_result ( arith:: divui ( prev_remainder, remainder, location) ) ?;
877-
878- // Then rem_(i+1) = rem_(i-1) - q * rem_i, and inv_(i+1) = inv_(i-1) - q * inv_i
879- let rem_times_quo = loop_block. muli ( remainder, quotient, location) ?;
880- let inv_times_quo = loop_block. muli ( inverse, quotient, location) ?;
881- let next_remainder =
882- loop_block. append_op_result ( arith:: subi ( prev_remainder, rem_times_quo, location) ) ?;
883- let next_inverse =
884- loop_block. append_op_result ( arith:: subi ( prev_inverse, inv_times_quo, location) ) ?;
885-
886- // Check if rem_(i+1) is 0
887- // If true, then:
888- // - rem_i is the gcd of a and b
889- // - inv_i is the bezout coefficient x
890- let zero = loop_block. const_int_from_type ( context, location, 0 , integer_type) ?;
891- let next_remainder_eq_zero =
892- loop_block. cmpi ( context, CmpiPredicate :: Eq , next_remainder, zero, location) ?;
893- loop_block. append_operation ( cf:: cond_br (
894- context,
895- next_remainder_eq_zero,
896- & end_block,
897- & loop_block,
898- & [ remainder, inverse] ,
899- & [ remainder, next_remainder, inverse, next_inverse] ,
900- location,
901- ) ) ;
923+ // LOOP BLOCK
924+ {
925+ let old_r = loop_block. arg ( 0 ) ?;
926+ let new_r = loop_block. arg ( 1 ) ?;
927+ let old_s = loop_block. arg ( 2 ) ?;
928+ let new_s = loop_block. arg ( 3 ) ?;
929+
930+ // First calculate quotient of old_r/new_r.
931+ let quotient = loop_block. append_op_result ( arith:: divui ( old_r, new_r, location) ) ?;
932+
933+ // Multiply quotient by new_r and new_s.
934+ let quotient_by_new_r = loop_block. muli ( quotient, new_r, location) ?;
935+ let quotient_by_new_s = loop_block. muli ( quotient, new_s, location) ?;
936+
937+ // Calculate new values for next iteration.
938+ // - next_new_r := old_r − quotient * new_r
939+ // - next_new_s := old_s − quotient * new_s
940+ let next_new_r =
941+ loop_block. append_op_result ( arith:: subi ( old_r, quotient_by_new_r, location) ) ?;
942+ let next_new_s =
943+ loop_block. append_op_result ( arith:: subi ( old_s, quotient_by_new_s, location) ) ?;
944+
945+ // Jump to end block if next_new_r is zero.
946+ let zero = loop_block. const_int_from_type ( context, location, 0 , integer_type) ?;
947+ let next_new_r_is_zero =
948+ loop_block. cmpi ( context, CmpiPredicate :: Eq , next_new_r, zero, location) ?;
949+ loop_block. append_operation ( cf:: cond_br (
950+ context,
951+ next_new_r_is_zero,
952+ & end_block,
953+ & loop_block,
954+ & [ new_r, new_s] ,
955+ & [ new_r, next_new_r, new_s, next_new_s] ,
956+ location,
957+ ) ) ;
958+ }
902959
903- // Create the struct that will contain the results
904- let results = end_block. append_op_result ( llvm:: undef (
905- llvm:: r#type:: r#struct ( context, & [ integer_type, integer_type] , false ) ,
906- location,
907- ) ) ?;
908- let results = end_block. insert_values (
909- context,
910- location,
911- results,
912- & [ end_block. arg ( 0 ) ?, end_block. arg ( 1 ) ?] ,
913- ) ?;
914- end_block. append_operation ( llvm:: r#return ( Some ( results) , location) ) ;
960+ // END BLOCK
961+ {
962+ let results = end_block. append_op_result ( llvm:: undef (
963+ llvm:: r#type:: r#struct ( context, & [ integer_type, integer_type] , false ) ,
964+ location,
965+ ) ) ?;
966+ let results = end_block. insert_values (
967+ context,
968+ location,
969+ results,
970+ & [ end_block. arg ( 0 ) ?, end_block. arg ( 1 ) ?] ,
971+ ) ?;
972+ end_block. append_operation ( llvm:: r#return ( Some ( results) , location) ) ;
973+ }
915974
916975 let func_name = StringAttribute :: new ( context, func_symbol) ;
917976 module. body ( ) . append_operation ( llvm:: func (
0 commit comments