Skip to content

Commit 97a987f

Browse files
committed
Auto merge of #142544 - Sa4dUs:prevent-abi-changes, r=ZuseZ4
Prevent ABI changes affect EnzymeAD This PR handles ABI changes for autodiff input arguments to improve Enzyme compatibility. Fundamentally this adjusts activities when a function argument is lowered as an `ScalarPair`, so there's no mismatch between diff activities and args. Also removes activities corresponding to ZSTs. fixes: #144025 r? `@ZuseZ4`
2 parents 4793ef5 + e04567c commit 97a987f

File tree

4 files changed

+271
-3
lines changed

4 files changed

+271
-3
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ use std::ptr;
33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
44
use rustc_codegen_ssa::common::TypeKind;
55
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
6-
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
6+
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
77
use rustc_middle::{bug, ty};
8+
use rustc_target::callconv::PassMode;
89
use tracing::debug;
910

1011
use crate::builder::{Builder, PlaceRef, UNNAMED};
@@ -16,9 +17,12 @@ use crate::value::Value;
1617

1718
pub(crate) fn adjust_activity_to_abi<'tcx>(
1819
tcx: TyCtxt<'tcx>,
19-
fn_ty: Ty<'tcx>,
20+
instance: Instance<'tcx>,
21+
typing_env: TypingEnv<'tcx>,
2022
da: &mut Vec<DiffActivity>,
2123
) {
24+
let fn_ty = instance.ty(tcx, typing_env);
25+
2226
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
2327
bug!("expected fn def for autodiff, got {:?}", fn_ty);
2428
}
@@ -27,8 +31,16 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
2731
// All we do is decide how to handle the arguments.
2832
let sig = fn_ty.fn_sig(tcx).skip_binder();
2933

34+
// FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
35+
let Ok(fn_abi) =
36+
tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
37+
else {
38+
bug!("failed to get fn_abi of instance with empty varargs");
39+
};
40+
3041
let mut new_activities = vec![];
3142
let mut new_positions = vec![];
43+
let mut del_activities = 0;
3244
for (i, ty) in sig.inputs().iter().enumerate() {
3345
if let Some(inner_ty) = ty.builtin_deref(true) {
3446
if inner_ty.is_slice() {
@@ -80,6 +92,34 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
8092
continue;
8193
}
8294
}
95+
96+
let pci = PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), value: *ty };
97+
98+
let layout = match tcx.layout_of(pci) {
99+
Ok(layout) => layout.layout,
100+
Err(_) => {
101+
bug!("failed to compute layout for type {:?}", ty);
102+
}
103+
};
104+
105+
let pass_mode = &fn_abi.args[i].mode;
106+
107+
// For ZST, just ignore and don't add its activity, as this arg won't be present
108+
// in the LLVM passed to Enzyme.
109+
// Some targets pass ZST indirectly in the C ABI, in that case, handle it as a normal arg
110+
// FIXME(Sa4dUs): Enforce ZST corresponding diff activity be `Const`
111+
if *pass_mode == PassMode::Ignore {
112+
del_activities += 1;
113+
da.remove(i);
114+
}
115+
116+
// If the argument is lowered as a `ScalarPair`, we need to duplicate its activity.
117+
// Otherwise, the number of activities won't match the number of LLVM arguments and
118+
// this will lead to errors when verifying the Enzyme call.
119+
if let rustc_abi::BackendRepr::ScalarPair(_, _) = layout.backend_repr() {
120+
new_activities.push(da[i].clone());
121+
new_positions.push(i + 1 - del_activities);
122+
}
83123
}
84124
// now add the extra activities coming from slices
85125
// Reverse order to not invalidate the indices

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1208,7 +1208,8 @@ fn codegen_autodiff<'ll, 'tcx>(
12081208

12091209
adjust_activity_to_abi(
12101210
tcx,
1211-
fn_source.ty(tcx, TypingEnv::fully_monomorphized()),
1211+
fn_source,
1212+
TypingEnv::fully_monomorphized(),
12121213
&mut diff_attrs.input_activity,
12131214
);
12141215

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
//@ revisions: debug release
2+
3+
//@[debug] compile-flags: -Zautodiff=Enable -C opt-level=0 -Clto=fat
4+
//@[release] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
5+
//@ no-prefer-dynamic
6+
//@ needs-enzyme
7+
8+
// This test checks that Rust types are lowered to LLVM-IR types in a way
9+
// we expect and Enzyme can handle. We explicitly check release mode to
10+
// ensure that LLVM's O3 pipeline doesn't rewrite function signatures
11+
// into forms that Enzyme can't process correctly.
12+
13+
#![feature(autodiff)]
14+
15+
use std::autodiff::{autodiff_forward, autodiff_reverse};
16+
17+
#[derive(Copy, Clone)]
18+
struct Input {
19+
x: f32,
20+
y: f32,
21+
}
22+
23+
#[derive(Copy, Clone)]
24+
struct Wrapper {
25+
z: f32,
26+
}
27+
28+
#[derive(Copy, Clone)]
29+
struct NestedInput {
30+
x: f32,
31+
y: Wrapper,
32+
}
33+
34+
fn square(x: f32) -> f32 {
35+
x * x
36+
}
37+
38+
// CHECK-LABEL: ; abi_handling::df1
39+
// CHECK-NEXT: Function Attrs
40+
// debug-NEXT: define internal { float, float }
41+
// debug-SAME: (ptr align 4 %x, ptr align 4 %bx_0)
42+
// release-NEXT: define internal fastcc float
43+
// release-SAME: (float %x.0.val, float %x.4.val)
44+
45+
// CHECK-LABEL: ; abi_handling::f1
46+
// CHECK-NEXT: Function Attrs
47+
// debug-NEXT: define internal float
48+
// debug-SAME: (ptr align 4 %x)
49+
// release-NEXT: define internal fastcc noundef float
50+
// release-SAME: (float %x.0.val, float %x.4.val)
51+
#[autodiff_forward(df1, Dual, Dual)]
52+
#[inline(never)]
53+
fn f1(x: &[f32; 2]) -> f32 {
54+
x[0] + x[1]
55+
}
56+
57+
// CHECK-LABEL: ; abi_handling::df2
58+
// CHECK-NEXT: Function Attrs
59+
// debug-NEXT: define internal { float, float }
60+
// debug-SAME: (ptr %f, float %x, float %dret)
61+
// release-NEXT: define internal fastcc float
62+
// release-SAME: (float noundef %x)
63+
64+
// CHECK-LABEL: ; abi_handling::f2
65+
// CHECK-NEXT: Function Attrs
66+
// debug-NEXT: define internal float
67+
// debug-SAME: (ptr %f, float %x)
68+
// release-NEXT: define internal fastcc noundef float
69+
// release-SAME: (float noundef %x)
70+
#[autodiff_reverse(df2, Const, Active, Active)]
71+
#[inline(never)]
72+
fn f2(f: fn(f32) -> f32, x: f32) -> f32 {
73+
f(x)
74+
}
75+
76+
// CHECK-LABEL: ; abi_handling::df3
77+
// CHECK-NEXT: Function Attrs
78+
// debug-NEXT: define internal { float, float }
79+
// debug-SAME: (ptr align 4 %x, ptr align 4 %bx_0, ptr align 4 %y, ptr align 4 %by_0)
80+
// release-NEXT: define internal fastcc { float, float }
81+
// release-SAME: (float %x.0.val)
82+
83+
// CHECK-LABEL: ; abi_handling::f3
84+
// CHECK-NEXT: Function Attrs
85+
// debug-NEXT: define internal float
86+
// debug-SAME: (ptr align 4 %x, ptr align 4 %y)
87+
// release-NEXT: define internal fastcc noundef float
88+
// release-SAME: (float %x.0.val)
89+
#[autodiff_forward(df3, Dual, Dual, Dual)]
90+
#[inline(never)]
91+
fn f3<'a>(x: &'a f32, y: &'a f32) -> f32 {
92+
*x * *y
93+
}
94+
95+
// CHECK-LABEL: ; abi_handling::df4
96+
// CHECK-NEXT: Function Attrs
97+
// debug-NEXT: define internal { float, float }
98+
// debug-SAME: (float %x.0, float %x.1, float %bx_0.0, float %bx_0.1)
99+
// release-NEXT: define internal fastcc { float, float }
100+
// release-SAME: (float noundef %x.0, float noundef %x.1)
101+
102+
// CHECK-LABEL: ; abi_handling::f4
103+
// CHECK-NEXT: Function Attrs
104+
// debug-NEXT: define internal float
105+
// debug-SAME: (float %x.0, float %x.1)
106+
// release-NEXT: define internal fastcc noundef float
107+
// release-SAME: (float noundef %x.0, float noundef %x.1)
108+
#[autodiff_forward(df4, Dual, Dual)]
109+
#[inline(never)]
110+
fn f4(x: (f32, f32)) -> f32 {
111+
x.0 * x.1
112+
}
113+
114+
// CHECK-LABEL: ; abi_handling::df5
115+
// CHECK-NEXT: Function Attrs
116+
// debug-NEXT: define internal { float, float }
117+
// debug-SAME: (float %i.0, float %i.1, float %bi_0.0, float %bi_0.1)
118+
// release-NEXT: define internal fastcc { float, float }
119+
// release-SAME: (float noundef %i.0, float noundef %i.1)
120+
121+
// CHECK-LABEL: ; abi_handling::f5
122+
// CHECK-NEXT: Function Attrs
123+
// debug-NEXT: define internal float
124+
// debug-SAME: (float %i.0, float %i.1)
125+
// release-NEXT: define internal fastcc noundef float
126+
// release-SAME: (float noundef %i.0, float noundef %i.1)
127+
#[autodiff_forward(df5, Dual, Dual)]
128+
#[inline(never)]
129+
fn f5(i: Input) -> f32 {
130+
i.x + i.y
131+
}
132+
133+
// CHECK-LABEL: ; abi_handling::df6
134+
// CHECK-NEXT: Function Attrs
135+
// debug-NEXT: define internal { float, float }
136+
// debug-SAME: (float %i.0, float %i.1, float %bi_0.0, float %bi_0.1)
137+
// release-NEXT: define internal fastcc { float, float }
138+
// release-SAME: float noundef %i.0, float noundef %i.1
139+
// release-SAME: float noundef %bi_0.0, float noundef %bi_0.1
140+
141+
// CHECK-LABEL: ; abi_handling::f6
142+
// CHECK-NEXT: Function Attrs
143+
// debug-NEXT: define internal float
144+
// debug-SAME: (float %i.0, float %i.1)
145+
// release-NEXT: define internal fastcc noundef float
146+
// release-SAME: (float noundef %i.0, float noundef %i.1)
147+
#[autodiff_forward(df6, Dual, Dual)]
148+
#[inline(never)]
149+
fn f6(i: NestedInput) -> f32 {
150+
i.x + i.y.z * i.y.z
151+
}
152+
153+
// CHECK-LABEL: ; abi_handling::df7
154+
// CHECK-NEXT: Function Attrs
155+
// debug-NEXT: define internal { float, float }
156+
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1, ptr align 4 %bx_0.0, ptr align 4 %bx_0.1)
157+
// release-NEXT: define internal fastcc { float, float }
158+
// release-SAME: (float %x.0.0.val, float %x.1.0.val)
159+
160+
// CHECK-LABEL: ; abi_handling::f7
161+
// CHECK-NEXT: Function Attrs
162+
// debug-NEXT: define internal float
163+
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1)
164+
// release-NEXT: define internal fastcc noundef float
165+
// release-SAME: (float %x.0.0.val, float %x.1.0.val)
166+
#[autodiff_forward(df7, Dual, Dual)]
167+
#[inline(never)]
168+
fn f7(x: (&f32, &f32)) -> f32 {
169+
x.0 * x.1
170+
}
171+
172+
fn main() {
173+
let x = std::hint::black_box(2.0);
174+
let y = std::hint::black_box(3.0);
175+
let z = std::hint::black_box(4.0);
176+
static Y: f32 = std::hint::black_box(3.2);
177+
178+
let in_f1 = [x, y];
179+
dbg!(f1(&in_f1));
180+
let res_f1 = df1(&in_f1, &[1.0, 0.0]);
181+
dbg!(res_f1);
182+
183+
dbg!(f2(square, x));
184+
let res_f2 = df2(square, x, 1.0);
185+
dbg!(res_f2);
186+
187+
dbg!(f3(&x, &Y));
188+
let res_f3 = df3(&x, &Y, &1.0, &0.0);
189+
dbg!(res_f3);
190+
191+
let in_f4 = (x, y);
192+
dbg!(f4(in_f4));
193+
let res_f4 = df4(in_f4, (1.0, 0.0));
194+
dbg!(res_f4);
195+
196+
let in_f5 = Input { x, y };
197+
dbg!(f5(in_f5));
198+
let res_f5 = df5(in_f5, Input { x: 1.0, y: 0.0 });
199+
dbg!(res_f5);
200+
201+
let in_f6 = NestedInput { x, y: Wrapper { z: y } };
202+
dbg!(f6(in_f6));
203+
let res_f6 = df6(in_f6, NestedInput { x, y: Wrapper { z } });
204+
dbg!(res_f6);
205+
206+
let in_f7 = (&x, &y);
207+
dbg!(f7(in_f7));
208+
let res_f7 = df7(in_f7, (&1.0, &0.0));
209+
dbg!(res_f7);
210+
}

tests/ui/autodiff/zst.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
//@ build-pass
5+
6+
// Check that differentiating functions with ZST args does not break
7+
8+
#![feature(autodiff)]
9+
10+
#[core::autodiff::autodiff_forward(fd_inner, Const, Dual)]
11+
fn f(_zst: (), _x: &mut f64) {}
12+
13+
fn fd(x: &mut f64, xd: &mut f64) {
14+
fd_inner((), x, xd);
15+
}
16+
17+
fn main() {}

0 commit comments

Comments
 (0)