diff --git a/Cargo.toml b/Cargo.toml index 08b46bc6bc..b5873fed54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ with-cheatcode = [] with-debug-utils = [] with-mem-tracing = [] with-libfunc-profiling = [] +with-libfunc-counter = [] with-segfault-catcher = [] with-trace-dump = ["dep:sierra-emu"] diff --git a/src/bin/cairo-native-run.rs b/src/bin/cairo-native-run.rs index ce40390b2f..af8a9025cf 100644 --- a/src/bin/cairo-native-run.rs +++ b/src/bin/cairo-native-run.rs @@ -6,6 +6,8 @@ use cairo_lang_runner::short_string::as_cairo_short_string; #[cfg(feature = "with-libfunc-profiling")] use cairo_lang_sierra::ids::ConcreteLibfuncId; use cairo_lang_sierra_to_casm::metadata::MetadataComputationConfig; +#[cfg(feature = "with-libfunc-counter")] +use cairo_native::metadata::libfunc_counter::libfunc_counter_runtime::CountersArrayGuard; #[cfg(feature = "with-libfunc-profiling")] use cairo_native::metadata::profiler::LibfuncProfileData; use cairo_native::{ @@ -57,6 +59,11 @@ struct Args { /// The output path for the libfunc profilling results profiler_output: Option, + #[cfg(feature = "with-libfunc-counter")] + #[arg(long)] + /// The output path for the execution trace + libfunc_counter_output: Option, + #[cfg(feature = "with-trace-dump")] #[arg(long)] /// The output path for the execution trace @@ -106,6 +113,9 @@ fn main() -> anyhow::Result<()> { .compile(&sierra_program, false, Some(Default::default()), None) .unwrap(); + #[cfg(feature = "with-libfunc-counter")] + let libfuncs_amount = sierra_program.clone().libfunc_declarations.len(); + let native_executor: Box _> = match args.run_mode { RunMode::Aot => { let executor = @@ -120,13 +130,48 @@ fn main() -> anyhow::Result<()> { } } + #[cfg(feature = "with-libfunc-counter")] + { + use cairo_native::metadata::libfunc_counter::LibfuncCounterBinding; + if let Some(counter_id) = + executor.find_symbol_ptr(LibfuncCounterBinding::CounterId.symbol()) + { + let counter_id = counter_id.cast::(); + unsafe { *counter_id = 0 }; + } + } + Box::new(move |function_id, args, gas, syscall_handler| { - executor.invoke_dynamic_with_syscall_handler( + #[cfg(feature = "with-libfunc-counter")] + let array_counter_guard = CountersArrayGuard::init(libfuncs_amount); + + let result = executor.invoke_dynamic_with_syscall_handler( function_id, args, gas, syscall_handler, - ) + ); + + #[cfg(feature = "with-libfunc-counter")] + { + use cairo_native::metadata::libfunc_counter::libfunc_counter_runtime; + use cairo_native::metadata::libfunc_counter::LibfuncCounterBinding; + + let counter_id_ptr = executor + .find_symbol_ptr(LibfuncCounterBinding::CounterId.symbol()) + .expect(""); + + unsafe { + libfunc_counter_runtime::store_and_free_counters_array( + counter_id_ptr as *mut u64, + libfuncs_amount, + ); + } + + drop(array_counter_guard); + }; + + result }) } RunMode::Jit => { @@ -136,6 +181,7 @@ fn main() -> anyhow::Result<()> { #[cfg(feature = "with-trace-dump")] { use cairo_native::metadata::trace_dump::TraceBinding; + if let Some(trace_id) = executor.find_symbol_ptr(TraceBinding::TraceId.symbol()) { let trace_id = trace_id.cast::(); unsafe { *trace_id = 0 }; @@ -154,13 +200,49 @@ fn main() -> anyhow::Result<()> { } } + #[cfg(feature = "with-libfunc-counter")] + { + use cairo_native::metadata::libfunc_counter::LibfuncCounterBinding; + + if let Some(counter_id) = + executor.find_symbol_ptr(LibfuncCounterBinding::CounterId.symbol()) + { + let counter_id = counter_id.cast::(); + unsafe { *counter_id = 0 }; + } + } + Box::new(move |function_id, args, gas, syscall_handler| { - executor.invoke_dynamic_with_syscall_handler( + #[cfg(feature = "with-libfunc-counter")] + let array_counter_guard = CountersArrayGuard::init(libfuncs_amount); + + let result = executor.invoke_dynamic_with_syscall_handler( function_id, args, gas, syscall_handler, - ) + ); + + #[cfg(feature = "with-libfunc-counter")] + { + use cairo_native::metadata::libfunc_counter::libfunc_counter_runtime; + use cairo_native::metadata::libfunc_counter::LibfuncCounterBinding; + + let counter_id_ptr = executor + .find_symbol_ptr(LibfuncCounterBinding::CounterId.symbol()) + .expect(""); + + unsafe { + libfunc_counter_runtime::store_and_free_counters_array( + counter_id_ptr as *mut u64, + libfuncs_amount, + ); + } + + drop(array_counter_guard); + }; + + result }) } }; @@ -281,6 +363,35 @@ fn main() -> anyhow::Result<()> { } } + #[cfg(feature = "with-libfunc-counter")] + if let Some(libfunc_counter_output) = args.libfunc_counter_output { + use std::collections::HashMap; + + let counters = + cairo_native::metadata::libfunc_counter::libfunc_counter_runtime::LIBFUNC_COUNTER + .lock() + .unwrap(); + assert_eq!(counters.len(), 1); + + let libfunc_counter = counters.values().next().unwrap(); + + let libfunc_counts = libfunc_counter + .iter() + .enumerate() + .map(|(i, count)| { + let libfunc = &sierra_program.libfunc_declarations[i]; + let debug_name = libfunc.id.debug_name.clone().unwrap().to_string(); + + (debug_name, *count) + }) + .collect::>(); + serde_json::to_writer_pretty( + std::fs::File::create(libfunc_counter_output).unwrap(), + &libfunc_counts, + ) + .unwrap(); + } + #[cfg(feature = "with-trace-dump")] if let Some(trace_output) = args.trace_output { let traces = cairo_native::metadata::trace_dump::trace_dump_runtime::TRACE_DUMP diff --git a/src/compiler.rs b/src/compiler.rs index cecf22f5f9..9de3925c2c 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -60,6 +60,8 @@ use crate::{ utils::{generate_function_name, walk_ir::walk_mlir_block, BlockExt}, }; use bumpalo::Bump; +#[cfg(feature = "with-libfunc-counter")] +use cairo_lang_sierra::ids::ConcreteLibfuncId; use cairo_lang_sierra::{ edit_state, extensions::{ @@ -151,6 +153,14 @@ pub fn compile( let n_libfuncs = program.libfunc_declarations.len() + 1; let sierra_stmt_start_offset = num_types + n_libfuncs + 1; + #[cfg(feature = "with-libfunc-counter")] + let libfunc_indexes = program + .libfunc_declarations + .iter() + .enumerate() + .map(|(idx, libf)| (libf.id.clone(), idx)) + .collect::>(); + for function in &program.funcs { tracing::info!("Compiling function `{}`.", function.id); compile_func( @@ -159,6 +169,8 @@ pub fn compile( registry, function, &program.statements, + #[cfg(feature = "with-libfunc-counter")] + &libfunc_indexes, metadata, di_compile_unit_id, sierra_stmt_start_offset, @@ -186,6 +198,7 @@ fn compile_func( registry: &ProgramRegistry, function: &Function, statements: &[Statement], + #[cfg(feature = "with-libfunc-counter")] libfunc_indexes: &HashMap, metadata: &mut MetadataStorage, di_compile_unit_id: Attribute, sierra_stmt_start_offset: usize, @@ -640,6 +653,23 @@ fn compile_func( }, }; + #[cfg(feature = "with-libfunc-counter")] + { + // Can't fail. If we had a key from the invocation statement that is not included in this hashmap, + // that would mean that there's an error in the sierra program since we got to invoke a libfunc that has + // not been declared. + let libfunc_idx = libfunc_indexes.get(&invocation.libfunc_id).unwrap(); + + crate::metadata::libfunc_counter::libfunc_counter_runtime::count_libfunc( + context, + module, + block, + location, + metadata, + *libfunc_idx, + )?; + } + libfunc.build( context, registry, diff --git a/src/executor/aot.rs b/src/executor/aot.rs index 56ee1b2b11..113810ee09 100644 --- a/src/executor/aot.rs +++ b/src/executor/aot.rs @@ -63,6 +63,9 @@ impl AotNativeExecutor { #[cfg(feature = "with-libfunc-profiling")] crate::metadata::profiler::setup_runtime(|name| executor.find_symbol_ptr(name)); + #[cfg(feature = "with-libfunc-counter")] + crate::metadata::libfunc_counter::setup_runtime(|name| executor.find_symbol_ptr(name)); + executor } diff --git a/src/executor/contract.rs b/src/executor/contract.rs index cebb976115..8acb9ef116 100644 --- a/src/executor/contract.rs +++ b/src/executor/contract.rs @@ -337,6 +337,9 @@ impl AotContractExecutor { #[cfg(feature = "with-libfunc-profiling")] crate::metadata::profiler::setup_runtime(|name| executor.find_symbol_ptr(name)); + #[cfg(feature = "with-libfunc-counter")] + crate::metadata::libfunc_counter::setup_runtime(|name| executor.find_symbol_ptr(name)); + Ok(Some(executor)) } diff --git a/src/executor/jit.rs b/src/executor/jit.rs index f7cbe58c74..af38f8d362 100644 --- a/src/executor/jit.rs +++ b/src/executor/jit.rs @@ -74,6 +74,9 @@ impl<'m> JitNativeExecutor<'m> { #[cfg(feature = "with-libfunc-profiling")] crate::metadata::profiler::setup_runtime(|name| executor.find_symbol_ptr(name)); + #[cfg(feature = "with-libfunc-counter")] + crate::metadata::libfunc_counter::setup_runtime(|name| executor.find_symbol_ptr(name)); + Ok(executor) } diff --git a/src/metadata.rs b/src/metadata.rs index e33656b0cf..01fe487f45 100644 --- a/src/metadata.rs +++ b/src/metadata.rs @@ -21,6 +21,7 @@ pub mod dup_overrides; pub mod enum_snapshot_variants; pub mod felt252_dict; pub mod gas; +pub mod libfunc_counter; pub mod profiler; pub mod realloc_bindings; pub mod runtime_bindings; diff --git a/src/metadata/libfunc_counter.rs b/src/metadata/libfunc_counter.rs new file mode 100644 index 0000000000..bbec512905 --- /dev/null +++ b/src/metadata/libfunc_counter.rs @@ -0,0 +1,307 @@ +#![cfg(feature = "with-libfunc-counter")] +//! The libfunc counter feature is used to generate information counting how many time a libfunc has been called. +//! +//! When this feature is used, the compiler will call one main method: +//! +//! 1. `count_libfunc`: called before every libfunc execution. This method will handle the counting. Given the index +//! of a libfunc (relative to its declaration order), it accesses the array of counters and updates the counter. +//! +//! In the context of Starknet contracts, we need to add support for building the array of counters for multiple executions. +//! To do so, we need one important element which must be set before every contract execution: +//! +//! * A counter to track the ID of the current array of counter, which gets updated every time we switch to another +//! contract. Since a contract can call other contracts, we need a way of restoring the counter after every execution. +//! +//! * An array-of-counters guard. Every time a new entrypoint is executed, a new array of counters needs to be created. +//! The guard keeps the last array that was used to restore it once the inner entrypoint execution has finished. +//! +//! See `cairo-native-run` for an example on how to do it. +use std::collections::HashSet; + +use melior::{ + dialect::{llvm, memref, ods}, + ir::{ + attribute::{FlatSymbolRefAttribute, StringAttribute, TypeAttribute}, + operation::OperationBuilder, + r#type::{IntegerType, MemRefType}, + Attribute, Block, BlockLike, Location, Module, Region, Value, + }, + Context, +}; + +use crate::{ + error::{Error, Result}, + utils::{BlockExt, GepIndex}, +}; + +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub enum LibfuncCounterBinding { + CounterId, + GetCounterArray, +} + +impl LibfuncCounterBinding { + pub const fn symbol(self) -> &'static str { + match self { + LibfuncCounterBinding::CounterId => "cairo_native__counter_id", + LibfuncCounterBinding::GetCounterArray => "cairo_native__get_counters_array", + } + } + + pub const fn function_ptr(self) -> *const () { + match self { + LibfuncCounterBinding::CounterId => std::ptr::null(), + LibfuncCounterBinding::GetCounterArray => { + libfunc_counter_runtime::get_counters_array as *const () + } + } + } +} + +#[derive(Clone, Default)] +pub struct LibfuncCounterMeta { + active_map: HashSet, +} + +impl LibfuncCounterMeta { + pub fn new() -> Self { + Self { + active_map: HashSet::new(), + } + } + + /// Register the global for the given binding, if not yet registered, and return + /// a pointer to the stored value. + /// + /// For the function to be available, `setup_runtime` must be called before running the module. + pub fn build_function<'c, 'a>( + &mut self, + context: &'c Context, + module: &Module, + block: &'a Block<'c>, + location: Location<'c>, + binding: LibfuncCounterBinding, + ) -> Result> { + if self.active_map.insert(binding) { + module.body().append_operation( + ods::llvm::mlir_global( + context, + Region::new(), + TypeAttribute::new(llvm::r#type::pointer(context, 0)), + StringAttribute::new(context, binding.symbol()), + Attribute::parse(context, "#llvm.linkage") + .ok_or(Error::ParseAttributeError)?, + location, + ) + .into(), + ); + } + + let global_address = block.append_op_result( + ods::llvm::mlir_addressof( + context, + llvm::r#type::pointer(context, 0), + FlatSymbolRefAttribute::new(context, binding.symbol()), + location, + ) + .into(), + )?; + + block.load( + context, + location, + global_address, + llvm::r#type::pointer(context, 0), + ) + } + + fn build_counter_id<'c, 'a>( + &mut self, + context: &'c Context, + module: &Module, + block: &'a Block<'c>, + location: Location<'c>, + ) -> Result> { + if self.active_map.insert(LibfuncCounterBinding::CounterId) { + module.body().append_operation(memref::global( + context, + LibfuncCounterBinding::CounterId.symbol(), + None, + MemRefType::new(IntegerType::new(context, 64).into(), &[], None, None), + None, + false, + None, + location, + )); + } + + let libfunc_counter_id_ptr = block + .append_op_result(memref::get_global( + context, + LibfuncCounterBinding::CounterId.symbol(), + MemRefType::new(IntegerType::new(context, 64).into(), &[], None, None), + location, + )) + .unwrap(); + + block.append_op_result(memref::load(libfunc_counter_id_ptr, &[], location)) + } + + /// Returns the array of counters. + fn get_array_counter<'c, 'a>( + &mut self, + context: &'c Context, + module: &Module, + block: &'a Block<'c>, + location: Location<'c>, + ) -> Result> { + self.build_counter_id(context, module, block, location)?; + + let function_ptr = self.build_function( + context, + module, + block, + location, + LibfuncCounterBinding::GetCounterArray, + )?; + + block.append_op_result( + OperationBuilder::new("llvm.call", location) + .add_operands(&[function_ptr]) + .add_results(&[llvm::r#type::pointer(context, 0)]) + .build()?, + ) + } + + pub fn count_libfunc( + &mut self, + context: &Context, + module: &Module, + block: &Block<'_>, + location: Location, + libfunc_idx: usize, + ) -> Result<()> { + let u32_ty = IntegerType::new(context, 32).into(); + let k1 = block.const_int(context, location, 1, 32)?; + + let array_counter_ptr = self.get_array_counter(context, module, block, location)?; + + let value_counter_ptr = block.gep( + context, + location, + array_counter_ptr, + &[GepIndex::Const(libfunc_idx as i32)], + u32_ty, + )?; + + let value_counter = block.load(context, location, value_counter_ptr, u32_ty)?; + let value_incremented = block.addi(value_counter, k1, location)?; + + block.store(context, location, value_counter_ptr, value_incremented)?; + + Ok(()) + } +} + +pub fn setup_runtime(find_symbol_ptr: impl Fn(&str) -> Option<*mut libc::c_void>) { + let bindings = &[LibfuncCounterBinding::GetCounterArray]; + + for binding in bindings { + if let Some(global) = find_symbol_ptr(binding.symbol()) { + let global = global.cast::<*const ()>(); + unsafe { *global = binding.function_ptr() }; + } + } +} + +pub mod libfunc_counter_runtime { + use core::slice; + use std::{ + cell::Cell, + collections::HashMap, + sync::{LazyLock, Mutex}, + }; + + use melior::{ + ir::{Block, Location, Module}, + Context, + }; + + use crate::{ + error::Result, + metadata::{libfunc_counter::LibfuncCounterMeta, MetadataStorage}, + utils::{libc_free, libc_malloc}, + }; + + /// Contains an array of vector for each execution completed. + pub static LIBFUNC_COUNTER: LazyLock>>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + + thread_local! { + pub(crate) static COUNTERS_ARRAY: Cell<*mut u32> = const { + // This value will be overritten before executing the code + Cell::new(std::ptr::null_mut()) + }; + } + + /// In the context of Starknet, a contract may call another. This means we + /// need as many arrays of counters as call contracts are invoked during execution. + /// This struct is used to hold the current array before calling the next contract + /// so that it can then be restored. + pub struct CountersArrayGuard(pub *mut u32); + + impl CountersArrayGuard { + pub fn init(libfuncs_amount: usize) -> CountersArrayGuard { + let u32_libfuncs_amount = libfuncs_amount * 4; + let new_array: *mut u32 = unsafe { libc_malloc(u32_libfuncs_amount).cast() }; + + // All positions in the array must be initialized with 0. Since + // some libfuncs declared may not be called, their respective counter + // won't be updated. + for i in 0..libfuncs_amount { + unsafe { *(new_array.add(i)) = 0 }; + } + + Self(COUNTERS_ARRAY.replace(new_array)) + } + } + + impl Drop for CountersArrayGuard { + fn drop(&mut self) { + COUNTERS_ARRAY.set(self.0); + } + } + + /// Update the libfunc's counter based on its index, relative to the order of declaration. + pub fn count_libfunc( + context: &Context, + module: &Module, + block: &Block, + location: Location, + metadata: &mut MetadataStorage, + libfunc_idx: usize, + ) -> Result<()> { + let libfunc_counter = metadata.get_or_insert_with(LibfuncCounterMeta::default); + + libfunc_counter.count_libfunc(context, module, block, location, libfunc_idx) + } + + pub extern "C" fn get_counters_array() -> *mut u32 { + COUNTERS_ARRAY.with(|x| x.get()) + } + + /// Converts the pointer to the counters into a Rust `Vec` and store it. Then, it frees the pointer. + /// + /// This method should be called at the end of an entrypoint execution. + pub unsafe fn store_and_free_counters_array(counter_id_ptr: *mut u64, libfuncs_amount: usize) { + let counter_array_ptr = get_counters_array(); + let counters_vec = slice::from_raw_parts(counter_array_ptr, libfuncs_amount).to_vec(); + + LIBFUNC_COUNTER + .lock() + .unwrap() + .insert(*counter_id_ptr, counters_vec); + + libc_free(counter_array_ptr as *mut libc::c_void); + } +}