@@ -586,6 +586,42 @@ fn thin_lto(
586586 }
587587}
588588
589+ fn enable_autodiff_settings ( ad : & [ config:: AutoDiff ] , module : & mut ModuleCodegen < ModuleLlvm > ) {
590+ for & val in ad {
591+ match val {
592+ config:: AutoDiff :: PrintModBefore => {
593+ unsafe { llvm:: LLVMDumpModule ( module. module_llvm . llmod ( ) ) } ;
594+ }
595+ config:: AutoDiff :: PrintPerf => {
596+ llvm:: set_print_perf ( true ) ;
597+ }
598+ config:: AutoDiff :: PrintAA => {
599+ llvm:: set_print_activity ( true ) ;
600+ }
601+ config:: AutoDiff :: PrintTA => {
602+ llvm:: set_print_type ( true ) ;
603+ }
604+ config:: AutoDiff :: Inline => {
605+ llvm:: set_inline ( true ) ;
606+ }
607+ config:: AutoDiff :: LooseTypes => {
608+ llvm:: set_loose_types ( false ) ;
609+ }
610+ config:: AutoDiff :: PrintSteps => {
611+ llvm:: set_print ( true ) ;
612+ }
613+ // We handle this below
614+ config:: AutoDiff :: PrintModAfter => { }
615+ // This is required and already checked
616+ config:: AutoDiff :: Enable => { }
617+ }
618+ }
619+ // This helps with handling enums for now.
620+ llvm:: set_strict_aliasing ( false ) ;
621+ // FIXME(ZuseZ4): Test this, since it was added a long time ago.
622+ llvm:: set_rust_rules ( true ) ;
623+ }
624+
589625pub ( crate ) fn run_pass_manager (
590626 cgcx : & CodegenContext < LlvmCodegenBackend > ,
591627 dcx : DiagCtxtHandle < ' _ > ,
@@ -604,34 +640,37 @@ pub(crate) fn run_pass_manager(
604640 let opt_stage = if thin { llvm:: OptStage :: ThinLTO } else { llvm:: OptStage :: FatLTO } ;
605641 let opt_level = config. opt_level . unwrap_or ( config:: OptLevel :: No ) ;
606642
607- // If this rustc version was build with enzyme/autodiff enabled, and if users applied the
608- // `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
609- debug ! ( "running llvm pm opt pipeline" ) ;
643+ // The PostAD behavior is the same that we would have if no autodiff was used.
644+ // It will run the default optimization pipeline. If AD is enabled we select
645+ // the DuringAD stage, which will disable vectorization and loop unrolling, and
646+ // schedule two autodiff optimization + differentiation passes.
647+ // We then run the llvm_optimize function a second time, to optimize the code which we generated
648+ // in the enzyme differentiation pass.
649+ let enable_ad = config. autodiff . contains ( & config:: AutoDiff :: Enable ) ;
650+ let stage =
651+ if enable_ad { write:: AutodiffStage :: DuringAD } else { write:: AutodiffStage :: PostAD } ;
652+
653+ if enable_ad {
654+ enable_autodiff_settings ( & config. autodiff , module) ;
655+ }
656+
610657 unsafe {
611- write:: llvm_optimize (
612- cgcx,
613- dcx,
614- module,
615- config,
616- opt_level,
617- opt_stage,
618- write:: AutodiffStage :: DuringAD ,
619- ) ?;
658+ write:: llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage, stage) ?;
620659 }
621- // FIXME(ZuseZ4): Make this more granular
622- if cfg ! ( llvm_enzyme) && !thin {
660+
661+ if cfg ! ( llvm_enzyme) && enable_ad {
662+ let opt_stage = llvm:: OptStage :: FatLTO ;
663+ let stage = write:: AutodiffStage :: PostAD ;
623664 unsafe {
624- write:: llvm_optimize (
625- cgcx,
626- dcx,
627- module,
628- config,
629- opt_level,
630- llvm:: OptStage :: FatLTO ,
631- write:: AutodiffStage :: PostAD ,
632- ) ?;
665+ write:: llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage, stage) ?;
666+ }
667+
668+ // This is the final IR, so people should be able to inspect the optimized autodiff output.
669+ if config. autodiff . contains ( & config:: AutoDiff :: PrintModAfter ) {
670+ unsafe { llvm:: LLVMDumpModule ( module. module_llvm . llmod ( ) ) } ;
633671 }
634672 }
673+
635674 debug ! ( "lto done" ) ;
636675 Ok ( ( ) )
637676}
0 commit comments