|
27 | 27 |
|
28 | 28 | #include "diopi_helper.h"
|
29 | 29 | #include "pybind_type_cast.h"
|
| 30 | +#include "torch/library.h" |
30 | 31 |
|
31 | 32 | namespace dipu::dipu_ext {
|
32 | 33 |
|
@@ -363,4 +364,57 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
363 | 364 | }
|
364 | 365 | }
|
365 | 366 |
|
| 367 | +at::Tensor& apply_penalty(at::Tensor& logits, const at::Tensor& presence_penalty, |
| 368 | + const at::Tensor& frequency_penalty, |
| 369 | + const at::Tensor& p_token_ids, |
| 370 | + const at::Tensor& p_token_counts, |
| 371 | + const at::Tensor& p_cumsum_seq_len, |
| 372 | + int64_t p_max_len_in_batch) { |
| 373 | + callDiopi(diopiApplyPenalty, logits, presence_penalty, frequency_penalty, |
| 374 | + p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch); |
| 375 | + return logits; |
| 376 | +} |
| 377 | + |
| 378 | +at::Tensor& dest_index_copy_kv(const at::Tensor& k, const at::Tensor& dest_loc, |
| 379 | + at::Tensor& out) { |
| 380 | + callDiopi(diopiDestIndexCopyKV, out, k, dest_loc); |
| 381 | + return out; |
| 382 | +} |
| 383 | + |
| 384 | +TORCH_LIBRARY(ops, m) { |
| 385 | + //m.def("adamw(Tensor(a!) input, Tensor(b!) grad, Tensor(c!) exp_avg, Tensor(d!) exp_avg_sq, Tensor(e!) max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay, int step, bool amsgrad)->(Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))"); |
| 386 | + m.def("apply_penalty(Tensor(a!) logits, Tensor presence_penalty, Tensor frequency_penalty, Tensor p_token_ids, Tensor p_token_counts, Tensor p_cumsum_seq_len, int p_max_len_in_batch)->Tensor(a!)"); |
| 387 | + m.def("dest_index_copy_kv(Tensor(a!) out, Tensor k, Tensor dest_loc)->Tensor(a!)"); |
| 388 | +} |
| 389 | + |
| 390 | +// impl for dipu |
| 391 | +TORCH_LIBRARY_IMPL(ops, XPU, m) { |
| 392 | + if (reinterpret_cast<void*>(diopiApplyPenalty) != nullptr) { |
| 393 | + m.impl("apply_penalty", apply_penalty); |
| 394 | + } |
| 395 | + if (reinterpret_cast<void*>(diopiDestIndexCopyKV) != nullptr) { |
| 396 | + m.impl("dest_index_copy_kv", dest_index_copy_kv); |
| 397 | + } |
| 398 | +} |
| 399 | + |
| 400 | +// impl for torch |
| 401 | +TORCH_LIBRARY_IMPL(ops, CUDA, m) { |
| 402 | + if (reinterpret_cast<void*>(diopiApplyPenalty) != nullptr) { |
| 403 | + m.impl("apply_penalty", apply_penalty); |
| 404 | + } |
| 405 | + if (reinterpret_cast<void*>(diopiDestIndexCopyKV) != nullptr) { |
| 406 | + m.impl("dest_index_copy_kv", dest_index_copy_kv); |
| 407 | + } |
| 408 | +} |
| 409 | + |
| 410 | +// impl for torch_npu |
| 411 | +TORCH_LIBRARY_IMPL(ops, PrivateUse1, m) { |
| 412 | + if (reinterpret_cast<void*>(diopiApplyPenalty) != nullptr) { |
| 413 | + m.impl("apply_penalty", apply_penalty); |
| 414 | + } |
| 415 | + if (reinterpret_cast<void*>(diopiDestIndexCopyKV) != nullptr) { |
| 416 | + m.impl("dest_index_copy_kv", dest_index_copy_kv); |
| 417 | + } |
| 418 | +} |
| 419 | + |
366 | 420 | } // namespace dipu::dipu_ext
|
0 commit comments