From 0d05388b23d3a58e781896b785f97487535e4c61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rom=C3=A1n=20C=C3=A1rdenas=20Rodr=C3=ADguez?= Date: Tue, 8 Jul 2025 20:02:36 +0200 Subject: [PATCH] Add __post_init to riscv-rt --- riscv-rt/CHANGELOG.md | 5 + riscv-rt/Cargo.toml | 3 +- riscv-rt/macros/src/lib.rs | 129 +++++++++++++++--- riscv-rt/src/lib.rs | 15 ++ tests-trybuild/Cargo.toml | 2 +- .../riscv-rt/post_init/fail_arg_count.rs | 4 + .../riscv-rt/post_init/fail_arg_count.stderr | 5 + .../tests/riscv-rt/post_init/fail_arg_type.rs | 4 + .../riscv-rt/post_init/fail_arg_type.stderr | 5 + .../tests/riscv-rt/post_init/fail_async.rs | 4 + .../riscv-rt/post_init/fail_async.stderr | 5 + .../tests/riscv-rt/post_init/pass_empty.rs | 4 + .../tests/riscv-rt/post_init/pass_safe.rs | 4 + .../tests/riscv-rt/post_init/pass_unsafe.rs | 4 + 14 files changed, 171 insertions(+), 22 deletions(-) create mode 100644 tests-trybuild/tests/riscv-rt/post_init/fail_arg_count.rs create mode 100644 tests-trybuild/tests/riscv-rt/post_init/fail_arg_count.stderr create mode 100644 tests-trybuild/tests/riscv-rt/post_init/fail_arg_type.rs create mode 100644 tests-trybuild/tests/riscv-rt/post_init/fail_arg_type.stderr create mode 100644 tests-trybuild/tests/riscv-rt/post_init/fail_async.rs create mode 100644 tests-trybuild/tests/riscv-rt/post_init/fail_async.stderr create mode 100644 tests-trybuild/tests/riscv-rt/post_init/pass_empty.rs create mode 100644 tests-trybuild/tests/riscv-rt/post_init/pass_safe.rs create mode 100644 tests-trybuild/tests/riscv-rt/post_init/pass_unsafe.rs diff --git a/riscv-rt/CHANGELOG.md b/riscv-rt/CHANGELOG.md index 658b77ef..fd8c0e80 100644 --- a/riscv-rt/CHANGELOG.md +++ b/riscv-rt/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] +### Added + +- New `post-init` feature to run a Rust `__post_init` function before jumping to `main`. +- New `#[riscv_rt::post_init]` attribute to aid in the definition of the `__post_init` function. + ### Changed - `main` function no longer needs to be close to `_start`. A linker script may copy diff --git a/riscv-rt/Cargo.toml b/riscv-rt/Cargo.toml index 761cd52b..e400473c 100644 --- a/riscv-rt/Cargo.toml +++ b/riscv-rt/Cargo.toml @@ -14,7 +14,7 @@ links = "riscv-rt" # Prevent multiple versions of riscv-rt being linked [package.metadata.docs.rs] default-target = "riscv64imac-unknown-none-elf" -features = ["pre-init"] +features = ["pre-init", "post-init"] targets = [ "riscv32i-unknown-none-elf", "riscv32imc-unknown-none-elf", "riscv32imac-unknown-none-elf", "riscv64imac-unknown-none-elf", "riscv64gc-unknown-none-elf", @@ -33,6 +33,7 @@ panic-halt = "1.0.0" [features] pre-init = [] +post-init = [] s-mode = ["riscv-rt-macros/s-mode"] single-hart = [] v-trap = ["riscv-rt-macros/v-trap"] diff --git a/riscv-rt/macros/src/lib.rs b/riscv-rt/macros/src/lib.rs index 3e7ddd4f..d174a9bb 100644 --- a/riscv-rt/macros/src/lib.rs +++ b/riscv-rt/macros/src/lib.rs @@ -63,26 +63,6 @@ pub fn entry(args: TokenStream, input: TokenStream) -> TokenStream { .into(); } - fn check_correct_type(argument: &PatType, ty: &str) -> Option { - let inv_type_message = format!("argument type must be {ty}"); - - if !is_correct_type(&argument.ty, ty) { - let error = parse::Error::new(argument.ty.span(), inv_type_message); - - Some(error.to_compile_error().into()) - } else { - None - } - } - fn check_argument_type(argument: &FnArg, ty: &str) -> Option { - let argument_error = parse::Error::new(argument.span(), "invalid argument"); - let argument_error = argument_error.to_compile_error().into(); - - match argument { - FnArg::Typed(argument) => check_correct_type(argument, ty), - FnArg::Receiver(_) => Some(argument_error), - } - } #[cfg(not(feature = "u-boot"))] for argument in f.sig.inputs.iter() { if let Some(message) = check_argument_type(argument, "usize") { @@ -181,6 +161,28 @@ fn is_correct_type(ty: &Type, name: &str) -> bool { } } +fn check_correct_type(argument: &PatType, ty: &str) -> Option { + let inv_type_message = format!("argument type must be {ty}"); + + if !is_correct_type(&argument.ty, ty) { + let error = parse::Error::new(argument.ty.span(), inv_type_message); + + Some(error.to_compile_error().into()) + } else { + None + } +} + +fn check_argument_type(argument: &FnArg, ty: &str) -> Option { + let argument_error = parse::Error::new(argument.span(), "invalid argument"); + let argument_error = argument_error.to_compile_error().into(); + + match argument { + FnArg::Typed(argument) => check_correct_type(argument, ty), + FnArg::Receiver(_) => Some(argument_error), + } +} + /// Attribute to mark which function will be called at the beginning of the reset handler. /// You must enable the `pre_init` feature in the `riscv-rt` crate to use this macro. /// @@ -263,6 +265,93 @@ pub fn pre_init(args: TokenStream, input: TokenStream) -> TokenStream { .into() } +/// Attribute to mark which function will be called before jumping to the entry point. +/// You must enable the `post-init` feature in the `riscv-rt` crate to use this macro. +/// +/// In contrast with `__pre_init`, this function is called after the static variables +/// are initialized, so it is safe to access them. It is also safe to run Rust code. +/// +/// The function must have the signature of `[unsafe] fn([usize])`, where the argument +/// corresponds to the hart ID of the current hart. This is useful for multi-hart systems +/// to perform hart-specific initialization. +/// +/// # IMPORTANT +/// +/// This attribute can appear at most *once* in the dependency graph. +/// +/// # Examples +/// +/// ``` +/// use riscv_rt_macros::post_init; +/// #[post_init] +/// unsafe fn before_main(hart_id: usize) { +/// // do something here +/// } +/// ``` +#[proc_macro_attribute] +pub fn post_init(args: TokenStream, input: TokenStream) -> TokenStream { + let f = parse_macro_input!(input as ItemFn); + + // check the function arguments + if f.sig.inputs.len() > 1 { + return parse::Error::new( + f.sig.inputs.last().unwrap().span(), + "`#[post_init]` function has too many arguments", + ) + .to_compile_error() + .into(); + } + for argument in f.sig.inputs.iter() { + if let Some(message) = check_argument_type(argument, "usize") { + return message; + }; + } + + // check the function signature + let valid_signature = f.sig.constness.is_none() + && f.sig.asyncness.is_none() + && f.vis == Visibility::Inherited + && f.sig.abi.is_none() + && f.sig.generics.params.is_empty() + && f.sig.generics.where_clause.is_none() + && f.sig.variadic.is_none() + && match f.sig.output { + ReturnType::Default => true, + ReturnType::Type(_, ref ty) => match **ty { + Type::Tuple(ref tuple) => tuple.elems.is_empty(), + _ => false, + }, + }; + + if !valid_signature { + return parse::Error::new( + f.span(), + "`#[post_init]` function must have signature `[unsafe] fn([usize])`", + ) + .to_compile_error() + .into(); + } + + if !args.is_empty() { + return parse::Error::new(Span::call_site(), "This attribute accepts no arguments") + .to_compile_error() + .into(); + } + + // XXX should we blacklist other attributes? + let attrs = f.attrs; + let ident = f.sig.ident; + let args = f.sig.inputs; + let block = f.block; + + quote!( + #[export_name = "__post_init"] + #(#attrs)* + unsafe fn #ident(#args) #block + ) + .into() +} + struct AsmLoopArgs { asm_template: String, count_from: usize, diff --git a/riscv-rt/src/lib.rs b/riscv-rt/src/lib.rs index a65001e1..78bd071d 100644 --- a/riscv-rt/src/lib.rs +++ b/riscv-rt/src/lib.rs @@ -553,6 +553,14 @@ //! ); //! ``` //! +//! ## `post-init` +//! +//! When enabled, the runtime will execute the `__post_init` function to be run before jumping to the main function. +//! If the feature is enabled, the `__post_init` function must be defined in the user code (i.e., no default implementation +//! is provided by this crate). If the feature is disabled, the `__post_init` function is not required. +//! +//! You can use the [`#[post_init]`][attr-post-init] attribute to define a post-init function with Rust. +//! //! ## `single-hart` //! //! Saves a little code size if there is only one hart on the target. @@ -595,6 +603,7 @@ //! [attr-external-interrupt]: attr.external_interrupt.html //! [attr-core-interrupt]: attr.core_interrupt.html //! [attr-pre-init]: attr.pre_init.html +//! [attr-post-init]: attr.post_init.html // NOTE: Adapted from cortex-m/src/lib.rs #![no_std] @@ -624,6 +633,8 @@ use riscv::register::{ pub use riscv_pac::*; pub use riscv_rt_macros::{core_interrupt, entry, exception, external_interrupt}; +#[cfg(feature = "post-init")] +pub use riscv_rt_macros::post_init; #[cfg(feature = "pre-init")] pub use riscv_rt_macros::pre_init; @@ -650,10 +661,14 @@ pub static __ONCE__: () = (); #[export_name = "_start_rust"] pub unsafe extern "C" fn start_rust(a0: usize, a1: usize, a2: usize) -> ! { extern "Rust" { + #[cfg(feature = "post-init")] + fn __post_init(a0: usize); fn _setup_interrupts(); fn hal_main(a0: usize, a1: usize, a2: usize) -> !; } + #[cfg(feature = "post-init")] + __post_init(a0); _setup_interrupts(); hal_main(a0, a1, a2); } diff --git a/tests-trybuild/Cargo.toml b/tests-trybuild/Cargo.toml index c5141b43..40270009 100644 --- a/tests-trybuild/Cargo.toml +++ b/tests-trybuild/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] riscv = { path = "../riscv" } -riscv-rt = { path = "../riscv-rt", features = ["no-exceptions", "no-interrupts"] } +riscv-rt = { path = "../riscv-rt", features = ["no-exceptions", "no-interrupts", "post-init"] } trybuild = "1.0" [features] diff --git a/tests-trybuild/tests/riscv-rt/post_init/fail_arg_count.rs b/tests-trybuild/tests/riscv-rt/post_init/fail_arg_count.rs new file mode 100644 index 00000000..6c94330d --- /dev/null +++ b/tests-trybuild/tests/riscv-rt/post_init/fail_arg_count.rs @@ -0,0 +1,4 @@ +#[riscv_rt::post_init] +fn before_main(_hart_id: usize, _dtb: usize) {} + +fn main() {} diff --git a/tests-trybuild/tests/riscv-rt/post_init/fail_arg_count.stderr b/tests-trybuild/tests/riscv-rt/post_init/fail_arg_count.stderr new file mode 100644 index 00000000..32ac7e78 --- /dev/null +++ b/tests-trybuild/tests/riscv-rt/post_init/fail_arg_count.stderr @@ -0,0 +1,5 @@ +error: `#[post_init]` function has too many arguments + --> tests/riscv-rt/post_init/fail_arg_count.rs:2:33 + | +2 | fn before_main(_hart_id: usize, _dtb: usize) {} + | ^^^^ diff --git a/tests-trybuild/tests/riscv-rt/post_init/fail_arg_type.rs b/tests-trybuild/tests/riscv-rt/post_init/fail_arg_type.rs new file mode 100644 index 00000000..fc5894c5 --- /dev/null +++ b/tests-trybuild/tests/riscv-rt/post_init/fail_arg_type.rs @@ -0,0 +1,4 @@ +#[riscv_rt::post_init] +fn before_main(_hart_id: String) {} + +fn main() {} diff --git a/tests-trybuild/tests/riscv-rt/post_init/fail_arg_type.stderr b/tests-trybuild/tests/riscv-rt/post_init/fail_arg_type.stderr new file mode 100644 index 00000000..73f4f80c --- /dev/null +++ b/tests-trybuild/tests/riscv-rt/post_init/fail_arg_type.stderr @@ -0,0 +1,5 @@ +error: argument type must be usize + --> tests/riscv-rt/post_init/fail_arg_type.rs:2:26 + | +2 | fn before_main(_hart_id: String) {} + | ^^^^^^ diff --git a/tests-trybuild/tests/riscv-rt/post_init/fail_async.rs b/tests-trybuild/tests/riscv-rt/post_init/fail_async.rs new file mode 100644 index 00000000..27c95b93 --- /dev/null +++ b/tests-trybuild/tests/riscv-rt/post_init/fail_async.rs @@ -0,0 +1,4 @@ +#[riscv_rt::post_init] +async fn before_main() {} + +fn main() {} diff --git a/tests-trybuild/tests/riscv-rt/post_init/fail_async.stderr b/tests-trybuild/tests/riscv-rt/post_init/fail_async.stderr new file mode 100644 index 00000000..0dff16ff --- /dev/null +++ b/tests-trybuild/tests/riscv-rt/post_init/fail_async.stderr @@ -0,0 +1,5 @@ +error: `#[post_init]` function must have signature `[unsafe] fn([usize])` + --> tests/riscv-rt/post_init/fail_async.rs:2:1 + | +2 | async fn before_main() {} + | ^^^^^ diff --git a/tests-trybuild/tests/riscv-rt/post_init/pass_empty.rs b/tests-trybuild/tests/riscv-rt/post_init/pass_empty.rs new file mode 100644 index 00000000..0d214c4b --- /dev/null +++ b/tests-trybuild/tests/riscv-rt/post_init/pass_empty.rs @@ -0,0 +1,4 @@ +#[riscv_rt::post_init] +fn before_main() {} + +fn main() {} diff --git a/tests-trybuild/tests/riscv-rt/post_init/pass_safe.rs b/tests-trybuild/tests/riscv-rt/post_init/pass_safe.rs new file mode 100644 index 00000000..70f05797 --- /dev/null +++ b/tests-trybuild/tests/riscv-rt/post_init/pass_safe.rs @@ -0,0 +1,4 @@ +#[riscv_rt::post_init] +fn before_main(_hart_id: usize) {} + +fn main() {} diff --git a/tests-trybuild/tests/riscv-rt/post_init/pass_unsafe.rs b/tests-trybuild/tests/riscv-rt/post_init/pass_unsafe.rs new file mode 100644 index 00000000..d470e699 --- /dev/null +++ b/tests-trybuild/tests/riscv-rt/post_init/pass_unsafe.rs @@ -0,0 +1,4 @@ +#[riscv_rt::post_init] +unsafe fn before_main(_hart_id: usize) {} + +fn main() {}