Skip to content

Commit 4c1a999

Browse files
author
Yin Hongyun
committed
[feat] add relu_backward
1 parent f61c61c commit 4c1a999

File tree

4 files changed

+36
-3
lines changed

4 files changed

+36
-3
lines changed

diopi_test/python/configs/diopi_configs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -656,12 +656,11 @@
656656
args=[
657657
{
658658
"ins": ['input'],
659+
"requires_grad": [True],
659660
"shape": ((), (1024,), (2, 4096), (64, 28, 28),
660661
(32, 64, 112, 112), (64, 3, 7, 28, 28),
661662
(0,), (256, 0), (8, 0, 128)),
662-
"dtype": [np.float16, np.float32, np.float64,
663-
np.int16, np.int32, np.int64,
664-
np.uint8, np.int8],
663+
"dtype": [np.float16, np.float32, np.float64],
665664
"gen_fn": 'Genfunc.randn',
666665
},
667666
],

diopi_test/python/conformance/diopi_functions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,15 @@ def relu(input, inplace=False) -> Tensor:
423423
return unary_op(input, inplace, "diopiRelu")
424424

425425

426+
def relu_backward(input, grad_outputs, **kwargs) -> Tensor:
427+
assert len(grad_outputs) == 1, "only accept 1 gradient to do backward"
428+
grad_input = raw_like(input)
429+
func = check_function("diopiReluBackward")
430+
ret = func(input.context(), grad_input, grad_outputs[0], input)
431+
check_returncode(ret)
432+
return {"input": grad_input} if grad_input.requires_grad else {}
433+
434+
426435
def abs(input, inplace=False) -> Tensor:
427436
return unary_op(input, inplace, "diopiAbs")
428437

impl/torch/functions/functions.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,18 @@ diopiError_t diopiRelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiC
158158
return diopiSuccess;
159159
}
160160

161+
diopiError_t diopiReluBackward(diopiContextHandle_t ctx, diopiConstTensorHandle_t grad_in, diopiTensorHandle_t grad_out, diopiConstTensorHandle_t input){
162+
impl::aten::setCurStream(ctx);
163+
164+
auto atGradOut = impl::aten::buildATen(grad_out);
165+
auto atInput = impl::aten::buildATen(input);
166+
auto atGradIn = impl::aten::buildATen(grad_in);
167+
auto mask = (atInput > 0).to(atGradOut.dtype());
168+
atGradIn.copy_(atGradOut * mask);
169+
170+
return diopiSuccess;
171+
}
172+
161173
diopiError_t diopiReluInp(diopiContextHandle_t ctx, diopiTensorHandle_t input) {
162174
impl::aten::setCurStream(ctx);
163175
auto atInput = impl::aten::buildATen(input);
@@ -4001,6 +4013,7 @@ diopiError_t diopiLinspace(diopiContextHandle_t ctx, diopiTensorHandle_t out, co
40014013
return diopiSuccess;
40024014
}
40034015

4016+
40044017
diopiError_t diopiRoll(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiSize_t shifts, diopiSize_t dims) {
40054018
impl::aten::setCurStream(ctx);
40064019
auto atInput = impl::aten::buildATen(input);

proto/include/diopi/functions.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ DIOPI_API diopiError_t diopiBatchNormBackward(diopiContextHandle_t ctx, diopiTen
236236
*/
237237
DIOPI_API diopiError_t diopiRelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);
238238

239+
/**
240+
* @brief Computes the gradient of the rectified linear unit function.
241+
*/
242+
DIOPI_API diopiError_t diopiReluBackward(diopiContextHandle_t ctx, diopiConstTensorHandle_t grad_in, diopiTensorHandle_t grad_out, diopiConstTensorHandle_t input);
243+
239244
/**
240245
* @brief The in-place version of diopiRelu().
241246
* @param[in] ctx Context environment.
@@ -701,6 +706,13 @@ DIOPI_API diopiError_t diopiAdaptiveMaxPool2dBackward(diopiContextHandle_t ctx,
701706
*/
702707
DIOPI_API diopiError_t diopiDropout(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t mask, diopiConstTensorHandle_t input, double p,
703708
bool train, diopiGeneratorHandle_t generator);
709+
710+
/**
711+
*@brief Randomly zeroes some of the elements of the input tensor with probability p
712+
*/
713+
DIOPI_API diopiError_t diopiDropout(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t mask, diopiConstTensorHandle_t input, double p,
714+
bool train, diopiGeneratorHandle_t generator);
715+
704716
/**
705717
* @brief The in-place version of diopiDropout().
706718
* @param[in] ctx Context environment.

0 commit comments

Comments
 (0)