Skip to content

support op dispatch #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Copyright (c) 2023, DeepLink.

#include <cstdint>
#include <iostream>
#include <string>
#include <tuple>
#include <utility>

#include "torch/library.h"
#include <ATen/core/ATen_fwd.h>
#include <ATen/core/Generator.h>
#include <ATen/core/TensorBody.h>
Expand Down Expand Up @@ -363,4 +365,138 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
}
}

std::tuple<at::Tensor&, at::Tensor&, at::Tensor&> adamw(
at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq,
const c10::optional<at::Tensor>& max_exp_avg_sq_opt, const at::Tensor& grad,
double lr, double beta1, double beta2, double epsilon, double weight_decay,
int64_t step, bool amsgrad) {
// the diopiAdamW func has no "maximize" param
at::Tensor& grad_ref =
const_cast<at::Tensor&>(grad); // todo: grad is const value
at::Tensor max_exp_avg_sq_opt_value =
max_exp_avg_sq_opt.value_or(at::Tensor());
callDiopi(diopiAdamW, param, grad_ref, exp_avg, exp_avg_sq,
max_exp_avg_sq_opt_value, lr, beta1, beta2, epsilon, weight_decay,
step, amsgrad);
return std::tie(param, exp_avg, exp_avg_sq);
}

at::Tensor& apply_penalty(at::Tensor& logits,
const at::Tensor& presence_penalty,
const at::Tensor& frequency_penalty,
const at::Tensor& p_token_ids,
const at::Tensor& p_token_counts,
const at::Tensor& p_cumsum_seq_len,
int64_t p_max_len_in_batch) {
callDiopi(diopiApplyPenalty, logits, presence_penalty, frequency_penalty,
p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch);
return logits;
}

at::Tensor& dest_index_copy_kv(const at::Tensor& k, const at::Tensor& dest_loc,
at::Tensor& out) {
callDiopi(diopiDestIndexCopyKV, out, k, dest_loc);
return out;
}

std::tuple<at::Tensor&, at::Tensor&> rms_norm(
at::Tensor& output, at::Tensor& inv_rms, const at::Tensor& input,
const OptionalIntArray& normalized_shape, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt, double eps) {
callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape, weight,
bias_opt, eps);
return std::tie(output, inv_rms);
}

std::tuple<at::Tensor&, at::Tensor&, at::Tensor&> rms_norm_backward(
at::Tensor& grad_input, at::Tensor& grad_weight, at::Tensor& grad_bias_opt,
const at::Tensor& grad_output, const at::Tensor& input,
const at::Tensor& weight, const c10::optional<at::Tensor>& bias_opt,
const at::Tensor& inv_rms, const OptionalIntArray& normalized_shape,
double eps) {
callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias_opt,
grad_output, input, weight, bias_opt, inv_rms, normalized_shape,
eps);
return std::tie(grad_input, grad_weight, grad_bias_opt);
}

at::Tensor& apply_rotary(at::Tensor& output, const at::Tensor& input,
const at::Tensor& cos, const at::Tensor& sin,
const bool conj, const bool interleaved) {
callDiopi(diopiRotaryEmbedding, output, input, cos, sin, conj, interleaved);
return output;
}

at::Tensor& example_for_all_backend(at::Tensor& inout) {
std::cout << __FUNCTION__ << ": " << inout.options() << "\n";
return inout;
}

at::Tensor& example_only_for_xpu(at::Tensor& inout) {
std::cout << __FUNCTION__ << ": " << inout.options() << "\n";
return inout;
}

// By default, all backends (XPU, AutocastXPU, AutoGradXPU, CUDA, PrivateUse1,
// AutogradPrivateUse1 etc) are registered. If you need to register separately
// for a certain backend, separate registration for a certain backend is also
// supported.
TORCH_LIBRARY(deeplink_ext_, m) {
if (&diopiAdamW != nullptr) {
m.def(
"adamw(Tensor(a!) param, Tensor(b!) exp_avg, Tensor(c!) exp_avg_sq, "
"Tensor? max_exp_avg_sq_opt, Tensor grad, float lr, float beta1, float "
"beta2, float epsilon, float weight_decay, int step, bool "
"amsgrad)->(Tensor(a!), Tensor(b!), Tensor(c!))",
adamw);
}
if (&diopiApplyPenalty != nullptr) {
m.def(
"apply_penalty(Tensor(a!) logits, Tensor presence_penalty, Tensor "
"frequency_penalty, Tensor p_token_ids, Tensor p_token_counts, Tensor "
"p_cumsum_seq_len, int p_max_len_in_batch)->Tensor(a!)",
apply_penalty);
}
if (&diopiDestIndexCopyKV != nullptr) {
m.def(
"dest_index_copy_kv(Tensor(a!) out, Tensor k, Tensor "
"dest_loc)->Tensor(a!)",
dest_index_copy_kv);
}
if (&diopiDestIndexCopyKV != nullptr) {
m.def(
"rms_norm(Tensor(a!) output, Tensor(b!) inv_rms, Tensor input, int[]? "
"normalized_shape, Tensor weight, Tensor? bias_opt, float eps) -> "
"(Tensor(a!), Tensor(b!))",
rms_norm);
}

if (&diopiRMSNormBackward != nullptr) {
m.def(
"rms_norm_backward(Tensor(a!) grad_input, Tensor(b!) grad_weight, "
"Tensor(c!) grad_bias_opt, Tensor grad_output, Tensor input, Tensor "
"weight, Tensor? bias_opt, Tensor inv_rms, int[]? normalized_shape, "
"float eps) -> (Tensor(a!), Tensor(b!), Tensor(c!))",
rms_norm_backward);
}
if (&diopiRotaryEmbedding != nullptr) {
m.def(
"apply_rotary(Tensor(a!) output, Tensor input, Tensor cos, Tensor sin, "
"bool conj, bool interleaved) -> Tensor(a!)",
apply_rotary);
}

m.def("example(Tensor(a!) inout)->Tensor(a!)", example_for_all_backend);
}

// only impl for dipu
TORCH_LIBRARY_IMPL(deeplink_ext_, XPU, m) {
// m.impl("example", example_only_for_xpu);
}

int n = [](){
std::cout << "deeplink_ext_ loaded" << std::endl;
return 0;
}();

} // namespace dipu::dipu_ext
16 changes: 15 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) 2024, DeepLink.

from setuptools import find_packages, setup, Extension
from torch.utils.cpp_extension import BuildExtension, include_paths, library_paths
from torch.utils.cpp_extension import BuildExtension, CppExtension, include_paths, library_paths

import glob
import os
import subprocess
Expand Down Expand Up @@ -86,3 +87,16 @@ def get_ext():
cmdclass={"build_ext": BuildExtensionWithCompdb},
install_requires=["einops"],
)


setup(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请教下 这个位置引入cppextension 是为了使用”TORCH_LIBRARY“吗 我记得BuildExtension和CppExtension 是在cuda 编译才用到, 如果在其他设备上用这种方式编译会不会有问题?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setup里面的只是个草稿,还有待完善。cppextension不光光可以编译cuda kernel,也可以只编译cpp文件的。

name='deeplink_ext_ops',
ext_modules=[
CppExtension(
name='deeplink_ext_ops',
sources=glob.glob("./csrc/*.cpp"),
extra_compile_args=[' -g ']),
],
cmdclass={
'build_ext': BuildExtension
})
26 changes: 26 additions & 0 deletions test_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import torch_dipu
import deeplink_ext

so_path = deeplink_ext.__path__[0] + "/cpp_extensions.cpython-39-x86_64-linux-gnu.so"
torch.ops.load_library(so_path)
print(f"torch.ops.loaded_libraries:{torch.ops.loaded_libraries}")

#print(torch.ops.deeplink_ext_.dest_index_copy_kv)

def code_to_profile():
x = torch.randn(3,4)
y = torch.ops.deeplink_ext_.example(x)
y = torch.ops.deeplink_ext_.example(x.cuda())


with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
) as p:
code_to_profile()
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))