diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 2f918faaf752b..b615398b4ed09 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -9,6 +9,7 @@ use std::str::FromStr; use crate::expand::{Decodable, Encodable, HashStable_Generic}; use crate::ptr::P; use crate::{Ty, TyKind}; +use crate::expand::typetree::TypeTree; /// Forward and Reverse Mode are well known names for automatic differentiation implementations. /// Enzyme does support both, but with different semantics, see DiffActivity. The First variants @@ -85,6 +86,9 @@ pub struct AutoDiffItem { /// The name of the function being generated pub target: String, pub attrs: AutoDiffAttrs, + // --- TypeTree support --- + pub inputs: Vec, + pub output: TypeTree, } #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] @@ -112,6 +116,10 @@ impl AutoDiffAttrs { pub fn has_primal_ret(&self) -> bool { matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual) } + /// New constructor for type tree support + pub fn into_item(self, source: String, target: String, inputs: Vec, output: TypeTree) -> AutoDiffItem { + AutoDiffItem { source, target, attrs: self, inputs, output } + } } impl DiffMode { @@ -284,6 +292,8 @@ impl AutoDiffAttrs { impl fmt::Display for AutoDiffItem { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Differentiating {} -> {}", self.source, self.target)?; - write!(f, " with attributes: {:?}", self.attrs) + write!(f, " with attributes: {:?}", self.attrs)?; + write!(f, " with inputs: {:?}", self.inputs)?; + write!(f, " with output: {:?}", self.output) } } diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index c784477833279..dd5f3d5aa3237 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -11,7 +11,9 @@ mod llvm_enzyme { AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity, valid_ty_for_activity, }; + use rustc_ast::expand::typetree::{TypeTree, Type, Kind}; use rustc_ast::ptr::P; + use crate::typetree::construct_typetree_from_fnsig; use rustc_ast::token::{Lit, LitKind, Token, TokenKind}; use rustc_ast::tokenstream::*; use rustc_ast::visit::AssocCtxt::*; @@ -324,6 +326,17 @@ mod llvm_enzyme { } let span = ecx.with_def_site_ctxt(expand_span); + // Construct real type trees from function signature + let (inputs, output) = construct_typetree_from_fnsig(&sig); + + // Use the new into_item method to construct the AutoDiffItem + let autodiff_item = x.clone().into_item( + primal.to_string(), + first_ident(&meta_item_vec[0]).to_string(), + inputs, + output, + ); + let n_active: u32 = x .input_activity .iter() @@ -1045,5 +1058,3 @@ mod llvm_enzyme { (d_sig, new_inputs, idents, false) } } - -pub(crate) use llvm_enzyme::{expand_forward, expand_reverse}; diff --git a/compiler/rustc_builtin_macros/src/lib.rs b/compiler/rustc_builtin_macros/src/lib.rs index 0594f7e86c333..f5dc409fde4d1 100644 --- a/compiler/rustc_builtin_macros/src/lib.rs +++ b/compiler/rustc_builtin_macros/src/lib.rs @@ -51,6 +51,7 @@ mod pattern_type; mod source_util; mod test; mod trace_macros; +mod typetree; pub mod asm; pub mod cmdline_attrs; diff --git a/compiler/rustc_builtin_macros/src/typetree.rs b/compiler/rustc_builtin_macros/src/typetree.rs new file mode 100644 index 0000000000000..f33efe3b22b06 --- /dev/null +++ b/compiler/rustc_builtin_macros/src/typetree.rs @@ -0,0 +1,330 @@ +use rustc_ast as ast; +use rustc_ast::FnRetTy; +use rustc_ast::expand::typetree::{Type, Kind, TypeTree, FncTree}; +use rustc_middle::ty::{Ty, TyCtxt, ParamEnv, ParamEnvAnd, Adt}; +use rustc_middle::ty::layout::{FieldsShape, LayoutOf}; +use rustc_middle::hir; +use rustc_span::Span; +use rustc_ast::expand::autodiff_attrs::DiffActivity; + +#[cfg(llvm_enzyme)] +pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + let mut visited = vec![]; + let ty = typetree_from_ty(ty, tcx, 0, false, &mut visited, None); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child: ty }; + return TypeTree(vec![tt]); +} + +// This function combines three tasks. To avoid traversing each type 3x, we combine them. +// 1. Create a TypeTree from a Ty. This is the main task. +// 2. IFF da is not empty, we also want to adjust DiffActivity to account for future MIR->LLVM +// lowering. E.g. fat ptr are going to introduce an extra int. +// 3. IFF da is not empty, we are creating TT for a function directly differentiated (has an +// autodiff macro on top). Here we want to make sure that shadows are mutable internally. +// We know the outermost ref/ptr indirection is mutability - we generate it like that. +// We now have to make sure that inner ptr/ref are mutable too, or issue a warning. +// Not an error, becaues it only causes issues if they are actually read, which we don't check +// yet. We should add such analysis to relibably either issue an error or accept without warning. +// If there only were some reasearch to do that... +#[cfg(llvm_enzyme)] +pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec, span: Option) -> FncTree { + if !fn_ty.is_fn() { + return FncTree { args: vec![], ret: TypeTree::new() }; + } + let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); + + // If rustc compiles the unmodified primal, we know that this copy of the function + // also has correct lifetimes. We know that Enzyme won't free the shadow too early + // (or actually at all), so let's strip lifetimes when computing the layout. + // Recommended by compiler-errors: + // https://discord.com/channels/273534239310479360/957720175619215380/1223454360676208751 + let x = tcx.instantiate_bound_regions_with_erased(fnc_binder); + + let mut new_activities = vec![]; + let mut new_positions = vec![]; + let mut visited = vec![]; + let mut args = vec![]; + for (i, ty) in x.inputs().iter().enumerate() { + // We care about safety checks, if an argument get's duplicated and we write into the + // shadow. That's equivalent to Duplicated or DuplicatedOnly. + let safety = if !da.is_empty() { + assert!(da.len() == x.inputs().len(), "{:?} != {:?}", da.len(), x.inputs().len()); + // If we have Activities, we also have spans + assert!(span.is_some()); + match da[i] { + DiffActivity::DuplicatedOnly | DiffActivity::Duplicated => true, + _ => false, + } + } else { + false + }; + + visited.clear(); + if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do whith fn ptr?"); + } + let inner_ty = ty.builtin_deref(true).unwrap().ty; + if inner_ty.is_slice() { + // We know that the lenght will be passed as extra arg. + let child = typetree_from_ty(inner_ty, tcx, 1, safety, &mut visited, span); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; + args.push(TypeTree(vec![tt])); + let i64_tt = Type { offset: -1, kind: Kind::Integer, size: 8, child: TypeTree::new() }; + args.push(TypeTree(vec![i64_tt])); + if !da.is_empty() { + // We are looking at a slice. The length of that slice will become an + // extra integer on llvm level. Integers are always const. + // However, if the slice get's duplicated, we want to know to later check the + // size. So we mark the new size argument as FakeActivitySize. + let activity = match da[i] { + DiffActivity::DualOnly | DiffActivity::Dual | + DiffActivity::DuplicatedOnly | DiffActivity::Duplicated + => DiffActivity::FakeActivitySize, + DiffActivity::Const => DiffActivity::Const, + _ => panic!("unexpected activity for ptr/ref"), + }; + new_activities.push(activity); + new_positions.push(i + 1); + } + trace!("ABI MATCHING!"); + continue; + } + } + let arg_tt = typetree_from_ty(*ty, tcx, 0, safety, &mut visited, span); + args.push(arg_tt); + } + + // now add the extra activities coming from slices + // Reverse order to not invalidate the indices + for _ in 0..new_activities.len() { + let pos = new_positions.pop().unwrap(); + let activity = new_activities.pop().unwrap(); + da.insert(pos, activity); + } + + visited.clear(); + let ret = typetree_from_ty(x.output(), tcx, 0, false, &mut visited, span); + + FncTree { args, ret } +} + + +// Error type for warnings +#[derive(Debug)] +pub struct AutodiffUnsafeInnerConstRef { + pub span: Span, + pub ty: String, +} + +#[cfg(llvm_enzyme)] +fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: bool, visited: &mut Vec>, span: Option) -> TypeTree { + if depth > 20 { + trace!("depth > 20 for ty: {}", &ty); + } + if visited.contains(&ty) { + // recursive type + trace!("recursive type: {}", &ty); + return TypeTree::new(); + } + visited.push(ty); + + if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do whith fn ptr?"); + } + + let inner_ty_and_mut = ty.builtin_deref(true).unwrap(); + let is_mut = inner_ty_and_mut.mutbl == hir::Mutability::Mut; + let inner_ty = inner_ty_and_mut.ty; + + // Now account for inner mutability. + if !is_mut && depth > 0 && safety { + let ptr_ty: String = if ty.is_ref() { + "ref" + } else if ty.is_unsafe_ptr() { + "ptr" + } else { + assert!(ty.is_box()); + "box" + }.to_string(); + + // If we have mutability, we also have a span + assert!(span.is_some()); + let span = span.unwrap(); + + tcx.sess + .dcx() + .emit_warning(AutodiffUnsafeInnerConstRef{span, ty: ptr_ty}); + } + + let child = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; + visited.pop(); + return TypeTree(vec![tt]); + } + + if ty.is_closure() || ty.is_coroutine() || ty.is_fresh() || ty.is_fn() { + visited.pop(); + return TypeTree::new(); + } + + if ty.is_scalar() { + let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { + (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) + } else if ty.is_floating_point() { + match ty { + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + _ => panic!("floatTy scalar that is neither f32 nor f64"), + } + } else { + panic!("scalar that is neither integral nor floating point"); + }; + visited.pop(); + return TypeTree(vec![Type { offset: -1, child: TypeTree::new(), kind, size }]); + } + + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: ty }; + + let layout = tcx.layout_of(param_env_and); + assert!(layout.is_ok()); + + let layout = layout.unwrap().layout; + let fields = layout.fields(); + let max_size = layout.size(); + + if ty.is_adt() && !ty.is_simd() { + let adt_def = ty.ty_adt_def().unwrap(); + + if adt_def.is_struct() { + let (offsets, _memory_index) = match fields { + // Manuel TODO: + FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m), + FieldsShape::Array { .. } => {return TypeTree::new();}, //e.g. core::arch::x86_64::__m128i, TODO: later + FieldsShape::Union(_) => {return TypeTree::new();}, + FieldsShape::Primitive => {return TypeTree::new();}, + }; + + let substs = match ty.kind() { + Adt(_, subst_ref) => subst_ref, + _ => panic!(""), + }; + + let fields = adt_def.all_fields(); + let fields = fields + .into_iter() + .zip(offsets.into_iter()) + .filter_map(|(field, offset)| { + let field_ty: Ty<'_> = field.ty(tcx, substs); + let field_ty: Ty<'_> = + tcx.normalize_erasing_regions(ParamEnv::empty(), field_ty); + + if field_ty.is_phantom_data() { + return None; + } + + let mut child = typetree_from_ty(field_ty, tcx, depth + 1, safety, visited, span).0; + + for c in &mut child { + if c.offset == -1 { + c.offset = offset.bytes() as isize + } else { + c.offset += offset.bytes() as isize; + } + } + + Some(child) + }) + .flatten() + .collect::>(); + + visited.pop(); + let ret_tt = TypeTree(fields); + return ret_tt; + } else if adt_def.is_enum() { + // Enzyme can't represent enums, so let it figure it out itself, without seeeding + // typetree + //unimplemented!("adt that is an enum"); + } else { + //let ty_name = tcx.def_path_debug_str(adt_def.did()); + //tcx.sess.emit_fatal(UnsupportedUnion { ty_name }); + } + } + + if ty.is_simd() { + trace!("simd"); + let (_size, inner_ty) = ty.simd_size_and_type(tcx); + let _sub_tt = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span); + // TODO + visited.pop(); + return TypeTree::new(); + } + + if ty.is_array() { + let (stride, count) = match fields { + FieldsShape::Array { stride: s, count: c } => (s, c), + _ => panic!(""), + }; + let byte_stride = stride.bytes_usize(); + let byte_max_size = max_size.bytes_usize(); + + assert!(byte_stride * *count as usize == byte_max_size); + if (*count as usize) == 0 { + return TypeTree::new(); + } + let sub_ty = ty.builtin_index().unwrap(); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span); + + // calculate size of subtree + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty }; + let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize; + let tt = TypeTree( + std::iter::repeat(subtt) + .take(*count as usize) + .enumerate() + .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) + .flatten() + .collect(), + ); + + visited.pop(); + return tt; + } + + if ty.is_slice() { + let sub_ty = ty.builtin_index().unwrap(); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span); + + visited.pop(); + return subtt; + } + + visited.pop(); + TypeTree::new() +} + +// AST-based type tree construction (simplified fallback) +#[cfg(llvm_enzyme)] +pub fn construct_typetree_from_ty(ty: &ast::Ty) -> TypeTree { + // For now, return empty type tree to let Enzyme figure out layout + // In a full implementation, we'd need to convert AST types to Ty<'tcx> + // and use the layout-based approach from the old code + TypeTree::new() +} + +#[cfg(llvm_enzyme)] +pub fn construct_typetree_from_fnsig(sig: &ast::FnSig) -> (Vec, TypeTree) { + // For now, return empty type trees + // This will be replaced with proper layout-based construction + let inputs: Vec = sig.decl.inputs.iter() + .map(|_| TypeTree::new()) + .collect(); + + let output = match &sig.decl.output { + FnRetTy::Default(_) => TypeTree::new(), + FnRetTy::Ty(_) => TypeTree::new(), + }; + + (inputs, output) +} diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index b07d9a5cfca8c..a709f528e9fba 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -1,6 +1,7 @@ use std::ptr; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; +use rustc_ast::expand::typetree::{FncTree, TypeTree}; use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::back::write::ModuleConfig; use rustc_codegen_ssa::common::TypeKind; @@ -16,8 +17,10 @@ use crate::declare::declare_simple_fn; use crate::errors::{AutoDiffWithoutEnable, LlvmError}; use crate::llvm::AttributePlace::Function; use crate::llvm::{Metadata, True}; +use crate::typetree::to_enzyme_typetree; use crate::value::Value; -use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; +use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm, DiffTypeTree}; +use rustc_data_structures::fx::FxHashMap; fn get_params(fnc: &Value) -> Vec<&Value> { let param_num = llvm::LLVMCountParams(fnc) as usize; @@ -294,6 +297,7 @@ fn generate_enzyme_call<'ll>( fn_to_diff: &'ll Value, outer_fn: &'ll Value, attrs: AutoDiffAttrs, + fnc_tree: Option, ) { // We have to pick the name depending on whether we want forward or reverse mode autodiff. let mut ad_name: String = match attrs.mode { @@ -361,6 +365,15 @@ fn generate_enzyme_call<'ll>( let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); attributes::apply_to_llfn(ad_fn, Function, &[attr]); + // TODO(KMJ-007): Add type tree metadata if available + // This requires adding CreateTypeTreeAttribute to LLVM bindings + // if let Some(tree) = fnc_tree { + // let data_layout = cx.data_layout(); + // let enzyme_tree = to_enzyme_typetree(tree, data_layout, cx.llcx); + // let tt_attr = llvm::CreateTypeTreeAttribute(cx.llcx, enzyme_tree); + // attributes::apply_to_llfn(ad_fn, Function, &[tt_attr]); + // } + // We add a made-up attribute just such that we can recognize it after AD to update // (no)-inline attributes. We'll then also remove this attribute. let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker"); @@ -461,6 +474,7 @@ pub(crate) fn differentiate<'ll>( module: &'ll ModuleCodegen, cgcx: &CodegenContext, diff_items: Vec, + typetrees: FxHashMap, _config: &ModuleConfig, ) -> Result<(), FatalError> { for item in &diff_items { @@ -505,7 +519,22 @@ pub(crate) fn differentiate<'ll>( )); }; - generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); + // Use type trees from the typetrees map if available, otherwise construct from item + let fnc_tree = if let Some(diff_tt) = typetrees.get(&item.source) { + Some(FncTree { + inputs: diff_tt.input_tt.clone(), + output: diff_tt.ret_tt.clone(), + }) + } else if !item.inputs.is_empty() || !item.output.0.is_empty() { + Some(FncTree { + inputs: item.inputs.clone(), + output: item.output.clone(), + }) + } else { + None + }; + + generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone(), fnc_tree); } // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index cdfffbe47bfa5..541cc6ca8718f 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -27,6 +27,7 @@ use back::owned_target_machine::OwnedTargetMachine; use back::write::{create_informational_target_machine, create_target_machine}; use context::SimpleCx; use errors::{AutoDiffWithoutLTO, ParseTargetMachineConfig}; +use llvm::TypeTree; use llvm_util::target_config; use rustc_ast::expand::allocator::AllocatorKind; use rustc_ast::expand::autodiff_attrs::AutoDiffItem; @@ -36,7 +37,7 @@ use rustc_codegen_ssa::back::write::{ }; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{CodegenResults, CompiledModule, ModuleCodegen, TargetConfig}; -use rustc_data_structures::fx::FxIndexMap; +use rustc_data_structures::fx::{FxHashMap, FxIndexMap}; use rustc_errors::{DiagCtxtHandle, FatalError}; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; @@ -74,6 +75,7 @@ mod llvm_util; mod mono_item; mod type_; mod type_of; +mod typetree; mod va_arg; mod value; @@ -159,6 +161,7 @@ impl WriteBackendMethods for LlvmCodegenBackend { type TargetMachineError = crate::errors::LlvmError<'static>; type ThinData = back::lto::ThinData; type ThinBuffer = back::lto::ThinBuffer; + type TypeTree = DiffTypeTree; fn print_pass_timings(&self) { let timings = llvm::build_string(|s| unsafe { llvm::LLVMRustPrintPassTimings(s) }).unwrap(); print!("{timings}"); @@ -232,13 +235,20 @@ impl WriteBackendMethods for LlvmCodegenBackend { cgcx: &CodegenContext, module: &ModuleCodegen, diff_fncs: Vec, + typetrees: FxHashMap, config: &ModuleConfig, ) -> Result<(), FatalError> { if cgcx.lto != Lto::Fat { let dcx = cgcx.create_dcx(); return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO)); } - builder::autodiff::differentiate(module, cgcx, diff_fncs, config) + builder::autodiff::differentiate(module, cgcx, diff_fncs, typetrees, config) + } + + // The typetrees contain all information, their order therefore is irrelevant. + #[allow(rustc::potential_query_instability)] + fn typetrees(module: &mut Self::Module) -> FxHashMap { + module.typetrees.drain().collect() } } @@ -386,6 +396,13 @@ impl CodegenBackend for LlvmCodegenBackend { } } +#[derive(Clone, Debug)] +pub struct DiffTypeTree { + pub ret_tt: TypeTree, + pub input_tt: Vec, +} + +#[allow(dead_code)] pub struct ModuleLlvm { llcx: &'static mut llvm::Context, llmod_raw: *const llvm::Module, @@ -393,6 +410,7 @@ pub struct ModuleLlvm { // This field is `ManuallyDrop` because it is important that the `TargetMachine` // is disposed prior to the `Context` being disposed otherwise UAFs can occur. tm: ManuallyDrop, + typetrees: FxHashMap, } unsafe impl Send for ModuleLlvm {} @@ -407,6 +425,7 @@ impl ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(create_target_machine(tcx, mod_name)), + typetrees: Default::default(), } } } @@ -418,7 +437,8 @@ impl ModuleLlvm { ModuleLlvm { llmod_raw, llcx, - tm: ManuallyDrop::new(create_informational_target_machine(tcx.sess, false)), + tm: ManuallyDrop::new(create_informational_target_machine(tcx.sess)), + typetrees: Default::default(), } } } @@ -440,7 +460,12 @@ impl ModuleLlvm { } }; - Ok(ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm) }) + Ok(ModuleLlvm { + llmod_raw, + llcx, + tm: ManuallyDrop::new(tm), + typetrees: Default::default(), + }) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 91ada856d5977..5e0cc8a0b6316 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -2670,4 +2670,7 @@ unsafe extern "C" { pub(crate) fn LLVMRustSetNoSanitizeAddress(Global: &Value); pub(crate) fn LLVMRustSetNoSanitizeHWAddress(Global: &Value); + + // Type Tree Attribute Functions + pub fn CreateTypeTreeAttribute<'a>(llcx: &'a Context, typetree: &'a TypeTree) -> &'a Attribute; } diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs new file mode 100644 index 0000000000000..3d688f443524c --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -0,0 +1,33 @@ +use crate::llvm; +use rustc_ast::expand::typetree::{Kind, TypeTree}; + +pub fn to_enzyme_typetree( + tree: TypeTree, + llvm_data_layout: &str, + llcx: &llvm::Context, +) -> llvm::TypeTree { + tree.0.iter().fold(llvm::TypeTree::new(), |obj, x| { + let scalar = match x.kind { + Kind::Integer => llvm::CConcreteType::DT_Integer, + Kind::Float => llvm::CConcreteType::DT_Float, + Kind::Double => llvm::CConcreteType::DT_Double, + Kind::Pointer => llvm::CConcreteType::DT_Pointer, + _ => panic!("Unknown kind {:?}", x.kind), + }; + + let tt = llvm::TypeTree::from_type(scalar, llcx).only(-1); + + let tt = if !x.child.0.is_empty() { + let inner_tt = to_enzyme_typetree(x.child.clone(), llvm_data_layout, llcx); + tt.merge(inner_tt.only(-1)) + } else { + tt + }; + + if x.offset != -1 { + obj.merge(tt.shift(llvm_data_layout, 0, x.size as isize, x.offset as usize)) + } else { + obj.merge(tt) + } + }) +} \ No newline at end of file