@@ -331,20 +331,23 @@ mod llvm_enzyme {
331
331
. count ( ) as u32 ;
332
332
let ( d_sig, new_args, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
333
333
334
- // UNUSED
334
+ // TODO(Sa4dUs): Remove this and all the related logic
335
335
let _d_body = gen_enzyme_body (
336
336
ecx, & x, n_active, & sig, & d_sig, primal, & new_args, span, sig_span, idents, errored,
337
337
& generics,
338
338
) ;
339
339
340
+ let d_body =
341
+ call_enzyme_autodiff ( ecx, primal, first_ident ( & meta_item_vec[ 0 ] ) , span, & d_sig) ;
342
+
340
343
// The first element of it is the name of the function to be generated
341
344
let asdf = Box :: new ( ast:: Fn {
342
345
defaultness : ast:: Defaultness :: Final ,
343
346
sig : d_sig,
344
347
ident : first_ident ( & meta_item_vec[ 0 ] ) ,
345
- generics,
348
+ generics : generics . clone ( ) ,
346
349
contract : None ,
347
- body : None , // This leads to an error when the ad function is inside a traits
350
+ body : Some ( d_body ) ,
348
351
define_opaque : None ,
349
352
} ) ;
350
353
let mut rustc_ad_attr =
@@ -431,18 +434,15 @@ mod llvm_enzyme {
431
434
tokens : ts,
432
435
} ) ;
433
436
434
- let rustc_intrinsic_attr =
435
- P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: rustc_intrinsic) ) ) ;
436
- let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
437
- let intrinsic_attr = outer_normal_attr ( & rustc_intrinsic_attr, new_id, span) ;
437
+ let vis_clone = vis. clone ( ) ;
438
438
439
439
let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
440
440
let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
441
441
let d_annotatable = match & item {
442
442
Annotatable :: AssocItem ( _, _) => {
443
443
let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
444
444
let d_fn = P ( ast:: AssocItem {
445
- attrs : thin_vec ! [ d_attr, intrinsic_attr ] ,
445
+ attrs : thin_vec ! [ d_attr] ,
446
446
id : ast:: DUMMY_NODE_ID ,
447
447
span,
448
448
vis,
@@ -452,15 +452,13 @@ mod llvm_enzyme {
452
452
Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
453
453
}
454
454
Annotatable :: Item ( _) => {
455
- let mut d_fn =
456
- ecx. item ( span, thin_vec ! [ d_attr, intrinsic_attr] , ItemKind :: Fn ( asdf) ) ;
455
+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf) ) ;
457
456
d_fn. vis = vis;
458
457
459
458
Annotatable :: Item ( d_fn)
460
459
}
461
460
Annotatable :: Stmt ( _) => {
462
- let mut d_fn =
463
- ecx. item ( span, thin_vec ! [ d_attr, intrinsic_attr] , ItemKind :: Fn ( asdf) ) ;
461
+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf) ) ;
464
462
d_fn. vis = vis;
465
463
466
464
Annotatable :: Stmt ( P ( ast:: Stmt {
@@ -474,7 +472,9 @@ mod llvm_enzyme {
474
472
}
475
473
} ;
476
474
477
- return vec ! [ orig_annotatable, d_annotatable] ;
475
+ let dummy_const_annotatable = gen_dummy_const ( ecx, span, primal, sig, generics, vis_clone) ;
476
+
477
+ return vec ! [ orig_annotatable, dummy_const_annotatable, d_annotatable] ;
478
478
}
479
479
480
480
// shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
@@ -495,6 +495,123 @@ mod llvm_enzyme {
495
495
ty
496
496
}
497
497
498
+ // Generate `enzyme_autodiff` intrinsic call
499
+ // ```
500
+ // std::intrinsics::enzyme_autodiff(source, diff, (args))
501
+ // ```
502
+ fn call_enzyme_autodiff (
503
+ ecx : & ExtCtxt < ' _ > ,
504
+ primal : Ident ,
505
+ diff : Ident ,
506
+ span : Span ,
507
+ d_sig : & FnSig ,
508
+ ) -> P < ast:: Block > {
509
+ let primal_path_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
510
+ let diff_path_expr = ecx. expr_path ( ecx. path_ident ( span, diff) ) ;
511
+
512
+ let tuple_expr = ecx. expr_tuple (
513
+ span,
514
+ d_sig
515
+ . decl
516
+ . inputs
517
+ . iter ( )
518
+ . map ( |arg| match arg. pat . kind {
519
+ PatKind :: Ident ( _, ident, _) => ecx. expr_path ( ecx. path_ident ( span, ident) ) ,
520
+ _ => todo ! ( ) ,
521
+ } )
522
+ . collect :: < ThinVec < _ > > ( )
523
+ . into ( ) ,
524
+ ) ;
525
+
526
+ let enzyme_path = ecx. path (
527
+ span,
528
+ vec ! [
529
+ Ident :: from_str( "std" ) ,
530
+ Ident :: from_str( "intrinsics" ) ,
531
+ Ident :: from_str( "enzyme_autodiff" ) ,
532
+ ] ,
533
+ ) ;
534
+ let call_expr = ecx. expr_call (
535
+ span,
536
+ ecx. expr_path ( enzyme_path) ,
537
+ vec ! [ primal_path_expr, diff_path_expr, tuple_expr] . into ( ) ,
538
+ ) ;
539
+
540
+ let block = ecx. block_expr ( call_expr) ;
541
+
542
+ block
543
+ }
544
+
545
+ // Generate dummy const to prevent primal function
546
+ // from being optimized away before applying enzyme
547
+ // ```
548
+ // const _: () =
549
+ // {
550
+ // #[used]
551
+ // pub static DUMMY_PTR: fn_type = primal_fn;
552
+ // };
553
+ // ```
554
+ fn gen_dummy_const (
555
+ ecx : & ExtCtxt < ' _ > ,
556
+ span : Span ,
557
+ primal : Ident ,
558
+ sig : FnSig ,
559
+ generics : Generics ,
560
+ vis : Visibility ,
561
+ ) -> Annotatable {
562
+ // #[used]
563
+ let used_attr = P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: used) ) ) ;
564
+ let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
565
+ let used_attr = outer_normal_attr ( & used_attr, new_id, span) ;
566
+
567
+ // static DUMMY_PTR: <fn_type> = <primal_ident>
568
+ let static_ident = Ident :: from_str_and_span ( "DUMMY_PTR" , span) ;
569
+ let fn_ptr_ty = ast:: TyKind :: BareFn ( Box :: new ( ast:: BareFnTy {
570
+ safety : sig. header . safety ,
571
+ ext : sig. header . ext ,
572
+ generic_params : generics. params ,
573
+ decl : sig. decl ,
574
+ decl_span : sig. span ,
575
+ } ) ) ;
576
+ let static_ty = ecx. ty ( span, fn_ptr_ty) ;
577
+
578
+ let static_expr = ecx. expr_path ( ecx. path ( span, vec ! [ primal] ) ) ;
579
+ let static_item_kind = ast:: ItemKind :: Static ( Box :: new ( ast:: StaticItem {
580
+ ident : static_ident,
581
+ ty : static_ty,
582
+ safety : ast:: Safety :: Default ,
583
+ mutability : ast:: Mutability :: Not ,
584
+ expr : Some ( static_expr) ,
585
+ define_opaque : None ,
586
+ } ) ) ;
587
+
588
+ let static_item = ast:: Item {
589
+ attrs : thin_vec ! [ used_attr] ,
590
+ id : ast:: DUMMY_NODE_ID ,
591
+ span,
592
+ vis,
593
+ kind : static_item_kind,
594
+ tokens : None ,
595
+ } ;
596
+
597
+ let block_expr = ecx. expr_block ( Box :: new ( ast:: Block {
598
+ stmts : thin_vec ! [ ecx. stmt_item( span, P ( static_item) ) ] ,
599
+ id : ast:: DUMMY_NODE_ID ,
600
+ rules : ast:: BlockCheckMode :: Default ,
601
+ span,
602
+ tokens : None ,
603
+ } ) ) ;
604
+
605
+ let const_item = ecx. item_const (
606
+ span,
607
+ Ident :: from_str_and_span ( "_" , span) ,
608
+ ecx. ty ( span, ast:: TyKind :: Tup ( thin_vec ! [ ] ) ) ,
609
+ block_expr,
610
+ ) ;
611
+
612
+ Annotatable :: Item ( const_item)
613
+ }
614
+
498
615
// Will generate a body of the type:
499
616
// ```
500
617
// {
0 commit comments