Skip to content

Commit 3a62e50

Browse files
POI-WXyewentao256
andauthored
[Ascend] Wx/fix rope and repeat (#1156)
* fix rope for ascend speed * add dim check for rope * optimize rope using cat instead of repeat * reimpl repeat and fix for corner case * update device config for sum on ascend --------- Co-authored-by: yewentao <[email protected]>
1 parent b01c8b5 commit 3a62e50

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

impl/ascend/device_configs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,12 @@
267267
),
268268
),
269269

270+
'reduce_partial_op': dict(
271+
name=['sum'],
272+
atol=1e-3,
273+
rtol=1e-4,
274+
),
275+
270276
'reduce_partial_op_1': dict(
271277
name=['std'],
272278
tensor_para=dict(

impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ DIOPI_API diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTenso
3535
}
3636

3737
BEGIN_CALL_ACL_OP(out, x, cos, sin);
38+
TORCH_CHECK(xAt.size(-1) == 2 * cosAt.size(-1) && xAt.size(-1) == 2 * sinAt.size(-1),
39+
"The size of the last dimension of x must be twice the size of the corresponding dimensions of cos and sin!");
3840
if (xAt.numel() == 0) {
3941
END_CALL_ACL_OP();
4042
}
@@ -49,10 +51,12 @@ DIOPI_API diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTenso
4951
at::Tensor outView = viewAs4D(outAt);
5052
at::Tensor cosView = viewAs4D(cosAt);
5153
at::Tensor sinView = viewAs4D(sinAt);
52-
at::Tensor cosRepeated = op_api::repeat(cosView, {1, 1, 1, 2});
53-
at::Tensor sinRepeated = op_api::repeat(sinView, {1, 1, 1, 2});
54+
// To meet the ascend kernel requirement: the last dimension size of cos and sin is the same as the dimension size corresponding to input, use cat op to
55+
// concatenate in the last dimension.
56+
at::Tensor cosCat = op_api::cat({cosView, cosView}, -1);
57+
at::Tensor sinCat = op_api::cat({sinView, sinView}, -1);
5458
if (conj) {
55-
op_api::neg_(sinRepeated);
59+
op_api::neg_(sinCat);
5660
}
5761

5862
// According to API document
@@ -63,7 +67,7 @@ DIOPI_API diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTenso
6367

6468
std::vector<at::Tensor> chunkResult = xView.chunk(2, -1);
6569
at::Tensor xNew = op_api::cat({chunkResult[1] * (-1), chunkResult[0]}, -1);
66-
at::Tensor result = op_api::mul(cosRepeated, xView) + op_api::mul(sinRepeated, xNew);
70+
at::Tensor result = op_api::mul(cosCat, xView) + op_api::mul(sinCat, xNew);
6771
outView.copy_(result);
6872

6973
END_CALL_ACL_OP();

impl/ascend_npu/diopi_impl/repeat.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,24 @@
55
*/
66

77
#include "helper.hpp"
8-
#include "op_plugin/AclOpsInterface.h"
8+
#include "op_plugin/OpApiInterface.h"
9+
#include "op_plugin/utils/op_api_common.h"
910

1011
extern "C" {
1112
diopiError_t diopiRepeat(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiSize_t repeatSize) {
1213
BEGIN_CALL_ACL_OP(out, input, repeatSize);
13-
std::vector<int64_t> inputShape(inputAt.sizes().cbegin(), inputAt.sizes().cend());
14-
15-
if (inputShape.size() < repeatSize.len) {
16-
while (inputShape.size() < repeatSize.len) {
17-
inputShape.insert(inputShape.begin(), 1);
18-
}
19-
20-
inputAt = impl::aten::viewStorage(inputAt, inputShape);
14+
TORCH_CHECK(inputAt.dim() <= repeatSize.len, "repeats size should not be smaller than input tensor dim on ascend!");
15+
// When repeatSize.len is equal to 0, out is the same as input.
16+
if (repeatSize.len == 0) {
17+
outAt.copy_(inputAt);
18+
END_CALL_ACL_OP();
2119
}
2220

23-
at_npu::native::OpPreparation::markAsOutputForApplyTensor(outAt);
24-
outAt = acl_op::repeat(inputAt, repeatSizeAt);
21+
std::vector<int64_t> inputShape = inputAt.sizes().vec();
22+
inputShape.insert(inputShape.begin(), repeatSize.len - inputAt.dim(), 1);
23+
inputAt = impl::aten::viewStorage(inputAt, inputShape);
24+
25+
EXEC_NPU_CMD(aclnnRepeat, inputAt, repeatSizeAt, outAt);
2526
END_CALL_ACL_OP();
2627
}
2728

0 commit comments

Comments
 (0)