@@ -4,10 +4,9 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit
44use rustc_codegen_ssa:: ModuleCodegen ;
55use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
66use rustc_errors:: FatalError ;
7- use rustc_session:: config:: Lto ;
87use tracing:: { debug, trace} ;
98
10- use crate :: back:: write:: { llvm_err, llvm_optimize } ;
9+ use crate :: back:: write:: llvm_err;
1110use crate :: builder:: SBuilder ;
1211use crate :: context:: SimpleCx ;
1312use crate :: declare:: declare_simple_fn;
@@ -53,8 +52,6 @@ fn generate_enzyme_call<'ll>(
5352 let mut ad_name: String = match attrs. mode {
5453 DiffMode :: Forward => "__enzyme_fwddiff" ,
5554 DiffMode :: Reverse => "__enzyme_autodiff" ,
56- DiffMode :: ForwardFirst => "__enzyme_fwddiff" ,
57- DiffMode :: ReverseFirst => "__enzyme_autodiff" ,
5855 _ => panic ! ( "logic bug in autodiff, unrecognized mode" ) ,
5956 }
6057 . to_string ( ) ;
@@ -153,7 +150,7 @@ fn generate_enzyme_call<'ll>(
153150 _ => { }
154151 }
155152
156- trace ! ( "matching autodiff arguments" ) ;
153+ debug ! ( "matching autodiff arguments" ) ;
157154 // We now handle the issue that Rust level arguments not always match the llvm-ir level
158155 // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
159156 // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
@@ -164,10 +161,10 @@ fn generate_enzyme_call<'ll>(
164161 let mut activity_pos = 0 ;
165162 let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn) ;
166163 while activity_pos < inputs. len ( ) {
167- let activity = inputs[ activity_pos as usize ] ;
164+ let diff_activity = inputs[ activity_pos as usize ] ;
168165 // Duplicated arguments received a shadow argument, into which enzyme will write the
169166 // gradient.
170- let ( activity, duplicated) : ( & Metadata , bool ) = match activity {
167+ let ( activity, duplicated) : ( & Metadata , bool ) = match diff_activity {
171168 DiffActivity :: None => panic ! ( "not a valid input activity" ) ,
172169 DiffActivity :: Const => ( enzyme_const, false ) ,
173170 DiffActivity :: Active => ( enzyme_out, false ) ,
@@ -222,7 +219,15 @@ fn generate_enzyme_call<'ll>(
222219 // A duplicated pointer will have the following two outer_fn arguments:
223220 // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
224221 // (..., metadata! enzyme_dup, ptr, ptr, ...).
225- assert ! ( llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Pointer ) ;
222+ if matches ! (
223+ diff_activity,
224+ DiffActivity :: Duplicated | DiffActivity :: DuplicatedOnly
225+ ) {
226+ assert ! (
227+ llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Pointer
228+ ) ;
229+ }
230+ // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
226231 args. push ( next_outer_arg) ;
227232 outer_pos += 2 ;
228233 activity_pos += 1 ;
@@ -277,7 +282,7 @@ pub(crate) fn differentiate<'ll>(
277282 module : & ' ll ModuleCodegen < ModuleLlvm > ,
278283 cgcx : & CodegenContext < LlvmCodegenBackend > ,
279284 diff_items : Vec < AutoDiffItem > ,
280- config : & ModuleConfig ,
285+ _config : & ModuleConfig ,
281286) -> Result < ( ) , FatalError > {
282287 for item in & diff_items {
283288 trace ! ( "{}" , item) ;
@@ -318,29 +323,6 @@ pub(crate) fn differentiate<'ll>(
318323
319324 // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
320325
321- if let Some ( opt_level) = config. opt_level {
322- let opt_stage = match cgcx. lto {
323- Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
324- Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
325- _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
326- _ => llvm:: OptStage :: PreLinkNoLTO ,
327- } ;
328- // This is our second opt call, so now we run all opts,
329- // to make sure we get the best performance.
330- let skip_size_increasing_opts = false ;
331- trace ! ( "running Module Optimization after differentiation" ) ;
332- unsafe {
333- llvm_optimize (
334- cgcx,
335- diag_handler. handle ( ) ,
336- module,
337- config,
338- opt_level,
339- opt_stage,
340- skip_size_increasing_opts,
341- ) ?
342- } ;
343- }
344326 trace ! ( "done with differentiate()" ) ;
345327
346328 Ok ( ( ) )
0 commit comments