Skip to content

Commit d4ca793

Browse files
committed
Const folding enums over snapshots.
1 parent b923d52 commit d4ca793

File tree

2 files changed

+254
-19
lines changed

2 files changed

+254
-19
lines changed

crates/cairo-lang-lowering/src/optimizations/const_folding.rs

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -238,22 +238,32 @@ pub fn const_folding(
238238
ctx.maybe_replace_inputs(info.inputs_mut());
239239
match info {
240240
MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) => {
241-
if let Some(VarInfo::Const(ConstValue::Enum(variant, value))) =
242-
ctx.var_info.get(&input.var_id)
241+
if let Some((
242+
n_snapshots,
243+
VarInfo::Const(ConstValue::Enum(variant, value)),
244+
)) = ctx.peel_snapshots(input.var_id)
243245
{
244246
let arm = &arms[variant.idx];
245247
let value = value.as_ref().clone();
246248
let output = arm.var_ids[0];
247249
if ctx.variables[input.var_id].droppable.is_ok()
248250
&& ctx.variables[output].copyable.is_ok()
251+
&& ctx
252+
.try_generate_const_statements(
253+
&value,
254+
output,
255+
n_snapshots,
256+
&mut block.statements,
257+
)
258+
.is_some()
249259
{
250-
if let Some(stmt) = ctx.try_generate_const_statement(&value, output)
251-
{
252-
block.statements.push(stmt);
253-
block.end = BlockEnd::Goto(arm.block_id, Default::default());
254-
}
260+
block.end = BlockEnd::Goto(arm.block_id, Default::default());
261+
}
262+
let mut info = VarInfo::Const(value);
263+
for _ in 0..n_snapshots {
264+
info = VarInfo::Snapshot(Box::new(info));
255265
}
256-
ctx.var_info.insert(output, VarInfo::Const(value));
266+
ctx.var_info.insert(output, info);
257267
}
258268
}
259269
MatchInfo::Value(info) => {
@@ -633,20 +643,65 @@ impl ConstFoldingContext<'_> {
633643
self.propagate_const_and_get_statement(BigInt::zero(), output, false)
634644
}
635645

636-
/// Returns a statement that introduces the requested value into `output`, or None if fails.
637-
fn try_generate_const_statement(
638-
&self,
646+
/// Addes statements that introduces the requested value into `output` and returns `Some(())`,
647+
/// or `None` if fails.
648+
fn try_generate_const_statements(
649+
&mut self,
639650
value: &ConstValue,
640651
output: VariableId,
641-
) -> Option<Statement> {
642-
if self.db.type_size_info(self.variables[output].ty) == Ok(TypeSizeInformation::Other) {
643-
Some(Statement::Const(StatementConst { value: value.clone(), output }))
644-
} else if matches!(value, ConstValue::Struct(members, _) if members.is_empty()) {
645-
// Handling const empty structs - which are not supported in sierra-gen.
646-
Some(Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }))
647-
} else {
648-
None
652+
n_snapshots: usize,
653+
statements: &mut Vec<Statement>,
654+
) -> Option<()> {
655+
let var = &self.variables[output];
656+
let output_ty = var.ty;
657+
let size_info = self.db.type_size_info(output_ty).ok()?;
658+
if size_info == TypeSizeInformation::ZeroSized
659+
&& !matches!(value, ConstValue::Struct(members, _) if members.is_empty())
660+
{
661+
return None;
649662
}
663+
let get_base_stmt = |var| {
664+
if size_info == TypeSizeInformation::ZeroSized {
665+
Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output: var })
666+
} else {
667+
Statement::Const(StatementConst { value: value.clone(), output: var })
668+
}
669+
};
670+
if n_snapshots == 0 {
671+
statements.push(get_base_stmt(output));
672+
return Some(());
673+
}
674+
let location = var.location;
675+
let mut curr_ty = value.ty(self.db).ok()?;
676+
let mut new_var = |ty| {
677+
self.variables.alloc(Variable::new(self.db, ImplLookupContext::default(), ty, location))
678+
};
679+
let mut curr_var = new_var(curr_ty);
680+
statements.push(get_base_stmt(curr_var));
681+
for _ in 1..n_snapshots {
682+
let unused_orig = new_var(curr_ty);
683+
let snapped_ty = TypeLongId::Snapshot(curr_ty).intern(self.db);
684+
let snapped_var = new_var(snapped_ty);
685+
statements.push(Statement::Snapshot(StatementSnapshot {
686+
input: VarUsage { var_id: curr_var, location },
687+
outputs: [unused_orig, snapped_var],
688+
}));
689+
curr_ty = snapped_ty;
690+
curr_var = snapped_var;
691+
}
692+
let unused_orig = new_var(curr_ty);
693+
statements.push(Statement::Snapshot(StatementSnapshot {
694+
input: VarUsage { var_id: curr_var, location },
695+
outputs: [unused_orig, output],
696+
}));
697+
let final_ty = TypeLongId::Snapshot(curr_ty).intern(self.db);
698+
assert!(
699+
final_ty == output_ty,
700+
"{} != {}",
701+
final_ty.format(self.db),
702+
output_ty.format(self.db)
703+
);
704+
Some(())
650705
}
651706

652707
/// Handles the end of an extern block.
@@ -990,6 +1045,17 @@ impl ConstFoldingContext<'_> {
9901045
try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const)
9911046
}
9921047

1048+
/// Returns the number of snapshots and the `VarInfo` wrapped by the snapshots.
1049+
fn peel_snapshots(&self, var_id: VariableId) -> Option<(usize, &VarInfo)> {
1050+
let mut n_snapshots = 0;
1051+
let mut curr = self.var_info.get(&var_id)?;
1052+
while let VarInfo::Snapshot(next) = curr {
1053+
n_snapshots += 1;
1054+
curr = next.as_ref();
1055+
}
1056+
Some((n_snapshots, curr))
1057+
}
1058+
9931059
/// Return the const value as an int if it exists and is an integer, additionally, if it is of a
9941060
/// non-zero type.
9951061
fn as_int_ex(&self, var_id: VariableId) -> Option<(&BigInt, bool)> {

crates/cairo-lang-lowering/src/optimizations/test_data/const_folding

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5903,3 +5903,172 @@ End:
59035903
Return(v7)
59045904

59055905
//! > lowering_diagnostics
5906+
5907+
//! > ==========================================================================
5908+
5909+
//! > Match enum const.
5910+
5911+
//! > test_runner_name
5912+
test_match_optimizer
5913+
5914+
//! > function
5915+
fn foo() -> u8 {
5916+
match E::V1(2) {
5917+
E::V0(x) => x,
5918+
E::V1(x) => x,
5919+
}
5920+
}
5921+
5922+
//! > function_name
5923+
foo
5924+
5925+
//! > module_code
5926+
#[derive(Drop)]
5927+
enum E {
5928+
V0: u8,
5929+
V1: u8,
5930+
}
5931+
5932+
//! > semantic_diagnostics
5933+
5934+
//! > before
5935+
Parameters:
5936+
blk0 (root):
5937+
Statements:
5938+
(v0: core::integer::u8) <- 2
5939+
(v1: test::E) <- E::V1(v0)
5940+
End:
5941+
Match(match_enum(v1) {
5942+
E::V0(v2) => blk1,
5943+
E::V1(v3) => blk2,
5944+
})
5945+
5946+
blk1:
5947+
Statements:
5948+
End:
5949+
Goto(blk3, {v2 -> v4})
5950+
5951+
blk2:
5952+
Statements:
5953+
End:
5954+
Goto(blk3, {v3 -> v4})
5955+
5956+
blk3:
5957+
Statements:
5958+
End:
5959+
Return(v4)
5960+
5961+
//! > after
5962+
Parameters:
5963+
blk0 (root):
5964+
Statements:
5965+
(v0: core::integer::u8) <- 2
5966+
(v1: test::E) <- E::V1(v0)
5967+
(v3: core::integer::u8) <- 2
5968+
End:
5969+
Goto(blk2, {})
5970+
5971+
blk1:
5972+
Statements:
5973+
End:
5974+
Goto(blk3, {v2 -> v4})
5975+
5976+
blk2:
5977+
Statements:
5978+
End:
5979+
Goto(blk3, {v3 -> v4})
5980+
5981+
blk3:
5982+
Statements:
5983+
End:
5984+
Return(v4)
5985+
5986+
//! > lowering_diagnostics
5987+
5988+
//! > ==========================================================================
5989+
5990+
//! > Match enum snapshot const.
5991+
5992+
//! > test_runner_name
5993+
test_match_optimizer
5994+
5995+
//! > function
5996+
fn foo() -> u8 {
5997+
match @E::V1(2) {
5998+
E::V0(x) => *x,
5999+
E::V1(x) => *x,
6000+
}
6001+
}
6002+
6003+
//! > function_name
6004+
foo
6005+
6006+
//! > module_code
6007+
#[derive(Drop)]
6008+
enum E {
6009+
V0: u8,
6010+
V1: u8,
6011+
}
6012+
6013+
//! > semantic_diagnostics
6014+
6015+
//! > before
6016+
Parameters:
6017+
blk0 (root):
6018+
Statements:
6019+
(v0: core::integer::u8) <- 2
6020+
(v1: test::E) <- E::V1(v0)
6021+
(v2: test::E, v3: @test::E) <- snapshot(v1)
6022+
End:
6023+
Match(match_enum(v3) {
6024+
E::V0(v4) => blk1,
6025+
E::V1(v5) => blk2,
6026+
})
6027+
6028+
blk1:
6029+
Statements:
6030+
(v6: core::integer::u8) <- desnap(v4)
6031+
End:
6032+
Goto(blk3, {v6 -> v7})
6033+
6034+
blk2:
6035+
Statements:
6036+
(v8: core::integer::u8) <- desnap(v5)
6037+
End:
6038+
Goto(blk3, {v8 -> v7})
6039+
6040+
blk3:
6041+
Statements:
6042+
End:
6043+
Return(v7)
6044+
6045+
//! > after
6046+
Parameters:
6047+
blk0 (root):
6048+
Statements:
6049+
(v0: core::integer::u8) <- 2
6050+
(v1: test::E) <- E::V1(v0)
6051+
(v2: test::E, v3: @test::E) <- snapshot(v1)
6052+
(v9: core::integer::u8) <- 2
6053+
(v10: core::integer::u8, v5: @core::integer::u8) <- snapshot(v9)
6054+
End:
6055+
Goto(blk2, {})
6056+
6057+
blk1:
6058+
Statements:
6059+
(v6: core::integer::u8) <- desnap(v4)
6060+
End:
6061+
Goto(blk3, {v6 -> v7})
6062+
6063+
blk2:
6064+
Statements:
6065+
(v8: core::integer::u8) <- desnap(v5)
6066+
End:
6067+
Goto(blk3, {v8 -> v7})
6068+
6069+
blk3:
6070+
Statements:
6071+
End:
6072+
Return(v7)
6073+
6074+
//! > lowering_diagnostics

0 commit comments

Comments
 (0)