@@ -9,6 +9,7 @@ use std::str::FromStr;
9
9
use crate :: expand:: { Decodable , Encodable , HashStable_Generic } ;
10
10
use crate :: ptr:: P ;
11
11
use crate :: { Ty , TyKind } ;
12
+ use crate :: expand:: typetree:: TypeTree ;
12
13
13
14
/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
14
15
/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
@@ -85,6 +86,9 @@ pub struct AutoDiffItem {
85
86
/// The name of the function being generated
86
87
pub target : String ,
87
88
pub attrs : AutoDiffAttrs ,
89
+ // --- TypeTree support ---
90
+ pub inputs : Vec < TypeTree > ,
91
+ pub output : TypeTree ,
88
92
}
89
93
90
94
#[ derive( Clone , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
@@ -112,6 +116,10 @@ impl AutoDiffAttrs {
112
116
pub fn has_primal_ret ( & self ) -> bool {
113
117
matches ! ( self . ret_activity, DiffActivity :: Active | DiffActivity :: Dual )
114
118
}
119
+ /// New constructor for type tree support
120
+ pub fn into_item ( self , source : String , target : String , inputs : Vec < TypeTree > , output : TypeTree ) -> AutoDiffItem {
121
+ AutoDiffItem { source, target, attrs : self , inputs, output }
122
+ }
115
123
}
116
124
117
125
impl DiffMode {
@@ -284,6 +292,8 @@ impl AutoDiffAttrs {
284
292
impl fmt:: Display for AutoDiffItem {
285
293
fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
286
294
write ! ( f, "Differentiating {} -> {}" , self . source, self . target) ?;
287
- write ! ( f, " with attributes: {:?}" , self . attrs)
295
+ write ! ( f, " with attributes: {:?}" , self . attrs) ?;
296
+ write ! ( f, " with inputs: {:?}" , self . inputs) ?;
297
+ write ! ( f, " with output: {:?}" , self . output)
288
298
}
289
299
}
0 commit comments