Skip to content

Commit ba7cd1e

Browse files
committed
basic type tree added
Signed-off-by: Karan Janthe <[email protected]>
1 parent 72a4d3a commit ba7cd1e

File tree

4 files changed

+162
-6
lines changed

4 files changed

+162
-6
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ mod llvm_enzyme {
1111
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
1212
valid_ty_for_activity,
1313
};
14+
use rustc_ast::expand::typetree::{TypeTree, Type, Kind};
1415
use rustc_ast::ptr::P;
1516
use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
1617
use rustc_ast::tokenstream::*;
@@ -25,7 +26,6 @@ mod llvm_enzyme {
2526
use tracing::{debug, trace};
2627

2728
use crate::errors;
28-
use crate::expand::typetree::TypeTree;
2929

3030
pub(crate) fn outer_normal_attr(
3131
kind: &P<rustc_ast::NormalAttr>,
@@ -325,10 +325,9 @@ mod llvm_enzyme {
325325
}
326326
let span = ecx.with_def_site_ctxt(expand_span);
327327

328-
// Prepare placeholder type trees for now
329-
let num_args = sig.decl.inputs.len();
330-
let inputs = vec![TypeTree::new(); num_args];
331-
let output = TypeTree::new();
328+
// Construct real type trees from function signature
329+
let (inputs, output) = construct_typetree_from_fnsig(&sig);
330+
332331
// Use the new into_item method to construct the AutoDiffItem
333332
let autodiff_item = x.clone().into_item(
334333
primal.to_string(),
@@ -1059,4 +1058,105 @@ mod llvm_enzyme {
10591058
}
10601059
}
10611060

1061+
#[cfg(llvm_enzyme)]
1062+
fn construct_typetree_from_ty(ty: &ast::Ty) -> TypeTree {
1063+
match &ty.kind {
1064+
TyKind::Path(..) => {
1065+
// Handle basic types like f32, f64, i32, etc.
1066+
// For now, we'll use a simple heuristic based on the path
1067+
// In a full implementation, this would need to be more sophisticated
1068+
TypeTree(vec![Type {
1069+
offset: 0,
1070+
size: 8, // Default size, should be computed properly
1071+
kind: Kind::Float, // Default to float, should be determined from type
1072+
child: TypeTree::new(),
1073+
}])
1074+
}
1075+
TyKind::Ptr(ptr_ty) => {
1076+
TypeTree(vec![Type {
1077+
offset: 0,
1078+
size: 8, // Pointer size
1079+
kind: Kind::Pointer,
1080+
child: construct_typetree_from_ty(&ptr_ty.ty),
1081+
}])
1082+
}
1083+
TyKind::Ref(_, ref_ty) => {
1084+
TypeTree(vec![Type {
1085+
offset: 0,
1086+
size: 8, // Reference size
1087+
kind: Kind::Pointer,
1088+
child: construct_typetree_from_ty(&ref_ty.ty),
1089+
}])
1090+
}
1091+
TyKind::Slice(slice_ty) => {
1092+
TypeTree(vec![Type {
1093+
offset: 0,
1094+
size: 16, // Slice is (ptr, len)
1095+
kind: Kind::Pointer,
1096+
child: construct_typetree_from_ty(&slice_ty.ty),
1097+
}])
1098+
}
1099+
TyKind::Array(array_ty) => {
1100+
// For arrays, we need to handle the size properly
1101+
let elem_ty = construct_typetree_from_ty(&array_ty.ty);
1102+
TypeTree(vec![Type {
1103+
offset: 0,
1104+
size: 8, // Array size depends on element type and count
1105+
kind: Kind::Pointer,
1106+
child: elem_ty,
1107+
}])
1108+
}
1109+
TyKind::Tup(tuple_types) => {
1110+
let mut types = Vec::new();
1111+
let mut offset = 0;
1112+
for (i, tuple_ty) in tuple_types.iter().enumerate() {
1113+
let elem_ty = construct_typetree_from_ty(tuple_ty);
1114+
// For tuples, we need to handle alignment and padding
1115+
// This is a simplified version
1116+
types.push(Type {
1117+
offset: offset as isize,
1118+
size: 8, // Should be computed based on actual type
1119+
kind: Kind::Float, // Default
1120+
child: elem_ty,
1121+
});
1122+
offset += 8; // Simplified alignment
1123+
}
1124+
TypeTree(types)
1125+
}
1126+
_ => {
1127+
// Default case for unknown types
1128+
TypeTree(vec![Type {
1129+
offset: 0,
1130+
size: 8,
1131+
kind: Kind::Float,
1132+
child: TypeTree::new(),
1133+
}])
1134+
}
1135+
}
1136+
}
1137+
1138+
#[cfg(llvm_enzyme)]
1139+
fn construct_typetree_from_fnsig(sig: &ast::FnSig) -> (Vec<TypeTree>, TypeTree) {
1140+
// Construct type trees for input arguments
1141+
let inputs: Vec<TypeTree> = sig.decl.inputs.iter()
1142+
.map(|param| construct_typetree_from_ty(&param.ty))
1143+
.collect();
1144+
1145+
// Construct type tree for return type
1146+
let output = match &sig.decl.output {
1147+
FnRetTy::Default(span) => {
1148+
// Unit type ()
1149+
TypeTree(vec![Type {
1150+
offset: 0,
1151+
size: 0,
1152+
kind: Kind::Integer,
1153+
child: TypeTree::new(),
1154+
}])
1155+
}
1156+
FnRetTy::Ty(ty) => construct_typetree_from_ty(ty),
1157+
};
1158+
1159+
(inputs, output)
1160+
}
1161+
10621162
pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::ptr;
22

33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
4+
use rustc_ast::expand::typetree::{FncTree, TypeTree};
45
use rustc_codegen_ssa::ModuleCodegen;
56
use rustc_codegen_ssa::back::write::ModuleConfig;
67
use rustc_codegen_ssa::common::TypeKind;
@@ -16,6 +17,7 @@ use crate::declare::declare_simple_fn;
1617
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
1718
use crate::llvm::AttributePlace::Function;
1819
use crate::llvm::{Metadata, True};
20+
use crate::typetree::to_enzyme_typetree;
1921
use crate::value::Value;
2022
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
2123

@@ -294,6 +296,7 @@ fn generate_enzyme_call<'ll>(
294296
fn_to_diff: &'ll Value,
295297
outer_fn: &'ll Value,
296298
attrs: AutoDiffAttrs,
299+
fnc_tree: Option<FncTree>,
297300
) {
298301
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
299302
let mut ad_name: String = match attrs.mode {
@@ -361,6 +364,15 @@ fn generate_enzyme_call<'ll>(
361364
let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
362365
attributes::apply_to_llfn(ad_fn, Function, &[attr]);
363366

367+
// TODO(KMJ-007): Add type tree metadata if available
368+
// This requires adding CreateTypeTreeAttribute to LLVM bindings
369+
// if let Some(tree) = fnc_tree {
370+
// let data_layout = cx.data_layout();
371+
// let enzyme_tree = to_enzyme_typetree(tree, data_layout, cx.llcx);
372+
// let tt_attr = llvm::CreateTypeTreeAttribute(cx.llcx, enzyme_tree);
373+
// attributes::apply_to_llfn(ad_fn, Function, &[tt_attr]);
374+
// }
375+
364376
// We add a made-up attribute just such that we can recognize it after AD to update
365377
// (no)-inline attributes. We'll then also remove this attribute.
366378
let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker");
@@ -505,7 +517,17 @@ pub(crate) fn differentiate<'ll>(
505517
));
506518
};
507519

508-
generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
520+
// Construct function type tree from item's type trees
521+
let fnc_tree = if !item.inputs.is_empty() || !item.output.0.is_empty() {
522+
Some(FncTree {
523+
inputs: item.inputs.clone(),
524+
output: item.output.clone(),
525+
})
526+
} else {
527+
None
528+
};
529+
530+
generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone(), fnc_tree);
509531
}
510532

511533
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts

compiler/rustc_codegen_llvm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ mod llvm_util;
7474
mod mono_item;
7575
mod type_;
7676
mod type_of;
77+
mod typetree;
7778
mod va_arg;
7879
mod value;
7980

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
use crate::llvm;
2+
use rustc_ast::expand::typetree::{Kind, TypeTree};
3+
4+
pub fn to_enzyme_typetree(
5+
tree: TypeTree,
6+
llvm_data_layout: &str,
7+
llcx: &llvm::Context,
8+
) -> llvm::TypeTree {
9+
tree.0.iter().fold(llvm::TypeTree::new(), |obj, x| {
10+
let scalar = match x.kind {
11+
Kind::Integer => llvm::CConcreteType::DT_Integer,
12+
Kind::Float => llvm::CConcreteType::DT_Float,
13+
Kind::Double => llvm::CConcreteType::DT_Double,
14+
Kind::Pointer => llvm::CConcreteType::DT_Pointer,
15+
_ => panic!("Unknown kind {:?}", x.kind),
16+
};
17+
18+
let tt = llvm::TypeTree::from_type(scalar, llcx).only(-1);
19+
20+
let tt = if !x.child.0.is_empty() {
21+
let inner_tt = to_enzyme_typetree(x.child.clone(), llvm_data_layout, llcx);
22+
tt.merge(inner_tt.only(-1))
23+
} else {
24+
tt
25+
};
26+
27+
if x.offset != -1 {
28+
obj.merge(tt.shift(llvm_data_layout, 0, x.size as isize, x.offset as usize))
29+
} else {
30+
obj.merge(tt)
31+
}
32+
})
33+
}

0 commit comments

Comments
 (0)