Skip to content

Commit 34d8e64

Browse files
xinyazhangpytorchmergebot
authored andcommitted
[ROCm] Bump AOTriton to 0.10b (pytorch#156290)
Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.10b: * Official support of gfx950/gfx1201 * Experimental support of gfx1101/gfx1151/gfx1150/gfx1200 * Reduce libaotriton.so binary size by over 80%. + Without this optimization the binary size of `libaotriton.so` could be over 100MiB due to 2x more supported architectures compared with 0.9b. Now it is only about 11MiB. * Support sliding window attention (SWA) in `_flash_attention_forward/backward`. Should fix pytorch#154582 See https://github.com/ROCm/aotriton/releases/tag/0.10b for full details, including Known Problems. Notable changes to SDPA backend: * `std::optional<int64_t>` `window_size_left/right` are directly passed to ROCM's SDPA backend, because the default value `-1` is meaningful to AOTriton's backend and bottom-right aligned causal mask is implemented with negative `window_size_left/right` * Some code clean up around `USE_CK_FLASH_ATTENTION` Pull Request resolved: pytorch#156290 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
1 parent 3644b41 commit 34d8e64

File tree

7 files changed

+368
-241
lines changed

7 files changed

+368
-241
lines changed

aten/src/ATen/native/transformers/cuda/attention.cu

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,8 +1113,10 @@ _flash_attention_forward(
11131113
std::optional<Tensor> alibi_slopes = _alibi_slopes;
11141114
const float softcap = 0.0;
11151115

1116-
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
1117-
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
1116+
#ifndef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly.
1117+
const int non_null_window_left = window_size_left.value_or(-1);
1118+
const int non_null_window_right = window_size_right.value_or(-1);
1119+
#endif
11181120

11191121
// We are going to have two paths:
11201122
// 1. The standard MHA path for dense tensors
@@ -1151,8 +1153,13 @@ _flash_attention_forward(
11511153
softmax_scale,
11521154
false /*zero_tensors*/,
11531155
is_causal,
1156+
#ifdef USE_ROCM
1157+
window_size_left,
1158+
window_size_right,
1159+
#else
11541160
non_null_window_left,
11551161
non_null_window_right,
1162+
#endif
11561163
softcap,
11571164
return_debug_mask,
11581165
std::nullopt /*gen_*/);
@@ -1175,8 +1182,13 @@ _flash_attention_forward(
11751182
dropout_p,
11761183
softmax_scale,
11771184
is_causal,
1185+
#ifdef USE_ROCM
1186+
window_size_left,
1187+
window_size_right,
1188+
#else
11781189
non_null_window_left,
11791190
non_null_window_right,
1191+
#endif
11801192
softcap,
11811193
return_debug_mask, /*return_softmax (this is used for testing)*/
11821194
std::nullopt);

aten/src/ATen/native/transformers/cuda/attention_backward.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
8787
auto contiguous_grad_out = grad_out.contiguous();
8888
auto contiguous_out = out.contiguous();
8989

90+
#ifndef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly.
9091
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
9192
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
93+
#endif
9294

9395
std::optional<at::Tensor> dq{std::nullopt};
9496
std::optional<at::Tensor> dk{std::nullopt};
@@ -136,8 +138,13 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
136138
softmax_scale,
137139
false /*zero_tensors*/,
138140
is_causal,
141+
#ifdef USE_ROCM
142+
window_size_left,
143+
window_size_right,
144+
#else
139145
non_null_window_left,
140146
non_null_window_right,
147+
#endif
141148
softcap,
142149
determinisitic,
143150
philox_seed,
@@ -159,8 +166,13 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
159166
dropout_p,
160167
softmax_scale,
161168
is_causal,
169+
#ifdef USE_ROCM
170+
window_size_left,
171+
window_size_right,
172+
#else
162173
non_null_window_left,
163174
non_null_window_right,
175+
#endif
164176
softcap,
165177
determinisitic,
166178
philox_seed,

0 commit comments

Comments
 (0)