Skip to content

Commit 72a4d3a

Browse files
committed
macro and basic attribute added
Signed-off-by: Karan Janthe <[email protected]>
1 parent 1b61d43 commit 72a4d3a

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use std::str::FromStr;
99
use crate::expand::{Decodable, Encodable, HashStable_Generic};
1010
use crate::ptr::P;
1111
use crate::{Ty, TyKind};
12+
use crate::expand::typetree::TypeTree;
1213

1314
/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
1415
/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
@@ -85,6 +86,9 @@ pub struct AutoDiffItem {
8586
/// The name of the function being generated
8687
pub target: String,
8788
pub attrs: AutoDiffAttrs,
89+
// --- TypeTree support ---
90+
pub inputs: Vec<TypeTree>,
91+
pub output: TypeTree,
8892
}
8993

9094
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
@@ -112,6 +116,10 @@ impl AutoDiffAttrs {
112116
pub fn has_primal_ret(&self) -> bool {
113117
matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
114118
}
119+
/// New constructor for type tree support
120+
pub fn into_item(self, source: String, target: String, inputs: Vec<TypeTree>, output: TypeTree) -> AutoDiffItem {
121+
AutoDiffItem { source, target, attrs: self, inputs, output }
122+
}
115123
}
116124

117125
impl DiffMode {
@@ -284,6 +292,8 @@ impl AutoDiffAttrs {
284292
impl fmt::Display for AutoDiffItem {
285293
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286294
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
287-
write!(f, " with attributes: {:?}", self.attrs)
295+
write!(f, " with attributes: {:?}", self.attrs)?;
296+
write!(f, " with inputs: {:?}", self.inputs)?;
297+
write!(f, " with output: {:?}", self.output)
288298
}
289299
}

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ mod llvm_enzyme {
2525
use tracing::{debug, trace};
2626

2727
use crate::errors;
28+
use crate::expand::typetree::TypeTree;
2829

2930
pub(crate) fn outer_normal_attr(
3031
kind: &P<rustc_ast::NormalAttr>,
@@ -324,6 +325,18 @@ mod llvm_enzyme {
324325
}
325326
let span = ecx.with_def_site_ctxt(expand_span);
326327

328+
// Prepare placeholder type trees for now
329+
let num_args = sig.decl.inputs.len();
330+
let inputs = vec![TypeTree::new(); num_args];
331+
let output = TypeTree::new();
332+
// Use the new into_item method to construct the AutoDiffItem
333+
let autodiff_item = x.clone().into_item(
334+
primal.to_string(),
335+
first_ident(&meta_item_vec[0]).to_string(),
336+
inputs,
337+
output,
338+
);
339+
327340
let n_active: u32 = x
328341
.input_activity
329342
.iter()

0 commit comments

Comments
 (0)