Skip to content

Commit 470e4ca

Browse files
committed
Move logic to a dedicated enzyme_autodiff intrinsic
1 parent fad0b0c commit 470e4ca

File tree

7 files changed

+181
-46
lines changed

7 files changed

+181
-46
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 130 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -331,20 +331,23 @@ mod llvm_enzyme {
331331
.count() as u32;
332332
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
333333

334-
// UNUSED
334+
// TODO(Sa4dUs): Remove this and all the related logic
335335
let _d_body = gen_enzyme_body(
336336
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
337337
&generics,
338338
);
339339

340+
let d_body =
341+
call_enzyme_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig);
342+
340343
// The first element of it is the name of the function to be generated
341344
let asdf = Box::new(ast::Fn {
342345
defaultness: ast::Defaultness::Final,
343346
sig: d_sig,
344347
ident: first_ident(&meta_item_vec[0]),
345-
generics,
348+
generics: generics.clone(),
346349
contract: None,
347-
body: None, // This leads to an error when the ad function is inside a traits
350+
body: Some(d_body),
348351
define_opaque: None,
349352
});
350353
let mut rustc_ad_attr =
@@ -431,18 +434,15 @@ mod llvm_enzyme {
431434
tokens: ts,
432435
});
433436

434-
let rustc_intrinsic_attr =
435-
P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_intrinsic)));
436-
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
437-
let intrinsic_attr = outer_normal_attr(&rustc_intrinsic_attr, new_id, span);
437+
let vis_clone = vis.clone();
438438

439439
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
440440
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
441441
let d_annotatable = match &item {
442442
Annotatable::AssocItem(_, _) => {
443443
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
444444
let d_fn = P(ast::AssocItem {
445-
attrs: thin_vec![d_attr, intrinsic_attr],
445+
attrs: thin_vec![d_attr],
446446
id: ast::DUMMY_NODE_ID,
447447
span,
448448
vis,
@@ -452,15 +452,13 @@ mod llvm_enzyme {
452452
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
453453
}
454454
Annotatable::Item(_) => {
455-
let mut d_fn =
456-
ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf));
455+
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
457456
d_fn.vis = vis;
458457

459458
Annotatable::Item(d_fn)
460459
}
461460
Annotatable::Stmt(_) => {
462-
let mut d_fn =
463-
ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf));
461+
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
464462
d_fn.vis = vis;
465463

466464
Annotatable::Stmt(P(ast::Stmt {
@@ -474,7 +472,9 @@ mod llvm_enzyme {
474472
}
475473
};
476474

477-
return vec![orig_annotatable, d_annotatable];
475+
let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone);
476+
477+
return vec![orig_annotatable, dummy_const_annotatable, d_annotatable];
478478
}
479479

480480
// shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
@@ -495,6 +495,123 @@ mod llvm_enzyme {
495495
ty
496496
}
497497

498+
// Generate `enzyme_autodiff` intrinsic call
499+
// ```
500+
// std::intrinsics::enzyme_autodiff(source, diff, (args))
501+
// ```
502+
fn call_enzyme_autodiff(
503+
ecx: &ExtCtxt<'_>,
504+
primal: Ident,
505+
diff: Ident,
506+
span: Span,
507+
d_sig: &FnSig,
508+
) -> P<ast::Block> {
509+
let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal));
510+
let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff));
511+
512+
let tuple_expr = ecx.expr_tuple(
513+
span,
514+
d_sig
515+
.decl
516+
.inputs
517+
.iter()
518+
.map(|arg| match arg.pat.kind {
519+
PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)),
520+
_ => todo!(),
521+
})
522+
.collect::<ThinVec<_>>()
523+
.into(),
524+
);
525+
526+
let enzyme_path = ecx.path(
527+
span,
528+
vec![
529+
Ident::from_str("std"),
530+
Ident::from_str("intrinsics"),
531+
Ident::from_str("enzyme_autodiff"),
532+
],
533+
);
534+
let call_expr = ecx.expr_call(
535+
span,
536+
ecx.expr_path(enzyme_path),
537+
vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
538+
);
539+
540+
let block = ecx.block_expr(call_expr);
541+
542+
block
543+
}
544+
545+
// Generate dummy const to prevent primal function
546+
// from being optimized away before applying enzyme
547+
// ```
548+
// const _: () =
549+
// {
550+
// #[used]
551+
// pub static DUMMY_PTR: fn_type = primal_fn;
552+
// };
553+
// ```
554+
fn gen_dummy_const(
555+
ecx: &ExtCtxt<'_>,
556+
span: Span,
557+
primal: Ident,
558+
sig: FnSig,
559+
generics: Generics,
560+
vis: Visibility,
561+
) -> Annotatable {
562+
// #[used]
563+
let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used)));
564+
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
565+
let used_attr = outer_normal_attr(&used_attr, new_id, span);
566+
567+
// static DUMMY_PTR: <fn_type> = <primal_ident>
568+
let static_ident = Ident::from_str_and_span("DUMMY_PTR", span);
569+
let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy {
570+
safety: sig.header.safety,
571+
ext: sig.header.ext,
572+
generic_params: generics.params,
573+
decl: sig.decl,
574+
decl_span: sig.span,
575+
}));
576+
let static_ty = ecx.ty(span, fn_ptr_ty);
577+
578+
let static_expr = ecx.expr_path(ecx.path(span, vec![primal]));
579+
let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem {
580+
ident: static_ident,
581+
ty: static_ty,
582+
safety: ast::Safety::Default,
583+
mutability: ast::Mutability::Not,
584+
expr: Some(static_expr),
585+
define_opaque: None,
586+
}));
587+
588+
let static_item = ast::Item {
589+
attrs: thin_vec![used_attr],
590+
id: ast::DUMMY_NODE_ID,
591+
span,
592+
vis,
593+
kind: static_item_kind,
594+
tokens: None,
595+
};
596+
597+
let block_expr = ecx.expr_block(Box::new(ast::Block {
598+
stmts: thin_vec![ecx.stmt_item(span, P(static_item))],
599+
id: ast::DUMMY_NODE_ID,
600+
rules: ast::BlockCheckMode::Default,
601+
span,
602+
tokens: None,
603+
}));
604+
605+
let const_item = ecx.item_const(
606+
span,
607+
Ident::from_str_and_span("_", span),
608+
ecx.ty(span, ast::TyKind::Tup(thin_vec![])),
609+
block_expr,
610+
);
611+
612+
Annotatable::Item(const_item)
613+
}
614+
498615
// Will generate a body of the type:
499616
// ```
500617
// {

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use rustc_middle::bug;
1010
use tracing::{debug, trace};
1111

1212
use crate::back::write::llvm_err;
13-
use crate::builder::{Builder, OperandRef, PlaceRef, UNNAMED};
13+
use crate::builder::{Builder, PlaceRef, UNNAMED};
1414
use crate::context::SimpleCx;
1515
use crate::declare::declare_simple_fn;
1616
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
@@ -200,7 +200,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
200200
fn_to_diff: &'ll Value,
201201
outer_name: &str,
202202
ret_ty: &'ll Type,
203-
fn_args: &[OperandRef<'tcx, &'ll Value>],
203+
fn_args: &[&'ll Value],
204204
attrs: AutoDiffAttrs,
205205
dest: PlaceRef<'tcx, &'ll Value>,
206206
) {
@@ -282,15 +282,13 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
282282
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
283283
}
284284

285-
let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect();
286-
287285
match_args_from_caller_to_enzyme(
288286
&cx,
289287
builder,
290288
attrs.width,
291289
&mut args,
292290
&attrs.input_activity,
293-
&outer_args,
291+
fn_args,
294292
);
295293

296294
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::cmp::Ordering;
33

44
use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size};
55
use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
6+
use rustc_codegen_ssa::codegen_attrs::autodiff_attrs;
67
use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
78
use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization};
89
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
@@ -196,48 +197,60 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
196197
&[ptr, args[1].immediate()],
197198
)
198199
}
199-
_ if tcx.has_attr(instance.def_id(), sym::rustc_autodiff) => {
200-
// NOTE(Sa4dUs): This is a hacky way to get the autodiff items
201-
// so we can focus on the lowering of the intrinsic call
202-
let mut source_id = None;
203-
let mut diff_attrs = None;
204-
let items: Vec<_> = tcx.hir_body_owners().map(|i| i.to_def_id()).collect();
205-
206-
// Hacky way of getting primal-diff pair, only works for code with 1 autodiff call
207-
for target_id in &items {
208-
let Some(target_attrs) = &tcx.codegen_fn_attrs(target_id).autodiff_item else {
209-
continue;
210-
};
200+
sym::enzyme_autodiff => {
201+
let val_arr: Vec<&'ll Value> = match args[2].val {
202+
crate::intrinsic::OperandValue::Ref(ref place_value) => {
203+
let mut ret_arr = vec![];
204+
let tuple_place = PlaceRef { val: *place_value, layout: args[2].layout };
211205

212-
if target_attrs.is_source() {
213-
source_id = Some(*target_id);
214-
} else {
215-
diff_attrs = Some(target_attrs);
216-
}
217-
}
206+
for i in 0..tuple_place.layout.layout.0.fields.count() {
207+
let field_place = tuple_place.project_field(self, i);
208+
let field_layout = tuple_place.layout.field(self, i);
209+
let llvm_ty = field_layout.llvm_type(self.cx);
218210

219-
if source_id.is_none() || diff_attrs.is_none() {
220-
bug!("could not find source_id={source_id:?} or diff_attrs={diff_attrs:?}");
221-
}
211+
let field_val =
212+
self.load(llvm_ty, field_place.val.llval, field_place.val.align);
213+
214+
ret_arr.push(field_val)
215+
}
222216

223-
let diff_attrs = diff_attrs.unwrap().clone();
217+
ret_arr
218+
}
219+
crate::intrinsic::OperandValue::Pair(v1, v2) => vec![v1, v2],
220+
OperandValue::Immediate(v) => vec![v],
221+
OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"),
222+
};
224223

225-
// Get source fn
226-
let source_id = source_id.unwrap();
227-
let fn_source = Instance::mono(tcx, source_id);
224+
// Get source, diff, and attrs
225+
let source_id = match fn_args.into_type_list(tcx)[0].kind() {
226+
ty::FnDef(def_id, _) => def_id,
227+
_ => bug!("invalid args"),
228+
};
229+
let fn_source = Instance::mono(tcx, *source_id);
228230
let source_symbol =
229231
symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
230232
let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol);
231233
let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") };
232234

235+
let diff_id = match fn_args.into_type_list(tcx)[1].kind() {
236+
ty::FnDef(def_id, _) => def_id,
237+
_ => bug!("invalid args"),
238+
};
239+
let fn_diff = Instance::mono(tcx, *diff_id);
240+
let diff_symbol =
241+
symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
242+
243+
let diff_attrs = autodiff_attrs(tcx, *diff_id);
244+
let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") };
245+
233246
// Build body
234247
generate_enzyme_call(
235248
self,
236249
self.cx,
237250
fn_to_diff,
238-
name.as_str(),
251+
&diff_symbol,
239252
llret_ty,
240-
args,
253+
&val_arr,
241254
diff_attrs.clone(),
242255
result,
243256
);

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ impl<'a> MixedExportNameAndNoMangleState<'a> {
720720
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
721721
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
722722
/// panic, unless we introduced a bug when parsing the autodiff macro.
723-
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
723+
pub fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
724724
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
725725

726726
let attrs = attrs.filter(|attr| attr.has_name(sym::rustc_autodiff)).collect::<Vec<_>>();

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
134134
| sym::round_ties_even_f32
135135
| sym::round_ties_even_f64
136136
| sym::round_ties_even_f128
137+
| sym::enzyme_autodiff
137138
| sym::const_eval_select => hir::Safety::Safe,
138139
_ => hir::Safety::Unsafe,
139140
};
@@ -215,6 +216,7 @@ pub(crate) fn check_intrinsic_type(
215216

216217
(n_tps, n_cts, inputs, output)
217218
}
219+
sym::enzyme_autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)),
218220
sym::abort => (0, 0, vec![], tcx.types.never),
219221
sym::unreachable => (0, 0, vec![], tcx.types.never),
220222
sym::breakpoint => (0, 0, vec![], tcx.types.unit),

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,7 @@ symbols! {
915915
enumerate_method,
916916
env,
917917
env_CFG_RELEASE: env!("CFG_RELEASE"),
918+
enzyme_autodiff,
918919
eprint_macro,
919920
eprintln_macro,
920921
eq,

library/core/src/intrinsics/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3114,6 +3114,10 @@ pub const unsafe fn copysignf64(x: f64, y: f64) -> f64;
31143114
#[rustc_intrinsic]
31153115
pub const unsafe fn copysignf128(x: f128, y: f128) -> f128;
31163116

3117+
#[rustc_nounwind]
3118+
#[rustc_intrinsic]
3119+
pub const fn enzyme_autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) -> R;
3120+
31173121
/// Inform Miri that a given pointer definitely has a certain alignment.
31183122
#[cfg(miri)]
31193123
#[rustc_allow_const_fn_unstable(const_eval_select)]

0 commit comments

Comments
 (0)