Skip to content

Commit 1e98d59

Browse files
author
Yin Hongyun
committed
[feat] add erf back
1 parent 4c1a999 commit 1e98d59

File tree

4 files changed

+46
-0
lines changed

4 files changed

+46
-0
lines changed

diopi_test/python/configs/diopi_configs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,24 @@
666666
],
667667
),
668668
),
669+
670+
'erf': dict(
671+
name=['erf'],
672+
interface=['torch'],
673+
dtype=[np.float16, np.float32, np.float64],
674+
tensor_para=dict(
675+
gen_fn='Genfunc.randn',
676+
args=[
677+
{
678+
"ins": ['input'],
679+
"requires_grad": [True],
680+
"shape": ((), (1, ), (1024,), (364800, 4), (2, 128, 3072),
681+
(256, 128, 3, 3),
682+
(2, 31, 512, 6, 40), (0,), (16, 0)),
683+
},
684+
],
685+
),
686+
),
669687

670688
'relu_no_contiguous': dict(
671689
name=["relu"],

diopi_test/python/conformance/diopi_functions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,15 @@ def log1p(input, inplace=False) -> Tensor:
540540
return unary_op(input, inplace, "diopiLog1p", promote_type(input, Dtype.float32))
541541

542542

543+
def erf_backward(input, grad_outputs, **kwargs) -> Tensor:
544+
assert len(grad_outputs) == 1, "only accept 1 gradient to do backward"
545+
grad_input = raw_like(input)
546+
func = check_function("diopiErfBackward")
547+
ret = func(input.context(), grad_input, grad_outputs[0], input)
548+
check_returncode(ret)
549+
return {"input": grad_input} if grad_input.requires_grad else {}
550+
551+
543552
def erf(input, inplace=False) -> Tensor:
544553
return unary_op(input, inplace, "diopiErf", promote_type(input, Dtype.float32))
545554

impl/torch/functions/functions.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,6 +1600,19 @@ diopiError_t diopiErf(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiCo
16001600
return diopiSuccess;
16011601
}
16021602

1603+
diopiError_t diopiErfBackward(diopiContextHandle_t ctx, diopiConstTensorHandle_t grad_in, diopiTensorHandle_t grad_out,
1604+
diopiConstTensorHandle_t input){
1605+
impl::aten::setCurStream(ctx);
1606+
auto atGradIn = impl::aten::buildATen(grad_in);
1607+
auto atGradOut = impl::aten::buildATen(grad_out);
1608+
auto atInput = impl::aten::buildATen(input);
1609+
auto local_grad = (2.0 / std::sqrt(M_PI)) * at::exp(-atInput * atInput);
1610+
atGradIn.copy_(atGradOut * local_grad);
1611+
1612+
return diopiSuccess;
1613+
}
1614+
1615+
16031616
diopiError_t diopiErfInp(diopiContextHandle_t ctx, diopiTensorHandle_t input) {
16041617
impl::aten::setCurStream(ctx);
16051618
auto atInput = impl::aten::buildATen(input);

proto/include/diopi/functions.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ DIOPI_API diopiError_t diopiRelu(diopiContextHandle_t ctx, diopiTensorHandle_t o
241241
*/
242242
DIOPI_API diopiError_t diopiReluBackward(diopiContextHandle_t ctx, diopiConstTensorHandle_t grad_in, diopiTensorHandle_t grad_out, diopiConstTensorHandle_t input);
243243

244+
/**
245+
* @brief Comput the gradient of the error function.
246+
*/
247+
DIOPI_API diopiError_t diopiErfBackward(diopiContextHandle_t ctx, diopiConstTensorHandle_t grad_in, diopiTensorHandle_t grad_out, diopiConstTensorHandle_t input);
248+
249+
244250
/**
245251
* @brief The in-place version of diopiRelu().
246252
* @param[in] ctx Context environment.

0 commit comments

Comments
 (0)