@@ -8,6 +8,7 @@ use libc::{c_char, c_int, c_void, size_t};
88use llvm:: {
99 LLVMRustLLVMHasZlibCompressionForDebugSymbols , LLVMRustLLVMHasZstdCompressionForDebugSymbols ,
1010} ;
11+ use rustc_ast:: expand:: autodiff_attrs:: AutoDiffItem ;
1112use rustc_codegen_ssa:: back:: link:: ensure_removed;
1213use rustc_codegen_ssa:: back:: versioned_llvm_target;
1314use rustc_codegen_ssa:: back:: write:: {
@@ -28,7 +29,7 @@ use rustc_session::config::{
2829use rustc_span:: symbol:: sym;
2930use rustc_span:: { BytePos , InnerSpan , Pos , SpanData , SyntaxContext } ;
3031use rustc_target:: spec:: { CodeModel , RelocModel , SanitizerSet , SplitDebuginfo , TlsModel } ;
31- use tracing:: debug;
32+ use tracing:: { debug, trace } ;
3233
3334use crate :: back:: lto:: ThinBuffer ;
3435use crate :: back:: owned_target_machine:: OwnedTargetMachine ;
@@ -530,9 +531,38 @@ pub(crate) unsafe fn llvm_optimize(
530531 config : & ModuleConfig ,
531532 opt_level : config:: OptLevel ,
532533 opt_stage : llvm:: OptStage ,
534+ skip_size_increasing_opts : bool ,
533535) -> Result < ( ) , FatalError > {
534- let unroll_loops =
535- opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
536+ // Enzyme:
537+ // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
538+ // source code. However, benchmarks show that optimizations increasing the code size
539+ // tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
540+ // and finally re-optimize the module, now with all optimizations available.
541+ // FIXME(ZuseZ4): In a future update we could figure out how to only optimize individual functions getting
542+ // differentiated.
543+
544+ let unroll_loops;
545+ let vectorize_slp;
546+ let vectorize_loop;
547+
548+ // When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
549+ // optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly,
550+ // we should make this more granular, or at least check that the user has at least one autodiff
551+ // call in their code, to justify altering the compilation pipeline.
552+ if skip_size_increasing_opts && cfg ! ( llvm_enzyme) {
553+ unroll_loops = false ;
554+ vectorize_slp = false ;
555+ vectorize_loop = false ;
556+ } else {
557+ unroll_loops =
558+ opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
559+ vectorize_slp = config. vectorize_slp ;
560+ vectorize_loop = config. vectorize_loop ;
561+ }
562+ trace ! (
563+ "Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}" ,
564+ unroll_loops, vectorize_slp, vectorize_loop
565+ ) ;
536566 let using_thin_buffers = opt_stage == llvm:: OptStage :: PreLinkThinLTO || config. bitcode_needed ( ) ;
537567 let pgo_gen_path = get_pgo_gen_path ( config) ;
538568 let pgo_use_path = get_pgo_use_path ( config) ;
@@ -596,8 +626,8 @@ pub(crate) unsafe fn llvm_optimize(
596626 using_thin_buffers,
597627 config. merge_functions ,
598628 unroll_loops,
599- config . vectorize_slp ,
600- config . vectorize_loop ,
629+ vectorize_slp,
630+ vectorize_loop,
601631 config. no_builtins ,
602632 config. emit_lifetime_markers ,
603633 sanitizer_options. as_ref ( ) ,
@@ -619,6 +649,83 @@ pub(crate) unsafe fn llvm_optimize(
619649 result. into_result ( ) . map_err ( |( ) | llvm_err ( dcx, LlvmError :: RunLlvmPasses ) )
620650}
621651
652+ pub ( crate ) fn differentiate (
653+ module : & ModuleCodegen < ModuleLlvm > ,
654+ cgcx : & CodegenContext < LlvmCodegenBackend > ,
655+ diff_items : Vec < AutoDiffItem > ,
656+ config : & ModuleConfig ,
657+ ) -> Result < ( ) , FatalError > {
658+ for item in & diff_items {
659+ trace ! ( "{}" , item) ;
660+ }
661+
662+ let llmod = module. module_llvm . llmod ( ) ;
663+ let llcx = & module. module_llvm . llcx ;
664+ let diag_handler = cgcx. create_dcx ( ) ;
665+
666+ // Before dumping the module, we want all the tt to become part of the module.
667+ for item in diff_items. iter ( ) {
668+ let name = CString :: new ( item. source . clone ( ) ) . unwrap ( ) ;
669+ let fn_def: Option < & llvm:: Value > =
670+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, name. as_ptr ( ) ) } ;
671+ let fn_def = match fn_def {
672+ Some ( x) => x,
673+ None => {
674+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
675+ src : item. source . clone ( ) ,
676+ target : item. target . clone ( ) ,
677+ error : "could not find source function" . to_owned ( ) ,
678+ } ) ) ;
679+ }
680+ } ;
681+ let target_name = CString :: new ( item. target . clone ( ) ) . unwrap ( ) ;
682+ debug ! ( "target name: {:?}" , & target_name) ;
683+ let fn_target: Option < & llvm:: Value > =
684+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, target_name. as_ptr ( ) ) } ;
685+ let fn_target = match fn_target {
686+ Some ( x) => x,
687+ None => {
688+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
689+ src : item. source . clone ( ) ,
690+ target : item. target . clone ( ) ,
691+ error : "could not find target function" . to_owned ( ) ,
692+ } ) ) ;
693+ }
694+ } ;
695+
696+ crate :: builder:: generate_enzyme_call ( llmod, llcx, fn_def, fn_target, item. attrs . clone ( ) ) ;
697+ }
698+
699+ // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
700+
701+ if let Some ( opt_level) = config. opt_level {
702+ let opt_stage = match cgcx. lto {
703+ Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
704+ Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
705+ _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
706+ _ => llvm:: OptStage :: PreLinkNoLTO ,
707+ } ;
708+ // This is our second opt call, so now we run all opts,
709+ // to make sure we get the best performance.
710+ let skip_size_increasing_opts = false ;
711+ trace ! ( "running Module Optimization after differentiation" ) ;
712+ unsafe {
713+ llvm_optimize (
714+ cgcx,
715+ diag_handler. handle ( ) ,
716+ module,
717+ config,
718+ opt_level,
719+ opt_stage,
720+ skip_size_increasing_opts,
721+ ) ?
722+ } ;
723+ }
724+ trace ! ( "done with differentiate()" ) ;
725+
726+ Ok ( ( ) )
727+ }
728+
622729// Unsafe due to LLVM calls.
623730pub ( crate ) unsafe fn optimize (
624731 cgcx : & CodegenContext < LlvmCodegenBackend > ,
@@ -641,14 +748,29 @@ pub(crate) unsafe fn optimize(
641748 unsafe { llvm:: LLVMWriteBitcodeToFile ( llmod, out. as_ptr ( ) ) } ;
642749 }
643750
751+ // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
752+
644753 if let Some ( opt_level) = config. opt_level {
645754 let opt_stage = match cgcx. lto {
646755 Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
647756 Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
648757 _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
649758 _ => llvm:: OptStage :: PreLinkNoLTO ,
650759 } ;
651- return unsafe { llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage) } ;
760+
761+ // If we know that we will later run AD, then we disable vectorization and loop unrolling
762+ let skip_size_increasing_opts = cfg ! ( llvm_enzyme) ;
763+ return unsafe {
764+ llvm_optimize (
765+ cgcx,
766+ dcx,
767+ module,
768+ config,
769+ opt_level,
770+ opt_stage,
771+ skip_size_increasing_opts,
772+ )
773+ } ;
652774 }
653775 Ok ( ( ) )
654776}
0 commit comments