Skip to content

Commit adb5821

Browse files
authored
Add an empty merge_block_builders with a test. (#7952)
1 parent fb2d6de commit adb5821

File tree

9 files changed

+350
-4
lines changed

9 files changed

+350
-4
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/cairo-lang-lowering/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ cairo-lang-syntax = { path = "../cairo-lang-syntax", version = "~2.11.4" }
1818
cairo-lang-utils = { path = "../cairo-lang-utils", version = "~2.11.4" }
1919
assert_matches.workspace = true
2020
id-arena.workspace = true
21+
indent.workspace = true
2122
itertools = { workspace = true, default-features = true }
2223
log.workspace = true
2324
num-bigint = { workspace = true, default-features = true }

crates/cairo-lang-lowering/src/fmt.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ impl DebugWithDb<LoweredFormatter<'_>> for BlockEnd {
102102
BlockEnd::Goto(block_id, remapping) => {
103103
return write!(f, " Goto({:?}, {:?})", block_id.debug(ctx), remapping.debug(ctx));
104104
}
105-
BlockEnd::NotSet => unreachable!(),
105+
BlockEnd::NotSet => return write!(f, " Not set"),
106106
BlockEnd::Match { info } => {
107107
return write!(f, " Match({:?})", info.debug(ctx));
108108
}

crates/cairo-lang-lowering/src/lower/block_builder.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
use cairo_lang_debug::DebugWithDb;
12
use cairo_lang_defs::ids::{MemberId, NamedLanguageElementId};
23
use cairo_lang_diagnostics::Maybe;
34
use cairo_lang_semantic as semantic;
5+
use cairo_lang_semantic::expr::fmt::ExprFormatter;
46
use cairo_lang_semantic::types::{peel_snapshots, wrap_in_snapshots};
57
use cairo_lang_semantic::usage::{MemberPath, Usage};
68
use cairo_lang_syntax::node::TypedStablePtr;
@@ -406,6 +408,22 @@ impl BlockBuilder {
406408
}
407409
}
408410

411+
impl<'a> DebugWithDb<ExprFormatter<'a>> for BlockBuilder {
412+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &ExprFormatter<'a>) -> std::fmt::Result {
413+
writeln!(f, "block_id: {:?}", self.block_id)?;
414+
if !self.statements.statements.is_empty() {
415+
writeln!(f, "statements:")?;
416+
for statement in &self.statements.statements {
417+
writeln!(f, " {statement:?}")?;
418+
}
419+
}
420+
writeln!(f, "semantics:")?;
421+
write!(f, "{}", indent::indent_all_with(" ", format!("{:?}", self.semantics.debug(db))))?;
422+
423+
Ok(())
424+
}
425+
}
426+
409427
/// Gets the type of a semantic variable.
410428
fn get_ty(ctx: &LoweringContext<'_, '_>, member_path: &MemberPath) -> semantic::TypeId {
411429
match member_path {
@@ -537,3 +555,28 @@ impl StructRecomposer for BlockStructRecomposer<'_, '_, '_> {
537555
self.ctx.db
538556
}
539557
}
558+
559+
/// Given a list of block builders, creates a new single block builder and finalizes all
560+
/// the block builders with a [BlockEnd::Goto] to the new block.
561+
///
562+
/// The mapping from semantic variables to lowered variables in the new block follows these rules:
563+
///
564+
/// * Variables mapped to the same lowered variable across all input blocks are kept as-is.
565+
/// * Local variables that appear in only a subset of the blocks are removed.
566+
/// * Variables with different mappings across blocks are remapped to a new lowered variable.
567+
///
568+
/// If only one parent builder is given, returns it without creating a new block.
569+
// TODO(lior): Remove `allow(dead_code)` once the function is used.
570+
#[allow(dead_code)]
571+
pub fn merge_block_builders(
572+
_ctx: &mut LoweringContext<'_, '_>,
573+
parent_builders: Vec<BlockBuilder>,
574+
_location: LocationId,
575+
) -> BlockBuilder {
576+
// If there is only one parent builder, return it.
577+
if parent_builders.len() == 1 {
578+
return parent_builders.into_iter().next().unwrap();
579+
}
580+
581+
todo!("Merging multiple block builders is not supported yet.");
582+
}
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
use cairo_lang_debug::DebugWithDb;
2+
use cairo_lang_semantic::corelib::unit_ty;
3+
use cairo_lang_semantic::expr::fmt::ExprFormatter;
4+
use cairo_lang_semantic::test_utils::{TestFunction, setup_test_function};
5+
use cairo_lang_semantic::usage::MemberPath;
6+
use cairo_lang_semantic::{self as semantic, Expr, Statement, StatementId};
7+
use cairo_lang_syntax::node::TypedStablePtr;
8+
use cairo_lang_test_utils::parse_test_file::TestRunnerResult;
9+
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
10+
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
11+
use cairo_lang_utils::{Upcast, extract_matches};
12+
use itertools::Itertools;
13+
14+
use super::block_builder::{BlockBuilder, merge_block_builders};
15+
use super::context::{LoweringContext, VarRequest};
16+
use super::test_utils::{create_encapsulating_ctx, create_lowering_context};
17+
use crate::VariableId;
18+
use crate::fmt::LoweredFormatter;
19+
use crate::test_utils::LoweringDatabaseForTesting;
20+
21+
const N_LOWERING_VARS: usize = 100;
22+
23+
cairo_lang_test_utils::test_file_test!(
24+
test_merge_block_builders,
25+
"src/lower/test_data",
26+
{
27+
merge_block_builders: "merge_block_builders",
28+
},
29+
test_merge_block_builders
30+
);
31+
32+
/// Tests the [merge_block_builders] function.
33+
///
34+
/// Each test case has the following input sections:
35+
/// - `variables`: A comma-separated list of "var_name: type".
36+
/// - `block_definitions`: Defines the semantic to lowering map of each input block.
37+
///
38+
/// For example:
39+
/// ```ignore
40+
/// ((x, 0),);
41+
/// ((x.a, 1), (y, 2));
42+
/// ```
43+
/// represents two blocks, where:
44+
/// * The first maps `x` to lowered variable 0.
45+
/// * The second maps `x.a` to lowered variable 1, and `y` to lowered variable 2.
46+
///
47+
/// Note that `x` and `x.*` should not be specified together in one block.
48+
///
49+
/// - `module_code`: Additional code for defining structs and helper functions.
50+
fn test_merge_block_builders(
51+
inputs: &OrderedHashMap<String, String>,
52+
_args: &OrderedHashMap<String, String>,
53+
) -> TestRunnerResult {
54+
let db = LoweringDatabaseForTesting::default();
55+
// Create a function with the given variables as parameters and the given block definitions
56+
// as a dummy function body.
57+
// Note that the function body is not lowered at any point, and it is only used to parse the
58+
// semantic to lowering map.
59+
let test_function = setup_test_function(
60+
&db,
61+
&format!("fn foo ({}) {{ {} }}", &inputs["variables"], &inputs["block_definitions"]),
62+
"foo",
63+
inputs.get("module_code").unwrap_or(&"".into()),
64+
)
65+
.unwrap();
66+
67+
let mut encapsulating_ctx =
68+
create_encapsulating_ctx(&db, test_function.function_id, &test_function.signature);
69+
70+
let mut ctx = create_lowering_context(
71+
&db,
72+
test_function.function_id,
73+
&test_function.signature,
74+
&mut encapsulating_ctx,
75+
);
76+
77+
// Create dummy lowering variables.
78+
let dummy_location = ctx.get_location(test_function.signature.stable_ptr.untyped());
79+
let lowering_vars: Vec<VariableId> = (0..N_LOWERING_VARS)
80+
.map(|_| ctx.new_var(VarRequest { ty: unit_ty(ctx.db), location: dummy_location }))
81+
.collect();
82+
83+
let expr_formatter = ExprFormatter { db: db.upcast(), function_id: test_function.function_id };
84+
85+
let input_blocks = create_block_builders(&mut ctx, &test_function, &lowering_vars);
86+
let input_blocks_str =
87+
input_blocks.iter().map(|b| format!("{:?}", b.debug(&expr_formatter))).join("\n");
88+
89+
// Invoke [merge_block_builders] on the input blocks.
90+
let merged_block = merge_block_builders(&mut ctx, input_blocks, dummy_location);
91+
92+
let lowered_formatter = LoweredFormatter::new(db.upcast(), &ctx.variables.variables);
93+
let lowered_blocks = ctx.blocks.build().unwrap();
94+
let lowered_str = lowered_blocks
95+
.iter()
96+
.map(|(block_id, block)| {
97+
format!(
98+
"{:?}:\n{:?}\n",
99+
block_id.debug(&lowered_formatter),
100+
block.debug(&lowered_formatter)
101+
)
102+
})
103+
.join("");
104+
105+
TestRunnerResult {
106+
outputs: OrderedHashMap::from([
107+
("input_blocks".into(), input_blocks_str),
108+
("merged_block_builder".into(), format!("{:?}", merged_block.debug(&expr_formatter))),
109+
("lowered".into(), lowered_str),
110+
]),
111+
error: None,
112+
}
113+
}
114+
115+
/// Creates a block builder for each semantic "statement" in the function body.
116+
///
117+
/// See [create_block_builder] for more details.
118+
fn create_block_builders(
119+
ctx: &mut LoweringContext<'_, '_>,
120+
test_function: &TestFunction,
121+
lowering_vars: &[VariableId],
122+
) -> Vec<BlockBuilder> {
123+
let expr = ctx.function_body.arenas.exprs[test_function.body].clone();
124+
let block_expr = extract_matches!(expr, Expr::Block);
125+
126+
block_expr
127+
.statements
128+
.iter()
129+
.map(|statement_id| create_block_builder(ctx, *statement_id, lowering_vars))
130+
.collect()
131+
}
132+
133+
/// Given a semantic "statement" of the form:
134+
/// `((member_path, lower_var_idx), ...)`
135+
/// creates a block builder with a semantic mapping that maps each member path to the corresponding
136+
/// given lowered variable.
137+
///
138+
/// Assumption: if a certain semantic variable is mapped, all its children should not be mapped.
139+
///
140+
/// Note that the statement is not a real statement - it is not lowered, and it is only used to
141+
/// define the semantic mapping.
142+
fn create_block_builder(
143+
ctx: &mut LoweringContext<'_, '_>,
144+
statement_id: StatementId,
145+
lowering_vars: &[VariableId],
146+
) -> BlockBuilder {
147+
let block_id = ctx.blocks.alloc_empty();
148+
let mut block_builder = BlockBuilder::root(block_id);
149+
let mut visited_vars: UnorderedHashSet<semantic::VarId> = Default::default();
150+
151+
let statement_expr =
152+
extract_matches!(&ctx.function_body.arenas.statements[statement_id], Statement::Expr);
153+
let external_tuple =
154+
extract_matches!(&ctx.function_body.arenas.exprs[statement_expr.expr], Expr::Tuple);
155+
156+
let expr_ids = external_tuple.items.clone();
157+
for expr_id in expr_ids {
158+
let inner_tuple = extract_matches!(&ctx.function_body.arenas.exprs[expr_id], Expr::Tuple);
159+
let lower_var_idx: usize = (&extract_matches!(
160+
&ctx.function_body.arenas.exprs[inner_tuple.items[1]],
161+
Expr::Literal
162+
)
163+
.value)
164+
.try_into()
165+
.unwrap();
166+
167+
match &ctx.function_body.arenas.exprs[inner_tuple.items[0]] {
168+
Expr::MemberAccess(member_access) => {
169+
let member_path: MemberPath = (member_access.member_path.as_ref().unwrap()).into();
170+
let mut var = &member_path;
171+
while let MemberPath::Member { parent: v, .. } = var {
172+
var = v;
173+
}
174+
let var_id = extract_matches!(var, MemberPath::Var);
175+
176+
if visited_vars.insert(*var_id) {
177+
block_builder.put_semantic(*var_id, lowering_vars[lower_var_idx]);
178+
}
179+
180+
let location = ctx.get_location(member_access.stable_ptr.untyped());
181+
block_builder.update_ref_raw(
182+
ctx,
183+
member_path,
184+
lowering_vars[lower_var_idx],
185+
location,
186+
);
187+
// Remove the statements that were created as part of the `update_ref_raw` call.
188+
block_builder.statements.statements.clear();
189+
}
190+
Expr::Var(var) => {
191+
if visited_vars.insert(var.var) {
192+
block_builder.put_semantic(var.var, lowering_vars[lower_var_idx]);
193+
}
194+
}
195+
expr => {
196+
panic!("Unexpected expression: {expr:?}");
197+
}
198+
}
199+
}
200+
201+
block_builder
202+
}

crates/cairo-lang-lowering/src/lower/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ mod lower_let_else;
7070
mod lower_match;
7171
pub mod refs;
7272

73+
#[cfg(test)]
74+
mod test_utils;
75+
76+
#[cfg(test)]
77+
mod block_builder_test;
78+
7379
#[cfg(test)]
7480
mod generated_test;
7581

crates/cairo-lang-lowering/src/lower/refs.rs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use cairo_lang_semantic::usage::MemberPath;
55
use cairo_lang_semantic::{self as semantic};
66
use cairo_lang_utils::extract_matches;
77
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8-
use itertools::chain;
8+
use itertools::{Itertools, chain};
99

1010
use crate::VariableId;
1111
use crate::db::LoweringGroup;
@@ -154,6 +154,15 @@ impl SemanticLoweringMapping {
154154
}
155155
}
156156

157+
impl<'a> cairo_lang_debug::debug::DebugWithDb<ExprFormatter<'a>> for SemanticLoweringMapping {
158+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &ExprFormatter<'a>) -> std::fmt::Result {
159+
for (member_path, value) in self.scattered.iter() {
160+
writeln!(f, "{:?}: {value}", member_path.debug(db))?;
161+
}
162+
Ok(())
163+
}
164+
}
165+
157166
/// A trait for deconstructing and constructing structs.
158167
pub trait StructRecomposer {
159168
fn deconstruct(
@@ -183,11 +192,26 @@ pub trait StructRecomposer {
183192
enum Value {
184193
/// The value of member path is stored in a lowered variable.
185194
Var(VariableId),
186-
/// The value of the member path is not stored. It should be reconstructed from the member
187-
/// values.
195+
/// The value of the member path is not stored. If needed, it should be reconstructed from the
196+
/// member values.
188197
Scattered(Box<Scattered>),
189198
}
190199

200+
impl std::fmt::Display for Value {
201+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202+
match self {
203+
Value::Var(var) => write!(f, "v{}", var.index()),
204+
Value::Scattered(scattered) => {
205+
write!(
206+
f,
207+
"Scattered({})",
208+
scattered.members.values().map(|value| value.to_string()).join(", ")
209+
)
210+
}
211+
}
212+
}
213+
}
214+
191215
/// A value for a non-stored member path. Recursively holds the [Value] for the members.
192216
#[derive(Clone, Debug, DebugWithDb)]
193217
#[debug_db(ExprFormatter<'a>)]
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//! > Single block
2+
3+
//! > test_runner_name
4+
test_merge_block_builders
5+
6+
//! > variables
7+
x: felt252
8+
9+
//! > block_definitions
10+
// A single block with a single variable `x`, mapped to lowered variable `v0`.
11+
(
12+
(x, 0),
13+
);
14+
15+
//! > input_blocks
16+
block_id: BlockId(0)
17+
semantics:
18+
Var(ParamId(test::x)): v0
19+
20+
//! > merged_block_builder
21+
block_id: BlockId(0)
22+
semantics:
23+
Var(ParamId(test::x)): v0
24+
25+
//! > lowered
26+
blk0:
27+
Statements:
28+
End:
29+
Not set

0 commit comments

Comments
 (0)