Skip to content

Commit 8fcc8e9

Browse files
author
Yin Hongyun
committed
[feat] add isinf、trunc、round、hardsigmoid、elu、threshold
1 parent 3e82955 commit 8fcc8e9

File tree

4 files changed

+236
-0
lines changed

4 files changed

+236
-0
lines changed

diopi_test/python/configs/diopi_configs.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,110 @@
33

44

55
diopi_configs = {
6+
'has_inf': dict(
7+
name=["isinf"],
8+
interface=["torch"],
9+
atol=1e-3,
10+
rtol=1e-4,
11+
tensor_para=dict(
12+
args=[
13+
{
14+
"ins": ['input'],
15+
"shape": ((), (1024,), (2, 4096), (64, 28, 28),
16+
(32, 64, 112, 112), (64, 3, 7, 28, 28),
17+
(0,), (256, 0), (8, 0, 128)),
18+
"dtype": [np.float16, np.float32, np.float64,
19+
np.int16, np.int32, np.int64,
20+
np.uint8, np.int8],
21+
},
22+
],
23+
),
24+
),
25+
26+
'trunc': dict(
27+
name=["trunc"],
28+
interface=["torch"],
29+
atol=1e-3,
30+
rtol=1e-4,
31+
tensor_para=dict(
32+
args=[
33+
{
34+
"ins": ['input'],
35+
"shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)),
36+
"dtype": [np.float32, np.float16, np.float64],
37+
},
38+
],
39+
),
40+
),
41+
42+
'round': dict(
43+
name=["round"],
44+
interface=["torch"],
45+
atol=1e-3,
46+
rtol=1e-4,
47+
tensor_para=dict(
48+
args=[
49+
{
50+
"ins": ['input'],
51+
"shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)),
52+
"dtype": [np.float32, np.float16, np.float64],
53+
},
54+
],
55+
),
56+
),
57+
58+
'round': dict(
59+
name=["hardsigmoid"],
60+
atol=1e-3,
61+
rtol=1e-4,
62+
tensor_para=dict(
63+
args=[
64+
{
65+
"ins": ['input'],
66+
"shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)),
67+
"dtype": [np.float32, np.float16, np.float64],
68+
},
69+
],
70+
),
71+
),
72+
73+
'elu': dict(
74+
name=["elu"],
75+
atol=1e-3,
76+
rtol=1e-4,
77+
para=dict(
78+
alpha=[0.234, 4.8, -10, 1.0],
79+
),
80+
tensor_para=dict(
81+
args=[
82+
{
83+
"ins": ['input'],
84+
"shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)),
85+
"dtype": [np.float32, np.float16, np.float64],
86+
},
87+
],
88+
),
89+
),
90+
91+
'threshold_relu': dict(
92+
name=["threshold"],
93+
atol=1e-3,
94+
rtol=1e-4,
95+
para=dict(
96+
threshold=[0.234, 4.8, -10, 1.0],
97+
value=[0.2, 4.2, -10, 2.0],
98+
),
99+
tensor_para=dict(
100+
args=[
101+
{
102+
"ins": ['input'],
103+
"shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)),
104+
"dtype": [np.float32, np.float16, np.float64],
105+
},
106+
],
107+
),
108+
),
109+
6110
# FIXME batch_norm输入0size的张量报错
7111
'batch_norm': dict(
8112
name=["batch_norm"],

diopi_test/python/conformance/diopi_functions.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,51 @@ def promote_type(input: Tensor, promoted_dtype: Dtype) -> Dtype:
224224
]
225225
return dtype1 if dtype1 not in need_promote_types else promoted_dtype
226226

227+
def isinf(input) -> Tensor:
228+
func = check_function("diopiHasInf")
229+
out = Tensor(size=input.size(), dtype=Dtype.bool)
230+
ret = func(input.context(), out, input)
231+
check_returncode(ret)
232+
return out
233+
234+
def trunc(input) -> Tensor:
235+
func = check_function("diopiTrunc")
236+
out = Tensor(size=input.size(), dtype=input.get_dtype())
237+
ret = func(input.context(), out, input)
238+
check_returncode(ret)
239+
return out
240+
241+
def round(input) -> Tensor:
242+
func = check_function("diopiTRound")
243+
out = Tensor(size=input.size(), dtype=input.get_dtype())
244+
ret = func(input.context(), out, input)
245+
check_returncode(ret)
246+
return out
247+
248+
def hardsigmoid(input) -> Tensor:
249+
func = check_function("diopiHardSigmoid")
250+
out = Tensor(size=input.size(), dtype=input.get_dtype())
251+
ret = func(input.context(), out, input)
252+
check_returncode(ret)
253+
return out
254+
255+
def elu(input, alpha) -> Tensor:
256+
func = check_function("diopiElu")
257+
out = Tensor(size=input.size(), dtype=input.get_dtype())
258+
value = Scalar(alpha)
259+
ret = func(input.context(), out, input, value)
260+
check_returncode(ret)
261+
return out
262+
263+
264+
def threshold(input, threshold, value) -> Tensor:
265+
func = check_function("diopiThresholdRelu")
266+
out = Tensor(size=input.size(), dtype=input.get_dtype())
267+
threshold = Scalar(threshold)
268+
value = Scalar(value)
269+
ret = func(input.context(), out, input, threshold, value)
270+
check_returncode(ret)
271+
return out
227272

228273
def fill_(input, value):
229274
func = check_function("diopiFill")

impl/torch/functions/functions.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,63 @@ const char* diopiGetImplVersion() {
6565
return version;
6666
}
6767

68+
diopiError_t diopiHasInf(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
69+
impl::aten::setCurStream(ctx);
70+
71+
auto atInput = impl::aten::buildATen(input);
72+
auto atOut = impl::aten::buildATen(out);
73+
CALL_ATEN_FUNC(isinf_out, atOut, atInput);
74+
75+
return diopiSuccess;
76+
}
77+
78+
diopiError_t diopiTrunc(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
79+
impl::aten::setCurStream(ctx);
80+
auto atInput = impl::aten::buildATen(input);
81+
auto atOut = impl::aten::buildATen(out);
82+
CALL_ATEN_FUNC(trunc_out, atOut, atInput);
83+
return diopiSuccess;
84+
}
85+
86+
diopiError_t diopiRound(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
87+
impl::aten::setCurStream(ctx);
88+
auto atInput = impl::aten::buildATen(input);
89+
auto atOut = impl::aten::buildATen(out);
90+
CALL_ATEN_FUNC(round_out, atOut, atInput);
91+
92+
return diopiSuccess;
93+
}
94+
95+
diopiError_t diopiHardSigmoid(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
96+
impl::aten::setCurStream(ctx);
97+
auto atInput = impl::aten::buildATen(input);
98+
auto atOut = impl::aten::buildATen(out);
99+
CALL_ATEN_FUNC(hardsigmoid_out, atOut, atInput);
100+
101+
return diopiSuccess;
102+
}
103+
104+
diopiError_t diopiThresholdRelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* threshold,
105+
const diopiScalar_t* value) {
106+
impl::aten::setCurStream(ctx);
107+
auto atInput = impl::aten::buildATen(input);
108+
auto atOut = impl::aten::buildATen(out);
109+
auto atThreshold = impl::aten::buildAtScalar(threshold);
110+
auto atValue = impl::aten::buildAtScalar(value);
111+
CALL_ATEN_FUNC(threshold_out, atOut, atInput, atThreshold, atValue);
112+
113+
return diopiSuccess;
114+
}
115+
116+
diopiError_t diopiElu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* alpha) {
117+
impl::aten::setCurStream(ctx);
118+
auto atInput = impl::aten::buildATen(input);
119+
auto atOut = impl::aten::buildATen(out);
120+
auto atAlpha = impl::aten::buildAtScalar(alpha);
121+
CALL_ATEN_FUNC(elu_out, atOut, atInput, atAlpha);
122+
return diopiSuccess;
123+
}
124+
68125
diopiError_t diopiRelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
69126
impl::aten::setCurStream(ctx);
70127
auto atOut = impl::aten::buildATen(out);

proto/include/diopi/functions.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,36 @@ extern "C" {
1919
DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetVendorName();
2020
DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetImplVersion();
2121
DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetLastErrorString();
22+
/**
23+
* @brief Returns whether the input tensor contains any Inf values.
24+
*/
25+
DIOPI_API diopiError_t diopiHasInf(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);
26+
27+
/**
28+
* @brief Truncates the input tensor to an integer value.
29+
*/
30+
DIOPI_API diopiError_t diopiTrunc(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);
31+
32+
/**
33+
* @brief Rounds the input tensor to the nearest integer value.
34+
*/
35+
DIOPI_API diopiError_t diopiRound(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);
36+
37+
/**
38+
* @brief Applies the hard sigmoid activation function to an input tensor.
39+
*/
40+
DIOPI_API diopiError_t diopiHardSigmoid(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);
41+
42+
/**
43+
* @brief Applies a thresholded rectified linear unit (ReLU) activation function to an input tensor.
44+
*/
45+
DIOPI_API diopiError_t diopiThresholdRelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* threshold,
46+
const diopiScalar_t* value);
47+
48+
/**
49+
* @brief Applies the exponential linear unit (ELU) activation function to an input tensor.
50+
*/
51+
DIOPI_API diopiError_t diopiElu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* alpha);
2252

2353
/**
2454
* @brief Applies a 2D convolution over an input image composed of several input planes.

0 commit comments

Comments
 (0)