Skip to content

Commit 4bae507

Browse files
support op dispatch
1 parent 50efa07 commit 4bae507

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

csrc/extensions.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include "diopi_helper.h"
2929
#include "pybind_type_cast.h"
30+
#include "torch/library.h"
3031

3132
namespace dipu::dipu_ext {
3233

@@ -363,4 +364,57 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
363364
}
364365
}
365366

367+
at::Tensor& apply_penalty(at::Tensor& logits, const at::Tensor& presence_penalty,
368+
const at::Tensor& frequency_penalty,
369+
const at::Tensor& p_token_ids,
370+
const at::Tensor& p_token_counts,
371+
const at::Tensor& p_cumsum_seq_len,
372+
int64_t p_max_len_in_batch) {
373+
callDiopi(diopiApplyPenalty, logits, presence_penalty, frequency_penalty,
374+
p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch);
375+
return logits;
376+
}
377+
378+
at::Tensor& dest_index_copy_kv(const at::Tensor& k, const at::Tensor& dest_loc,
379+
at::Tensor& out) {
380+
callDiopi(diopiDestIndexCopyKV, out, k, dest_loc);
381+
return out;
382+
}
383+
384+
TORCH_LIBRARY(ops, m) {
385+
//m.def("adamw(Tensor(a!) input, Tensor(b!) grad, Tensor(c!) exp_avg, Tensor(d!) exp_avg_sq, Tensor(e!) max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay, int step, bool amsgrad)->(Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))");
386+
m.def("apply_penalty(Tensor(a!) logits, Tensor presence_penalty, Tensor frequency_penalty, Tensor p_token_ids, Tensor p_token_counts, Tensor p_cumsum_seq_len, int p_max_len_in_batch)->Tensor(a!)");
387+
m.def("dest_index_copy_kv(Tensor(a!) out, Tensor k, Tensor dest_loc)->Tensor(a!)");
388+
}
389+
390+
// impl for dipu
391+
TORCH_LIBRARY_IMPL(ops, XPU, m) {
392+
if (reinterpret_cast<void*>(diopiApplyPenalty) != nullptr) {
393+
m.impl("apply_penalty", apply_penalty);
394+
}
395+
if (reinterpret_cast<void*>(diopiDestIndexCopyKV) != nullptr) {
396+
m.impl("dest_index_copy_kv", dest_index_copy_kv);
397+
}
398+
}
399+
400+
// impl for torch
401+
TORCH_LIBRARY_IMPL(ops, CUDA, m) {
402+
if (reinterpret_cast<void*>(diopiApplyPenalty) != nullptr) {
403+
m.impl("apply_penalty", apply_penalty);
404+
}
405+
if (reinterpret_cast<void*>(diopiDestIndexCopyKV) != nullptr) {
406+
m.impl("dest_index_copy_kv", dest_index_copy_kv);
407+
}
408+
}
409+
410+
// impl for torch_npu
411+
TORCH_LIBRARY_IMPL(ops, PrivateUse1, m) {
412+
if (reinterpret_cast<void*>(diopiApplyPenalty) != nullptr) {
413+
m.impl("apply_penalty", apply_penalty);
414+
}
415+
if (reinterpret_cast<void*>(diopiDestIndexCopyKV) != nullptr) {
416+
m.impl("dest_index_copy_kv", dest_index_copy_kv);
417+
}
418+
}
419+
366420
} // namespace dipu::dipu_ext

0 commit comments

Comments
 (0)