Skip to content

Commit fe68509

Browse files
authored
[Ascend] Wx/reimpl multinomial op (#1218)
* reimpl multinomial op
1 parent 4ee62cd commit fe68509

File tree

2 files changed

+4
-11
lines changed

2 files changed

+4
-11
lines changed

impl/ascend/functions/multinomial.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,15 @@
44
* @copyright (c) 2023, DeepLink.
55
*/
66

7+
#include "../aclnn/adaptor.hpp"
78
#include "../common/acloprunner.hpp"
89

910
namespace impl {
1011
namespace ascend {
1112
diopiError_t diopiMultinomial(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, int64_t numSamples, bool replacement,
1213
diopiGeneratorHandle_t generator) {
13-
auto pair = getSeedAndOffset(ctx, generator, 10);
14-
AclOpRunner<3, 1>("MultinomialWithReplacement", ctx)
15-
.addInput(input)
16-
.addConstInput(pair.first, diopi_dtype_int64)
17-
.addConstInput(pair.second, diopi_dtype_int64)
18-
.setAttr("numsamples", numSamples)
19-
.setAttr("replacement", replacement)
20-
.addOutput(out)
21-
.run();
14+
std::pair<uint64_t, int64_t> pair = getSeedAndOffset(ctx, generator, 10);
15+
DIOPI_ASCEND_CALL_ACLNN(aclnnMultinomial, ctx, input, numSamples, replacement, static_cast<int64_t>(pair.first), pair.second, out);
2216
return diopiSuccess;
2317
}
2418

impl/ascend_npu/ascend_config.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ ascend:
154154
- diopiMulInp
155155
- diopiMulInpScalar
156156
- diopiMulScalar
157+
- diopiMultinomial
157158
- diopiMin
158159
- diopiMinAll
159160
- diopiMinimum
@@ -235,9 +236,7 @@ ascend:
235236
- diopiZeros
236237
ascend_npu:
237238
- diopiNonzero
238-
- diopiCat
239239
- diopiCopyInp
240-
- diopiMultinomial
241240
- diopiRotaryEmbedding
242241
- diopiIndexPut
243242
- diopiIndexPutInp

0 commit comments

Comments
 (0)