Skip to content

Commit c6074dd

Browse files
committed
performance(sierra-gas): Gas computation performance improvements.
* Calculating branches only once. * Using Vec<_> instead of UnorderedMap<StatementIdx, _>. SIERRA_UPDATE_PATCH_CHANGE_TAG=Just performance gain - no interface effect.
1 parent a7b510b commit c6074dd

File tree

1 file changed

+31
-41
lines changed

1 file changed

+31
-41
lines changed

crates/cairo-lang-sierra-gas/src/compute_costs.rs

Lines changed: 31 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -125,19 +125,24 @@ pub fn compute_costs<
125125
) -> Result<GasInfo, CostError> {
126126
let mut context = CostContext {
127127
program,
128-
get_cost_fn,
128+
branch_costs: Default::default(),
129129
enforced_wallet_values,
130130
costs: Default::default(),
131131
target_values: Default::default(),
132132
};
133+
for statement in &program.statements {
134+
context.branch_costs.push(match statement {
135+
Statement::Invocation(invocation) => get_cost_fn(&invocation.libfunc_id),
136+
Statement::Return(_) => vec![],
137+
});
138+
}
133139

134140
context.prepare_wallet(specific_cost_context)?;
135141

136142
// Compute the excess cost and the corresponding target value for each statement.
137143
context.target_values = context.compute_target_values(specific_cost_context)?;
138144

139145
// Recompute the wallet values for each statement, after setting the target values.
140-
context.costs = Default::default();
141146
context.prepare_wallet(specific_cost_context)?;
142147

143148
// Check that enforcing the wallet values succeeded.
@@ -236,30 +241,29 @@ fn get_branch_requirements<
236241
fn analyze_gas_statements<
237242
CostType: CostTypeTrait,
238243
SpecificCostContext: SpecificCostContextTrait<CostType>,
239-
GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCost>,
240244
>(
241-
context: &CostContext<'_, CostType, GetCostFn>,
245+
context: &CostContext<'_, CostType>,
242246
specific_context: &SpecificCostContext,
243247
idx: &StatementIdx,
244248
variable_values: &mut VariableValues,
245249
) -> Result<(), CostError> {
246250
let Statement::Invocation(invocation) = &context.program.get_statement(idx).unwrap() else {
247251
return Ok(());
248252
};
249-
let libfunc_cost: Vec<BranchCost> = context.get_cost(&invocation.libfunc_id);
253+
let libfunc_cost = &context.branch_costs[idx.0];
250254
let branch_requirements: Vec<WalletInfo<CostType>> = get_branch_requirements(
251255
specific_context,
252256
&|statement_idx| context.wallet_at(statement_idx),
253257
idx,
254258
invocation,
255-
&libfunc_cost,
259+
libfunc_cost,
256260
false,
257261
);
258262

259263
let wallet_value = context.wallet_at(idx).value;
260264

261265
for (branch_info, branch_cost, branch_requirement) in
262-
zip_eq3(&invocation.branches, &libfunc_cost, &branch_requirements)
266+
zip_eq3(&invocation.branches, libfunc_cost, &branch_requirements)
263267
{
264268
if let BranchCost::WithdrawGas(WithdrawGasBranchInfo { success: true, .. }) = branch_cost {
265269
// Note that `idx.next(&branch_info.target)` is indeed branch align due to
@@ -397,32 +401,21 @@ impl<CostType: CostTypeTrait> std::ops::Add for WalletInfo<CostType> {
397401
}
398402

399403
/// Helper struct for computing the wallet value at each statement.
400-
struct CostContext<
401-
'a,
402-
CostType: CostTypeTrait,
403-
GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCost>,
404-
> {
404+
struct CostContext<'a, CostType: CostTypeTrait> {
405405
/// The Sierra program.
406406
program: &'a Program,
407-
/// A callback function returning the cost of a libfunc for every output branch.
408-
get_cost_fn: &'a GetCostFn,
407+
/// The branch costs per statement.
408+
branch_costs: Vec<Vec<BranchCost>>,
409409
/// A map from statement index to an enforced wallet value. For example, some functions
410410
/// may have a required cost, in this case the functions entry points should have a predefined
411411
/// wallet value.
412412
enforced_wallet_values: &'a OrderedHashMap<StatementIdx, CostType>,
413413
/// The cost before executing a Sierra statement.
414-
costs: UnorderedHashMap<StatementIdx, WalletInfo<CostType>>,
414+
costs: Vec<WalletInfo<CostType>>,
415415
/// A partial map from StatementIdx to a requested lower bound on the wallet value.
416-
target_values: UnorderedHashMap<StatementIdx, CostType>,
416+
target_values: Vec<CostType>,
417417
}
418-
impl<CostType: CostTypeTrait, GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCost>>
419-
CostContext<'_, CostType, GetCostFn>
420-
{
421-
/// Returns the cost of a libfunc for every output branch.
422-
fn get_cost(&self, libfunc_id: &ConcreteLibfuncId) -> Vec<BranchCost> {
423-
(self.get_cost_fn)(libfunc_id)
424-
}
425-
418+
impl<CostType: CostTypeTrait> CostContext<'_, CostType> {
426419
/// Returns the required value in the wallet before executing statement `idx`.
427420
///
428421
/// Assumes that [Self::prepare_wallet] was called before.
@@ -444,10 +437,7 @@ impl<CostType: CostTypeTrait, GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCos
444437
return WalletInfo::from(enforced_wallet_value.clone());
445438
}
446439

447-
self.costs
448-
.get(idx)
449-
.unwrap_or_else(|| panic!("Wallet value for statement {idx} was not yet computed."))
450-
.clone()
440+
self.costs[idx.0].clone()
451441
}
452442

453443
/// Prepares the values for [Self::wallet_at].
@@ -465,21 +455,21 @@ impl<CostType: CostTypeTrait, GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCos
465455
vec![]
466456
}
467457
Statement::Invocation(invocation) => {
468-
let libfunc_cost: Vec<BranchCost> = self.get_cost(&invocation.libfunc_id);
458+
let libfunc_cost = &self.branch_costs[current_idx.0];
469459

470-
get_branch_requirements_dependencies(current_idx, invocation, &libfunc_cost)
460+
get_branch_requirements_dependencies(current_idx, invocation, libfunc_cost)
471461
.into_iter()
472462
.collect()
473463
}
474464
}
475465
},
476466
)?;
477467

468+
self.costs.resize(self.program.statements.len(), Default::default());
478469
for current_idx in rev_topological_order {
479470
// The computation of the dependencies was completed.
480-
let res = self.no_cache_compute_wallet_at(&current_idx, specific_cost_context);
481-
// Update the cache with the result.
482-
self.costs.insert(current_idx, res.clone());
471+
self.costs[current_idx.0] =
472+
self.no_cache_compute_wallet_at(&current_idx, specific_cost_context);
483473
}
484474

485475
Ok(())
@@ -496,21 +486,21 @@ impl<CostType: CostTypeTrait, GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCos
496486
match &self.program.get_statement(idx).unwrap() {
497487
Statement::Return(_) => Default::default(),
498488
Statement::Invocation(invocation) => {
499-
let libfunc_cost: Vec<BranchCost> = self.get_cost(&invocation.libfunc_id);
489+
let libfunc_cost = &self.branch_costs[idx.0];
500490

501491
// For each branch, compute the required value for the wallet.
502492
let branch_requirements: Vec<WalletInfo<CostType>> = get_branch_requirements(
503493
specific_cost_context,
504494
&|statement_idx| self.wallet_at(statement_idx),
505495
idx,
506496
invocation,
507-
&libfunc_cost,
497+
libfunc_cost,
508498
true,
509499
);
510500

511501
// The wallet value at the beginning of the statement is the maximal value
512502
// required by all the branches.
513-
WalletInfo::merge(&libfunc_cost, branch_requirements, self.target_values.get(idx))
503+
WalletInfo::merge(libfunc_cost, branch_requirements, self.target_values.get(idx.0))
514504
}
515505
}
516506
}
@@ -521,7 +511,7 @@ impl<CostType: CostTypeTrait, GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCos
521511
fn compute_target_values<SpecificCostContext: SpecificCostContextTrait<CostType>>(
522512
&self,
523513
specific_cost_context: &SpecificCostContext,
524-
) -> Result<UnorderedHashMap<StatementIdx, CostType>, CostError> {
514+
) -> Result<Vec<CostType>, CostError> {
525515
// Compute a reverse topological order of the statements.
526516
// Unlike `prepare_wallet`:
527517
// * function calls are not treated as edges and
@@ -566,7 +556,7 @@ impl<CostType: CostTypeTrait, GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCos
566556
.map(|i| {
567557
let idx = StatementIdx(i);
568558
let original_wallet_value = self.wallet_at_ex(&idx, false).value;
569-
(idx, original_wallet_value + excess.get(&idx).cloned().unwrap_or_default())
559+
original_wallet_value + excess.get(&idx).cloned().unwrap_or_default()
570560
})
571561
.collect())
572562
}
@@ -608,20 +598,20 @@ impl<CostType: CostTypeTrait, GetCostFn: Fn(&ConcreteLibfuncId) -> Vec<BranchCos
608598
}
609599
};
610600

611-
let libfunc_cost: Vec<BranchCost> = self.get_cost(&invocation.libfunc_id);
601+
let libfunc_cost = &self.branch_costs[idx.0];
612602

613603
let branch_requirements = get_branch_requirements(
614604
specific_cost_context,
615605
&|statement_idx| self.wallet_at(statement_idx),
616606
idx,
617607
invocation,
618-
&libfunc_cost,
608+
libfunc_cost,
619609
false,
620610
);
621611

622612
// Pass the excess to the branches.
623613
for (branch_info, branch_cost, branch_requirement) in
624-
zip_eq3(&invocation.branches, &libfunc_cost, branch_requirements)
614+
zip_eq3(&invocation.branches, libfunc_cost, branch_requirements)
625615
{
626616
let branch_statement = idx.next(&branch_info.target);
627617
if finalized_excess_statements.contains(&branch_statement) {

0 commit comments

Comments
 (0)