@@ -125,19 +125,24 @@ pub fn compute_costs<
125125) -> Result < GasInfo , CostError > {
126126 let mut context = CostContext {
127127 program,
128- get_cost_fn ,
128+ branch_costs : Default :: default ( ) ,
129129 enforced_wallet_values,
130130 costs : Default :: default ( ) ,
131131 target_values : Default :: default ( ) ,
132132 } ;
133+ for statement in & program. statements {
134+ context. branch_costs . push ( match statement {
135+ Statement :: Invocation ( invocation) => get_cost_fn ( & invocation. libfunc_id ) ,
136+ Statement :: Return ( _) => vec ! [ ] ,
137+ } ) ;
138+ }
133139
134140 context. prepare_wallet ( specific_cost_context) ?;
135141
136142 // Compute the excess cost and the corresponding target value for each statement.
137143 context. target_values = context. compute_target_values ( specific_cost_context) ?;
138144
139145 // Recompute the wallet values for each statement, after setting the target values.
140- context. costs = Default :: default ( ) ;
141146 context. prepare_wallet ( specific_cost_context) ?;
142147
143148 // Check that enforcing the wallet values succeeded.
@@ -236,30 +241,29 @@ fn get_branch_requirements<
236241fn analyze_gas_statements <
237242 CostType : CostTypeTrait ,
238243 SpecificCostContext : SpecificCostContextTrait < CostType > ,
239- GetCostFn : Fn ( & ConcreteLibfuncId ) -> Vec < BranchCost > ,
240244> (
241- context : & CostContext < ' _ , CostType , GetCostFn > ,
245+ context : & CostContext < ' _ , CostType > ,
242246 specific_context : & SpecificCostContext ,
243247 idx : & StatementIdx ,
244248 variable_values : & mut VariableValues ,
245249) -> Result < ( ) , CostError > {
246250 let Statement :: Invocation ( invocation) = & context. program . get_statement ( idx) . unwrap ( ) else {
247251 return Ok ( ( ) ) ;
248252 } ;
249- let libfunc_cost: Vec < BranchCost > = context. get_cost ( & invocation . libfunc_id ) ;
253+ let libfunc_cost = & context. branch_costs [ idx . 0 ] ;
250254 let branch_requirements: Vec < WalletInfo < CostType > > = get_branch_requirements (
251255 specific_context,
252256 & |statement_idx| context. wallet_at ( statement_idx) ,
253257 idx,
254258 invocation,
255- & libfunc_cost,
259+ libfunc_cost,
256260 false ,
257261 ) ;
258262
259263 let wallet_value = context. wallet_at ( idx) . value ;
260264
261265 for ( branch_info, branch_cost, branch_requirement) in
262- zip_eq3 ( & invocation. branches , & libfunc_cost, & branch_requirements)
266+ zip_eq3 ( & invocation. branches , libfunc_cost, & branch_requirements)
263267 {
264268 if let BranchCost :: WithdrawGas ( WithdrawGasBranchInfo { success : true , .. } ) = branch_cost {
265269 // Note that `idx.next(&branch_info.target)` is indeed branch align due to
@@ -397,32 +401,21 @@ impl<CostType: CostTypeTrait> std::ops::Add for WalletInfo<CostType> {
397401}
398402
399403/// Helper struct for computing the wallet value at each statement.
400- struct CostContext <
401- ' a ,
402- CostType : CostTypeTrait ,
403- GetCostFn : Fn ( & ConcreteLibfuncId ) -> Vec < BranchCost > ,
404- > {
404+ struct CostContext < ' a , CostType : CostTypeTrait > {
405405 /// The Sierra program.
406406 program : & ' a Program ,
407- /// A callback function returning the cost of a libfunc for every output branch .
408- get_cost_fn : & ' a GetCostFn ,
407+ /// The branch costs per statement .
408+ branch_costs : Vec < Vec < BranchCost > > ,
409409 /// A map from statement index to an enforced wallet value. For example, some functions
410410 /// may have a required cost, in this case the functions entry points should have a predefined
411411 /// wallet value.
412412 enforced_wallet_values : & ' a OrderedHashMap < StatementIdx , CostType > ,
413413 /// The cost before executing a Sierra statement.
414- costs : UnorderedHashMap < StatementIdx , WalletInfo < CostType > > ,
414+ costs : Vec < WalletInfo < CostType > > ,
415415 /// A partial map from StatementIdx to a requested lower bound on the wallet value.
416- target_values : UnorderedHashMap < StatementIdx , CostType > ,
416+ target_values : Vec < CostType > ,
417417}
418- impl < CostType : CostTypeTrait , GetCostFn : Fn ( & ConcreteLibfuncId ) -> Vec < BranchCost > >
419- CostContext < ' _ , CostType , GetCostFn >
420- {
421- /// Returns the cost of a libfunc for every output branch.
422- fn get_cost ( & self , libfunc_id : & ConcreteLibfuncId ) -> Vec < BranchCost > {
423- ( self . get_cost_fn ) ( libfunc_id)
424- }
425-
418+ impl < CostType : CostTypeTrait > CostContext < ' _ , CostType > {
426419 /// Returns the required value in the wallet before executing statement `idx`.
427420 ///
428421 /// Assumes that [Self::prepare_wallet] was called before.
@@ -444,10 +437,7 @@ impl<CostType: CostTypeTrait, GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCos
444437 return WalletInfo :: from ( enforced_wallet_value. clone ( ) ) ;
445438 }
446439
447- self . costs
448- . get ( idx)
449- . unwrap_or_else ( || panic ! ( "Wallet value for statement {idx} was not yet computed." ) )
450- . clone ( )
440+ self . costs [ idx. 0 ] . clone ( )
451441 }
452442
453443 /// Prepares the values for [Self::wallet_at].
@@ -465,21 +455,21 @@ impl<CostType: CostTypeTrait, GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCos
465455 vec ! [ ]
466456 }
467457 Statement :: Invocation ( invocation) => {
468- let libfunc_cost: Vec < BranchCost > = self . get_cost ( & invocation . libfunc_id ) ;
458+ let libfunc_cost = & self . branch_costs [ current_idx . 0 ] ;
469459
470- get_branch_requirements_dependencies ( current_idx, invocation, & libfunc_cost)
460+ get_branch_requirements_dependencies ( current_idx, invocation, libfunc_cost)
471461 . into_iter ( )
472462 . collect ( )
473463 }
474464 }
475465 } ,
476466 ) ?;
477467
468+ self . costs . resize ( self . program . statements . len ( ) , Default :: default ( ) ) ;
478469 for current_idx in rev_topological_order {
479470 // The computation of the dependencies was completed.
480- let res = self . no_cache_compute_wallet_at ( & current_idx, specific_cost_context) ;
481- // Update the cache with the result.
482- self . costs . insert ( current_idx, res. clone ( ) ) ;
471+ self . costs [ current_idx. 0 ] =
472+ self . no_cache_compute_wallet_at ( & current_idx, specific_cost_context) ;
483473 }
484474
485475 Ok ( ( ) )
@@ -496,21 +486,21 @@ impl<CostType: CostTypeTrait, GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCos
496486 match & self . program . get_statement ( idx) . unwrap ( ) {
497487 Statement :: Return ( _) => Default :: default ( ) ,
498488 Statement :: Invocation ( invocation) => {
499- let libfunc_cost: Vec < BranchCost > = self . get_cost ( & invocation . libfunc_id ) ;
489+ let libfunc_cost = & self . branch_costs [ idx . 0 ] ;
500490
501491 // For each branch, compute the required value for the wallet.
502492 let branch_requirements: Vec < WalletInfo < CostType > > = get_branch_requirements (
503493 specific_cost_context,
504494 & |statement_idx| self . wallet_at ( statement_idx) ,
505495 idx,
506496 invocation,
507- & libfunc_cost,
497+ libfunc_cost,
508498 true ,
509499 ) ;
510500
511501 // The wallet value at the beginning of the statement is the maximal value
512502 // required by all the branches.
513- WalletInfo :: merge ( & libfunc_cost, branch_requirements, self . target_values . get ( idx) )
503+ WalletInfo :: merge ( libfunc_cost, branch_requirements, self . target_values . get ( idx. 0 ) )
514504 }
515505 }
516506 }
@@ -521,7 +511,7 @@ impl<CostType: CostTypeTrait, GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCos
521511 fn compute_target_values < SpecificCostContext : SpecificCostContextTrait < CostType > > (
522512 & self ,
523513 specific_cost_context : & SpecificCostContext ,
524- ) -> Result < UnorderedHashMap < StatementIdx , CostType > , CostError > {
514+ ) -> Result < Vec < CostType > , CostError > {
525515 // Compute a reverse topological order of the statements.
526516 // Unlike `prepare_wallet`:
527517 // * function calls are not treated as edges and
@@ -566,7 +556,7 @@ impl<CostType: CostTypeTrait, GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCos
566556 . map ( |i| {
567557 let idx = StatementIdx ( i) ;
568558 let original_wallet_value = self . wallet_at_ex ( & idx, false ) . value ;
569- ( idx , original_wallet_value + excess. get ( & idx) . cloned ( ) . unwrap_or_default ( ) )
559+ original_wallet_value + excess. get ( & idx) . cloned ( ) . unwrap_or_default ( )
570560 } )
571561 . collect ( ) )
572562 }
@@ -608,20 +598,20 @@ impl<CostType: CostTypeTrait, GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCos
608598 }
609599 } ;
610600
611- let libfunc_cost: Vec < BranchCost > = self . get_cost ( & invocation . libfunc_id ) ;
601+ let libfunc_cost = & self . branch_costs [ idx . 0 ] ;
612602
613603 let branch_requirements = get_branch_requirements (
614604 specific_cost_context,
615605 & |statement_idx| self . wallet_at ( statement_idx) ,
616606 idx,
617607 invocation,
618- & libfunc_cost,
608+ libfunc_cost,
619609 false ,
620610 ) ;
621611
622612 // Pass the excess to the branches.
623613 for ( branch_info, branch_cost, branch_requirement) in
624- zip_eq3 ( & invocation. branches , & libfunc_cost, branch_requirements)
614+ zip_eq3 ( & invocation. branches , libfunc_cost, branch_requirements)
625615 {
626616 let branch_statement = idx. next ( & branch_info. target ) ;
627617 if finalized_excess_statements. contains ( & branch_statement) {
0 commit comments