@@ -11,6 +11,7 @@ mod llvm_enzyme {
11
11
AutoDiffAttrs , DiffActivity , DiffMode , valid_input_activity, valid_ret_activity,
12
12
valid_ty_for_activity,
13
13
} ;
14
+ use rustc_ast:: expand:: typetree:: { TypeTree , Type , Kind } ;
14
15
use rustc_ast:: ptr:: P ;
15
16
use rustc_ast:: token:: { Lit , LitKind , Token , TokenKind } ;
16
17
use rustc_ast:: tokenstream:: * ;
@@ -25,7 +26,6 @@ mod llvm_enzyme {
25
26
use tracing:: { debug, trace} ;
26
27
27
28
use crate :: errors;
28
- use crate :: expand:: typetree:: TypeTree ;
29
29
30
30
pub ( crate ) fn outer_normal_attr (
31
31
kind : & P < rustc_ast:: NormalAttr > ,
@@ -325,10 +325,9 @@ mod llvm_enzyme {
325
325
}
326
326
let span = ecx. with_def_site_ctxt ( expand_span) ;
327
327
328
- // Prepare placeholder type trees for now
329
- let num_args = sig. decl . inputs . len ( ) ;
330
- let inputs = vec ! [ TypeTree :: new( ) ; num_args] ;
331
- let output = TypeTree :: new ( ) ;
328
+ // Construct real type trees from function signature
329
+ let ( inputs, output) = construct_typetree_from_fnsig ( & sig) ;
330
+
332
331
// Use the new into_item method to construct the AutoDiffItem
333
332
let autodiff_item = x. clone ( ) . into_item (
334
333
primal. to_string ( ) ,
@@ -1059,4 +1058,105 @@ mod llvm_enzyme {
1059
1058
}
1060
1059
}
1061
1060
1061
+ #[ cfg( llvm_enzyme) ]
1062
+ fn construct_typetree_from_ty ( ty : & ast:: Ty ) -> TypeTree {
1063
+ match & ty. kind {
1064
+ TyKind :: Path ( ..) => {
1065
+ // Handle basic types like f32, f64, i32, etc.
1066
+ // For now, we'll use a simple heuristic based on the path
1067
+ // In a full implementation, this would need to be more sophisticated
1068
+ TypeTree ( vec ! [ Type {
1069
+ offset: 0 ,
1070
+ size: 8 , // Default size, should be computed properly
1071
+ kind: Kind :: Float , // Default to float, should be determined from type
1072
+ child: TypeTree :: new( ) ,
1073
+ } ] )
1074
+ }
1075
+ TyKind :: Ptr ( ptr_ty) => {
1076
+ TypeTree ( vec ! [ Type {
1077
+ offset: 0 ,
1078
+ size: 8 , // Pointer size
1079
+ kind: Kind :: Pointer ,
1080
+ child: construct_typetree_from_ty( & ptr_ty. ty) ,
1081
+ } ] )
1082
+ }
1083
+ TyKind :: Ref ( _, ref_ty) => {
1084
+ TypeTree ( vec ! [ Type {
1085
+ offset: 0 ,
1086
+ size: 8 , // Reference size
1087
+ kind: Kind :: Pointer ,
1088
+ child: construct_typetree_from_ty( & ref_ty. ty) ,
1089
+ } ] )
1090
+ }
1091
+ TyKind :: Slice ( slice_ty) => {
1092
+ TypeTree ( vec ! [ Type {
1093
+ offset: 0 ,
1094
+ size: 16 , // Slice is (ptr, len)
1095
+ kind: Kind :: Pointer ,
1096
+ child: construct_typetree_from_ty( & slice_ty. ty) ,
1097
+ } ] )
1098
+ }
1099
+ TyKind :: Array ( array_ty) => {
1100
+ // For arrays, we need to handle the size properly
1101
+ let elem_ty = construct_typetree_from_ty ( & array_ty. ty ) ;
1102
+ TypeTree ( vec ! [ Type {
1103
+ offset: 0 ,
1104
+ size: 8 , // Array size depends on element type and count
1105
+ kind: Kind :: Pointer ,
1106
+ child: elem_ty,
1107
+ } ] )
1108
+ }
1109
+ TyKind :: Tup ( tuple_types) => {
1110
+ let mut types = Vec :: new ( ) ;
1111
+ let mut offset = 0 ;
1112
+ for ( i, tuple_ty) in tuple_types. iter ( ) . enumerate ( ) {
1113
+ let elem_ty = construct_typetree_from_ty ( tuple_ty) ;
1114
+ // For tuples, we need to handle alignment and padding
1115
+ // This is a simplified version
1116
+ types. push ( Type {
1117
+ offset : offset as isize ,
1118
+ size : 8 , // Should be computed based on actual type
1119
+ kind : Kind :: Float , // Default
1120
+ child : elem_ty,
1121
+ } ) ;
1122
+ offset += 8 ; // Simplified alignment
1123
+ }
1124
+ TypeTree ( types)
1125
+ }
1126
+ _ => {
1127
+ // Default case for unknown types
1128
+ TypeTree ( vec ! [ Type {
1129
+ offset: 0 ,
1130
+ size: 8 ,
1131
+ kind: Kind :: Float ,
1132
+ child: TypeTree :: new( ) ,
1133
+ } ] )
1134
+ }
1135
+ }
1136
+ }
1137
+
1138
+ #[ cfg( llvm_enzyme) ]
1139
+ fn construct_typetree_from_fnsig ( sig : & ast:: FnSig ) -> ( Vec < TypeTree > , TypeTree ) {
1140
+ // Construct type trees for input arguments
1141
+ let inputs: Vec < TypeTree > = sig. decl . inputs . iter ( )
1142
+ . map ( |param| construct_typetree_from_ty ( & param. ty ) )
1143
+ . collect ( ) ;
1144
+
1145
+ // Construct type tree for return type
1146
+ let output = match & sig. decl . output {
1147
+ FnRetTy :: Default ( span) => {
1148
+ // Unit type ()
1149
+ TypeTree ( vec ! [ Type {
1150
+ offset: 0 ,
1151
+ size: 0 ,
1152
+ kind: Kind :: Integer ,
1153
+ child: TypeTree :: new( ) ,
1154
+ } ] )
1155
+ }
1156
+ FnRetTy :: Ty ( ty) => construct_typetree_from_ty ( ty) ,
1157
+ } ;
1158
+
1159
+ ( inputs, output)
1160
+ }
1161
+
1062
1162
pub ( crate ) use llvm_enzyme:: { expand_forward, expand_reverse} ;
0 commit comments