Skip to content

Commit 20d83c4

Browse files
author
Yin Hongyun
committed
add group norm back
1 parent a6dbbb6 commit 20d83c4

File tree

4 files changed

+125
-8
lines changed

4 files changed

+125
-8
lines changed

diopi_test/python/configs/diopi_configs.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7209,23 +7209,23 @@
72097209
atol=1e-4,
72107210
rtol=1e-5,
72117211
para=dict(
7212-
num_groups=[32, 4, 5, 1],
7213-
eps=[1e-05, 1e-05, 1e-05, 1e-05],
7214-
reduced_axes = [[2, 3], [1, 3], [0, 3], [2, 3]],
7215-
channel_axis = [1, 2, 1, 0]
7212+
num_groups=[32],
7213+
eps=[1e-05],
7214+
reduced_axes = [[2, 3]],
7215+
channel_axis = [1]
72167216
),
72177217
tensor_para=dict(
72187218
args=[
72197219
{
72207220
"ins": ["input"],
7221-
"shape": ((2, 256, 7, 10), (2, 256, 12, 12),
7222-
(12, 15, 8, 9),(3, 6, 9, 0)),
7221+
"requires_grad": [True],
7222+
"shape": ((2, 256, 12, 10),),
72237223
"dtype": [np.float32, np.float64, np.float16],
72247224
},
72257225
{
72267226
"ins": ["weight", "bias"],
7227-
"shape": ((256,), (12,),
7228-
(15,), (3,)),
7227+
"requires_grad": [True],
7228+
"shape": ((256,),),
72297229
"dtype": [np.float32, np.float64, np.float16],
72307230
},
72317231
]

diopi_test/python/conformance/diopi_functions.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5275,6 +5275,48 @@ def group_norm_GB(input, num_groups, weight=None, bias=None, eps=1e-05, reduced_
52755275
GLOBAL_STATE["group_norm_GB_save_invstd"] = save_invstd
52765276
return out
52775277

5278+
5279+
def group_norm_GB_backward(
5280+
input,
5281+
grad_outputs,
5282+
num_groups,
5283+
weight=None,
5284+
bias=None,
5285+
eps=1e-05,
5286+
reduced_axes=[2, 3],
5287+
channel_axis=1,
5288+
**kwargs,
5289+
) -> Tensor:
5290+
assert len(grad_outputs) == 1, "only accept 1 gradient to do backward"
5291+
save_mean = GLOBAL_STATE.pop("group_norm_GB_save_mean")
5292+
save_invstd = GLOBAL_STATE.pop("group_norm_GB_save_invstd")
5293+
grad_input = raw_like(input)
5294+
grad_weight = raw_like(weight)
5295+
grad_bias = raw_like(bias)
5296+
weight = None if weight is None else weight
5297+
bias = None if bias is None else bias
5298+
5299+
out = {"input": grad_input, "weight": grad_weight, "bias": grad_bias}
5300+
func = check_function("diopiGroupNormGBBackward")
5301+
reduced_axes = Sizes(reduced_axes)
5302+
ret = func(
5303+
input.context(),
5304+
grad_input,
5305+
grad_weight,
5306+
grad_bias,
5307+
grad_outputs[0],
5308+
input,
5309+
weight,
5310+
save_mean,
5311+
save_invstd,
5312+
num_groups,
5313+
reduced_axes,
5314+
channel_axis,
5315+
)
5316+
check_returncode(ret)
5317+
return {k: v for k, v in out.items() if v.requires_grad}
5318+
5319+
52785320
def group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
52795321
dim = list(input.size().data)
52805322
save_mean = Tensor((dim[0], num_groups), input.get_dtype())

impl/torch/functions/functions.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4240,6 +4240,74 @@ diopiError_t diopiGroupNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out,
42404240
return diopiSuccess;
42414241
}
42424242

4243+
diopiError_t diopiGroupNormGBBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight, diopiTensorHandle_t grad_bias,
4244+
diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight,
4245+
diopiConstTensorHandle_t mean, diopiConstTensorHandle_t rstd, int64_t num_groups, diopiSize_t reduced_axes, const int64_t channel_axis) {
4246+
impl::aten::setCurStream(ctx);
4247+
auto atGradOutput = impl::aten::buildATen(grad_output);
4248+
auto atInput = impl::aten::buildATen(input);
4249+
auto atWeight = impl::aten::buildATen(weight);
4250+
auto atSaveMean = impl::aten::buildATen(mean);
4251+
auto atSaveVar = impl::aten::buildATen(rstd);
4252+
auto atGradWeight = impl::aten::buildATen(grad_weight);
4253+
auto atGradBias = impl::aten::buildATen(grad_bias);
4254+
auto axisSize = atInput.size(channel_axis);
4255+
auto k = axisSize / num_groups;
4256+
at::IntArrayRef atReducedAxes = impl::aten::buildAtIntArray(reduced_axes);
4257+
std::vector<int64_t> dims;
4258+
int64_t N = 1;
4259+
for (int i = 0; i < atInput.dim(); i++) {
4260+
if (i == channel_axis) {
4261+
continue;
4262+
} else {
4263+
bool is_reduced_axis = false;
4264+
for (int m = 0; m < reduced_axes.len; m++) {
4265+
if (i == reduced_axes.data[m]) {
4266+
is_reduced_axis = true;
4267+
break;
4268+
}
4269+
}
4270+
if (is_reduced_axis) {
4271+
continue;
4272+
} else {
4273+
dims.push_back(i);
4274+
N *= atInput.size(i);
4275+
}
4276+
}
4277+
}
4278+
dims.push_back(channel_axis);
4279+
int64_t HxW = 1;
4280+
for(auto i = 0; i < reduced_axes.len; i++) {
4281+
dims.push_back(reduced_axes.data[i]);
4282+
HxW *= atInput.size(reduced_axes.data[i]);
4283+
}
4284+
auto C = atInput.size(channel_axis);
4285+
auto permutedInput = atInput.permute(dims);
4286+
auto permutedShape = permutedInput.sizes();
4287+
auto reshapedInput = permutedInput.reshape({N, C, HxW, 1}).contiguous();
4288+
4289+
std::vector<int64_t> reverse_order(dims.size());
4290+
for (auto i = 0; i < atInput.dim(); i++) {
4291+
reverse_order[dims[i]] = i;
4292+
}
4293+
4294+
if (grad_weight && grad_bias) {
4295+
auto atGradInput = impl::aten::buildATen(grad_input).permute(dims).reshape({N, C, HxW, 1});
4296+
4297+
at::native_group_norm_backward_out(
4298+
atGradInput, atGradWeight, atGradBias, atGradOutput.permute(dims).reshape({N, C, HxW, 1}), reshapedInput, atSaveMean, atSaveVar, atWeight, N, C, HxW, num_groups, {true, true, true});
4299+
atGradInput = atGradInput.reshape(permutedShape).permute(reverse_order);
4300+
} else {
4301+
auto atOuts = at::native_group_norm_backward(
4302+
atGradOutput.permute(dims).reshape({N, C, HxW, 1}), reshapedInput, atSaveMean, atSaveVar, atWeight, N, C, HxW, num_groups, {true, grad_weight != nullptr, grad_bias != nullptr});
4303+
impl::aten::updateATen2Tensor(ctx, std::get<0>(atOuts).reshape(permutedShape).permute(reverse_order), grad_input);
4304+
impl::aten::updateATen2Tensor(ctx, std::get<1>(atOuts), grad_weight);
4305+
impl::aten::updateATen2Tensor(ctx, std::get<2>(atOuts), grad_bias);
4306+
}
4307+
4308+
return diopiSuccess;
4309+
}
4310+
42434311
diopiError_t diopiGroupNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
42444312
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
42454313
double eps) {

proto/include/diopi/functions.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3607,6 +3607,13 @@ DIOPI_API diopiError_t diopiGroupNormGB(diopiContextHandle_t ctx, diopiTensorHan
36073607
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
36083608
double eps, diopiSize_t reduced_axes, const int64_t channel_axis);
36093609

3610+
/**
3611+
* @brief Compute the backward pass of diopiGroupNorm().
3612+
*/
3613+
DIOPI_API diopiError_t diopiGroupNormGBBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight,
3614+
diopiTensorHandle_t grad_bias, diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input,
3615+
diopiConstTensorHandle_t weight, diopiConstTensorHandle_t mean, diopiConstTensorHandle_t rstd,
3616+
int64_t num_groups, diopiSize_t reduced_axes, const int64_t channel_axis);
36103617
/**
36113618
* @brief Compute the backward pass of diopiGroupNorm().
36123619
* @param[in] ctx Context environment.

0 commit comments

Comments
 (0)