Skip to content

Commit 0f20b84

Browse files
committed
Better parallelization of function compilation.
1 parent 97f166e commit 0f20b84

File tree

4 files changed

+142
-157
lines changed

4 files changed

+142
-157
lines changed

crates/cairo-lang-compiler/src/diagnostics.rs

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
use std::fmt::Write;
22

3-
use cairo_lang_defs::db::DefsGroup;
43
use cairo_lang_defs::ids::ModuleId;
54
use cairo_lang_diagnostics::{
65
DiagnosticEntry, Diagnostics, FormattedDiagnosticEntry, PluginFileDiagnosticNotes, Severity,
76
};
87
use cairo_lang_filesystem::ids::{CrateId, FileLongId};
98
use cairo_lang_lowering::db::LoweringGroup;
10-
use cairo_lang_parser::db::ParserGroup;
11-
use cairo_lang_semantic::db::SemanticGroup;
129
use cairo_lang_utils::LookupIntern;
1310
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
14-
use rayon::ThreadPool;
1511
use thiserror::Error;
1612

1713
use crate::db::RootDatabase;
@@ -139,7 +135,7 @@ impl<'a> DiagnosticsReporter<'a> {
139135
}
140136

141137
/// Returns the crate ids for which the diagnostics will be checked.
142-
fn crates_of_interest(&self, db: &dyn LoweringGroup) -> Vec<CrateId> {
138+
pub(crate) fn crates_of_interest(&self, db: &dyn LoweringGroup) -> Vec<CrateId> {
143139
if self.crate_ids.is_empty() { db.crates() } else { self.crate_ids.clone() }
144140
}
145141

@@ -250,35 +246,6 @@ impl<'a> DiagnosticsReporter<'a> {
250246
if self.check(db) { Err(DiagnosticsError) } else { Ok(()) }
251247
}
252248

253-
/// Spawns threads to compute the diagnostics queries, making sure later calls for these queries
254-
/// would be faster as the queries were already computed.
255-
pub(crate) fn warm_up_diagnostics(&self, db: &RootDatabase, pool: &ThreadPool) {
256-
let crates = self.crates_of_interest(db);
257-
for crate_id in crates {
258-
let snapshot = salsa::ParallelDatabase::snapshot(db);
259-
pool.spawn(move || {
260-
let db = &*snapshot;
261-
262-
let crate_modules = db.crate_modules(crate_id);
263-
for module_id in crate_modules.iter().copied() {
264-
let snapshot = salsa::ParallelDatabase::snapshot(db);
265-
rayon::spawn(move || {
266-
let db = &*snapshot;
267-
for file_id in
268-
db.module_files(module_id).unwrap_or_default().iter().copied()
269-
{
270-
db.file_syntax_diagnostics(file_id);
271-
}
272-
273-
let _ = db.module_semantic_diagnostics(module_id);
274-
275-
let _ = db.module_lowering_diagnostics(module_id);
276-
});
277-
}
278-
});
279-
}
280-
}
281-
282249
pub fn skip_lowering_diagnostics(mut self) -> Self {
283250
self.skip_lowering_diagnostics = true;
284251
self

crates/cairo-lang-compiler/src/lib.rs

Lines changed: 134 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,23 @@ use std::sync::{Arc, Mutex};
77

88
use ::cairo_lang_diagnostics::ToOption;
99
use anyhow::{Context, Result};
10-
use cairo_lang_filesystem::ids::CrateId;
10+
use cairo_lang_defs::db::DefsGroup;
11+
use cairo_lang_defs::ids::ModuleId;
12+
use cairo_lang_filesystem::ids::{CrateId, FileId};
13+
use cairo_lang_lowering::db::LoweringGroup;
1114
use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
1215
use cairo_lang_lowering::utils::InliningStrategy;
16+
use cairo_lang_lowering::{self as lowering, LoweringStage};
17+
use cairo_lang_parser::db::ParserGroup;
18+
use cairo_lang_semantic::db::SemanticGroup;
1319
use cairo_lang_sierra::debug_info::{Annotations, DebugInfo};
1420
use cairo_lang_sierra::program::{Program, ProgramArtifact};
1521
use cairo_lang_sierra_generator::db::SierraGenGroup;
1622
use cairo_lang_sierra_generator::executables::{collect_executables, find_executable_function_ids};
17-
use cairo_lang_sierra_generator::program_generator::{
18-
SierraProgramWithDebug, try_get_function_with_body_id,
19-
};
23+
use cairo_lang_sierra_generator::program_generator::SierraProgramWithDebug;
2024
use cairo_lang_sierra_generator::replace_ids::replace_sierra_ids_in_program;
2125
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
22-
use rayon::{ThreadPool, ThreadPoolBuilder};
26+
use rayon::ThreadPoolBuilder;
2327

2428
use crate::db::RootDatabase;
2529
use crate::diagnostics::{DiagnosticsError, DiagnosticsReporter};
@@ -154,136 +158,154 @@ pub fn compile_prepared_db(
154158
Ok(sierra_program_with_debug)
155159
}
156160

157-
/// Context for database warmup.
158-
///
159-
/// This struct will spawn a thread pool that can be used for parallel database warmup.
160-
/// This can be both diagnostics warmup and function compilation warmup.
161-
/// We encapsulate the thread pool here so that we can reuse it easily for both.
162-
/// Note: Usually diagnostics should be checked as early as possible to avoid running into
163-
/// compilation errors that have not been reported to the user yet (which can result in compiler
164-
/// panic). This requires us to split the diagnostics warmup and function compilation warmup into
165-
/// two separate steps (note that we don't usually know the `ConcreteFunctionWithBodyId` yet when
166-
/// calculating diagnostics).
167-
pub enum DbWarmupContext {
168-
Warmup { pool: ThreadPool },
169-
NoWarmup,
161+
/// Checks if parallelism is available for warmup.
162+
fn should_warmup() -> bool {
163+
rayon::current_num_threads() > 1
170164
}
171165

172-
impl DbWarmupContext {
173-
/// Creates a new thread pool.
174-
pub fn new() -> Self {
175-
if !Self::should_warmup() {
176-
return Self::NoWarmup;
177-
}
166+
/// Performs parallel database warmup for diagnostics (if possible).
167+
pub fn warmup_diagnostics(db: &RootDatabase, diagnostic_reporter: &DiagnosticsReporter<'_>) {
168+
if !should_warmup() {
169+
return;
170+
}
171+
let crates = diagnostic_reporter.crates_of_interest(db);
172+
let snapshot = salsa::ParallelDatabase::snapshot(db);
173+
rayon::spawn(move || {
174+
let db = &*snapshot;
178175
const MAX_WARMUP_PARALLELISM: usize = 4;
179176
let pool = ThreadPoolBuilder::new()
180177
.num_threads(rayon::current_num_threads().min(MAX_WARMUP_PARALLELISM))
181178
.build()
182179
.expect("failed to build rayon thread pool");
183-
Self::Warmup { pool }
184-
}
185-
186-
/// Checks if parallelism is available for warmup.
187-
fn should_warmup() -> bool {
188-
rayon::current_num_threads() > 1
189-
}
190-
191-
/// Performs parallel database warmup (if possible)
192-
fn warmup_diagnostics(
193-
&self,
194-
db: &RootDatabase,
195-
diagnostic_reporter: &mut DiagnosticsReporter<'_>,
196-
) {
197-
match self {
198-
Self::Warmup { pool } => diagnostic_reporter.warm_up_diagnostics(db, pool),
199-
Self::NoWarmup => {}
200-
}
201-
}
202-
203-
/// Checks if there are diagnostics and reports them to the provided callback as strings.
204-
/// Returns `Err` if diagnostics were found.
205-
///
206-
/// Performs parallel database warmup (if possible) and calls `DiagnosticsReporter::ensure`.
207-
pub fn ensure_diagnostics(
208-
&self,
209-
db: &RootDatabase,
210-
diagnostic_reporter: &mut DiagnosticsReporter<'_>,
211-
) -> std::result::Result<(), DiagnosticsError> {
212-
self.warmup_diagnostics(db, diagnostic_reporter);
213-
diagnostic_reporter.ensure(db)?;
214-
Ok(())
215-
}
216-
217-
/// Spawns a task to warm up the db for the requested functions (if possible).
218-
fn warmup_db(
219-
&self,
220-
db: &RootDatabase,
221-
requested_function_ids: Vec<ConcreteFunctionWithBodyId>,
222-
) {
223-
match self {
224-
Self::Warmup { pool } => {
225-
let snapshot = salsa::ParallelDatabase::snapshot(db);
226-
pool.spawn(move || warmup_db_blocking(snapshot, requested_function_ids));
227-
}
228-
Self::NoWarmup => {}
180+
for crate_id in crates {
181+
let snapshot = salsa::ParallelDatabase::snapshot(db);
182+
pool.spawn(move || {
183+
let db = &*snapshot;
184+
let processed_file_ids =
185+
Arc::new(Mutex::new(UnorderedHashSet::<FileId>::default()));
186+
fn handle_module(
187+
db: &RootDatabase,
188+
processed_file_ids: Arc<Mutex<UnorderedHashSet<FileId>>>,
189+
module_id: ModuleId,
190+
) {
191+
let mut has_inner_calls = false;
192+
if let Ok(submodule_ids) = db.module_submodules_ids(module_id) {
193+
for submodule_module_id in submodule_ids.iter().copied() {
194+
let snapshot = salsa::ParallelDatabase::snapshot(db);
195+
let processed_file_ids = processed_file_ids.clone();
196+
rayon::spawn(move || {
197+
let db = &*snapshot;
198+
handle_module(
199+
db,
200+
processed_file_ids,
201+
ModuleId::Submodule(submodule_module_id),
202+
);
203+
});
204+
has_inner_calls = true;
205+
}
206+
}
207+
if has_inner_calls {
208+
rayon::yield_local();
209+
}
210+
for file_id in db.module_files(module_id).unwrap_or_default().iter().copied() {
211+
if !processed_file_ids.lock().unwrap().insert(file_id) {
212+
continue;
213+
}
214+
db.file_syntax_diagnostics(file_id);
215+
}
216+
let _ = db.module_semantic_diagnostics(module_id);
217+
let _ = db.module_lowering_diagnostics(module_id);
218+
}
219+
handle_module(db, processed_file_ids, ModuleId::CrateRoot(crate_id));
220+
});
229221
}
230-
}
222+
});
223+
}
224+
// pub fn warmup_diagnostics_blocking(
225+
// db: &RootDatabase,
226+
// diagnostic_reporter: &DiagnosticsReporter<'_>,
227+
// ) {
228+
// fn handle_crate()
229+
// }
230+
231+
/// Checks if there are diagnostics and reports them to the provided callback as strings.
232+
/// Returns `Err` if diagnostics were found.
233+
///
234+
/// Performs parallel database warmup (if possible) and calls `DiagnosticsReporter::ensure`.
235+
pub fn ensure_diagnostics(
236+
db: &RootDatabase,
237+
diagnostic_reporter: &mut DiagnosticsReporter<'_>,
238+
) -> std::result::Result<(), DiagnosticsError> {
239+
warmup_diagnostics(db, diagnostic_reporter);
240+
diagnostic_reporter.ensure(db)?;
241+
Ok(())
231242
}
232243

233-
impl Default for DbWarmupContext {
234-
fn default() -> Self {
235-
Self::new()
244+
/// Spawns a task to warm up the db for the requested functions (if possible).
245+
fn warmup_functions(db: &RootDatabase, requested_function_ids: &[ConcreteFunctionWithBodyId]) {
246+
if !should_warmup() {
247+
return;
236248
}
249+
let requested_function_ids = requested_function_ids.to_vec();
250+
let snapshot = salsa::ParallelDatabase::snapshot(db);
251+
rayon::spawn(move || warmup_functions_blocking(snapshot, requested_function_ids));
237252
}
238253

239254
/// Spawns threads to compute the `function_with_body_sierra` query and all dependent queries for
240255
/// the requested functions and their dependencies.
241256
///
242257
/// Note that typically spawn_warmup_db should be used as this function is blocking.
243-
fn warmup_db_blocking(
258+
fn warmup_functions_blocking(
244259
snapshot: salsa::Snapshot<RootDatabase>,
245260
requested_function_ids: Vec<ConcreteFunctionWithBodyId>,
246261
) {
262+
fn handle_func<'a>(
263+
s: &rayon::Scope<'a>,
264+
processed_function_ids: &'a Mutex<UnorderedHashSet<ConcreteFunctionWithBodyId>>,
265+
snapshot: salsa::Snapshot<RootDatabase>,
266+
func_id: ConcreteFunctionWithBodyId,
267+
) {
268+
if !processed_function_ids.lock().unwrap().insert(func_id) {
269+
return;
270+
}
271+
s.spawn(move |s| {
272+
let db = &*snapshot;
273+
let Ok(lowered) = db.lowered_body(func_id, LoweringStage::Monomorphized) else {
274+
return;
275+
};
276+
let mut has_inner_calls = false;
277+
let mut handle_callee = |callee: lowering::ids::FunctionId| {
278+
if let Ok(Some(callee)) = callee.body(db) {
279+
let snapshot = salsa::ParallelDatabase::snapshot(&*snapshot);
280+
s.spawn(move |s| handle_func(s, processed_function_ids, snapshot, callee));
281+
has_inner_calls = true;
282+
}
283+
};
284+
for (_, block) in lowered.blocks.iter() {
285+
for statement in &block.statements {
286+
if let lowering::Statement::Call(call_stmt) = statement {
287+
handle_callee(call_stmt.function);
288+
}
289+
}
290+
if let lowering::BlockEnd::Match { info: lowering::MatchInfo::Extern(info) } =
291+
&block.end
292+
{
293+
handle_callee(info.function);
294+
}
295+
}
296+
if has_inner_calls {
297+
rayon::yield_local();
298+
}
299+
300+
let _ = db.function_with_body_sierra(func_id);
301+
});
302+
}
247303
let processed_function_ids =
248304
&Mutex::new(UnorderedHashSet::<ConcreteFunctionWithBodyId>::default());
249305
rayon::scope(move |s| {
250306
for func_id in requested_function_ids {
251307
let snapshot = salsa::ParallelDatabase::snapshot(&*snapshot);
252-
253-
s.spawn(move |_| {
254-
fn handle_func_inner(
255-
processed_function_ids: &Mutex<UnorderedHashSet<ConcreteFunctionWithBodyId>>,
256-
snapshot: salsa::Snapshot<RootDatabase>,
257-
func_id: ConcreteFunctionWithBodyId,
258-
) {
259-
if processed_function_ids.lock().unwrap().insert(func_id) {
260-
rayon::scope(move |s| {
261-
let db = &*snapshot;
262-
let Ok(function) = db.function_with_body_sierra(func_id) else {
263-
return;
264-
};
265-
for statement in &function.body {
266-
let Some(related_function_id) =
267-
try_get_function_with_body_id(db, statement)
268-
else {
269-
continue;
270-
};
271-
272-
let snapshot = salsa::ParallelDatabase::snapshot(&*snapshot);
273-
s.spawn(move |_| {
274-
handle_func_inner(
275-
processed_function_ids,
276-
snapshot,
277-
related_function_id,
278-
)
279-
})
280-
}
281-
});
282-
}
283-
}
284-
285-
handle_func_inner(processed_function_ids, snapshot, func_id)
286-
});
308+
s.spawn(move |s| handle_func(s, processed_function_ids, snapshot, func_id));
287309
}
288310
});
289311
}
@@ -293,11 +315,10 @@ fn warmup_db_blocking(
293315
pub fn get_sierra_program_for_functions(
294316
db: &RootDatabase,
295317
requested_function_ids: Vec<ConcreteFunctionWithBodyId>,
296-
context: DbWarmupContext,
297318
) -> Result<Arc<SierraProgramWithDebug>> {
298-
context.warmup_db(db, requested_function_ids.clone());
319+
warmup_functions(db, &requested_function_ids);
299320
db.get_sierra_program_for_functions(requested_function_ids)
300-
.to_option()
321+
.ok()
301322
.with_context(|| "Compilation failed without any diagnostics.")
302323
}
303324

0 commit comments

Comments
 (0)