Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 85 additions & 19 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,22 +238,32 @@ pub fn const_folding(
ctx.maybe_replace_inputs(info.inputs_mut());
match info {
MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) => {
if let Some(VarInfo::Const(ConstValue::Enum(variant, value))) =
ctx.var_info.get(&input.var_id)
if let Some((
n_snapshots,
VarInfo::Const(ConstValue::Enum(variant, value)),
)) = ctx.peel_snapshots(input.var_id)
{
let arm = &arms[variant.idx];
let value = value.as_ref().clone();
let output = arm.var_ids[0];
if ctx.variables[input.var_id].droppable.is_ok()
&& ctx.variables[output].copyable.is_ok()
&& ctx
.try_generate_const_statements(
&value,
output,
n_snapshots,
&mut block.statements,
)
.is_some()
{
if let Some(stmt) = ctx.try_generate_const_statement(&value, output)
{
block.statements.push(stmt);
block.end = BlockEnd::Goto(arm.block_id, Default::default());
}
block.end = BlockEnd::Goto(arm.block_id, Default::default());
}
let mut info = VarInfo::Const(value);
for _ in 0..n_snapshots {
info = VarInfo::Snapshot(Box::new(info));
}
ctx.var_info.insert(output, VarInfo::Const(value));
ctx.var_info.insert(output, info);
}
}
MatchInfo::Value(info) => {
Expand Down Expand Up @@ -633,20 +643,65 @@ impl ConstFoldingContext<'_> {
self.propagate_const_and_get_statement(BigInt::zero(), output, false)
}

/// Returns a statement that introduces the requested value into `output`, or None if fails.
fn try_generate_const_statement(
&self,
/// Adds statements that introduces the requested value into `output` and returns `Some(())`,
/// or `None` if fails.
fn try_generate_const_statements(
&mut self,
value: &ConstValue,
output: VariableId,
) -> Option<Statement> {
if self.db.type_size_info(self.variables[output].ty) == Ok(TypeSizeInformation::Other) {
Some(Statement::Const(StatementConst { value: value.clone(), output }))
} else if matches!(value, ConstValue::Struct(members, _) if members.is_empty()) {
// Handling const empty structs - which are not supported in sierra-gen.
Some(Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }))
} else {
None
n_snapshots: usize,
statements: &mut Vec<Statement>,
) -> Option<()> {
let var = &self.variables[output];
let output_ty = var.ty;
let size_info = self.db.type_size_info(output_ty).ok()?;
if size_info == TypeSizeInformation::ZeroSized
&& !matches!(value, ConstValue::Struct(members, _) if members.is_empty())
{
return None;
}
let get_base_stmt = |var| {
if size_info == TypeSizeInformation::ZeroSized {
Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output: var })
} else {
Statement::Const(StatementConst { value: value.clone(), output: var })
}
};
if n_snapshots == 0 {
statements.push(get_base_stmt(output));
return Some(());
}
let location = var.location;
let mut curr_ty = value.ty(self.db).ok()?;
let mut new_var = |ty| {
self.variables.alloc(Variable::new(self.db, ImplLookupContext::default(), ty, location))
};
let mut curr_var = new_var(curr_ty);
statements.push(get_base_stmt(curr_var));
for _ in 1..n_snapshots {
let unused_orig = new_var(curr_ty);
let snapped_ty = TypeLongId::Snapshot(curr_ty).intern(self.db);
let snapped_var = new_var(snapped_ty);
statements.push(Statement::Snapshot(StatementSnapshot {
input: VarUsage { var_id: curr_var, location },
outputs: [unused_orig, snapped_var],
}));
curr_ty = snapped_ty;
curr_var = snapped_var;
}
let unused_orig = new_var(curr_ty);
statements.push(Statement::Snapshot(StatementSnapshot {
input: VarUsage { var_id: curr_var, location },
outputs: [unused_orig, output],
}));
let final_ty = TypeLongId::Snapshot(curr_ty).intern(self.db);
assert!(
final_ty == output_ty,
"{} != {}",
final_ty.format(self.db),
output_ty.format(self.db)
);
Some(())
}

/// Handles the end of an extern block.
Expand Down Expand Up @@ -990,6 +1045,17 @@ impl ConstFoldingContext<'_> {
try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const)
}

/// Returns the number of snapshots and the `VarInfo` wrapped by the snapshots.
fn peel_snapshots(&self, var_id: VariableId) -> Option<(usize, &VarInfo)> {
let mut n_snapshots = 0;
let mut curr = self.var_info.get(&var_id)?;
while let VarInfo::Snapshot(next) = curr {
n_snapshots += 1;
curr = next.as_ref();
}
Some((n_snapshots, curr))
}

/// Return the const value as an int if it exists and is an integer, additionally, if it is of a
/// non-zero type.
fn as_int_ex(&self, var_id: VariableId) -> Option<(&BigInt, bool)> {
Expand Down
169 changes: 169 additions & 0 deletions crates/cairo-lang-lowering/src/optimizations/test_data/const_folding
Original file line number Diff line number Diff line change
Expand Up @@ -5903,3 +5903,172 @@ End:
Return(v7)

//! > lowering_diagnostics

//! > ==========================================================================

//! > Match enum const.

//! > test_runner_name
test_match_optimizer

//! > function
fn foo() -> u8 {
match E::V1(2) {
E::V0(x) => x,
E::V1(x) => x,
}
}

//! > function_name
foo

//! > module_code
#[derive(Drop)]
enum E {
V0: u8,
V1: u8,
}

//! > semantic_diagnostics

//! > before
Parameters:
blk0 (root):
Statements:
(v0: core::integer::u8) <- 2
(v1: test::E) <- E::V1(v0)
End:
Match(match_enum(v1) {
E::V0(v2) => blk1,
E::V1(v3) => blk2,
})

blk1:
Statements:
End:
Goto(blk3, {v2 -> v4})

blk2:
Statements:
End:
Goto(blk3, {v3 -> v4})

blk3:
Statements:
End:
Return(v4)

//! > after
Parameters:
blk0 (root):
Statements:
(v0: core::integer::u8) <- 2
(v1: test::E) <- E::V1(v0)
(v3: core::integer::u8) <- 2
End:
Goto(blk2, {})

blk1:
Statements:
End:
Goto(blk3, {v2 -> v4})

blk2:
Statements:
End:
Goto(blk3, {v3 -> v4})

blk3:
Statements:
End:
Return(v4)

//! > lowering_diagnostics

//! > ==========================================================================

//! > Match enum snapshot const.

//! > test_runner_name
test_match_optimizer

//! > function
fn foo() -> u8 {
match @E::V1(2) {
E::V0(x) => *x,
E::V1(x) => *x,
}
}

//! > function_name
foo

//! > module_code
#[derive(Drop)]
enum E {
V0: u8,
V1: u8,
}

//! > semantic_diagnostics

//! > before
Parameters:
blk0 (root):
Statements:
(v0: core::integer::u8) <- 2
(v1: test::E) <- E::V1(v0)
(v2: test::E, v3: @test::E) <- snapshot(v1)
End:
Match(match_enum(v3) {
E::V0(v4) => blk1,
E::V1(v5) => blk2,
})

blk1:
Statements:
(v6: core::integer::u8) <- desnap(v4)
End:
Goto(blk3, {v6 -> v7})

blk2:
Statements:
(v8: core::integer::u8) <- desnap(v5)
End:
Goto(blk3, {v8 -> v7})

blk3:
Statements:
End:
Return(v7)

//! > after
Parameters:
blk0 (root):
Statements:
(v0: core::integer::u8) <- 2
(v1: test::E) <- E::V1(v0)
(v2: test::E, v3: @test::E) <- snapshot(v1)
(v9: core::integer::u8) <- 2
(v10: core::integer::u8, v5: @core::integer::u8) <- snapshot(v9)
End:
Goto(blk2, {})

blk1:
Statements:
(v6: core::integer::u8) <- desnap(v4)
End:
Goto(blk3, {v6 -> v7})

blk2:
Statements:
(v8: core::integer::u8) <- desnap(v5)
End:
Goto(blk3, {v8 -> v7})

blk3:
Statements:
End:
Return(v7)

//! > lowering_diagnostics
Loading