Skip to content

Commit a6dbbb6

Browse files
author
Yin Hongyun
committed
add group_norm_GB
1 parent f09e78e commit a6dbbb6

File tree

5 files changed

+168
-4
lines changed

5 files changed

+168
-4
lines changed

diopi_test/python/configs/diopi_configs.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7202,6 +7202,35 @@
72027202
]
72037203
),
72047204
),
7205+
7206+
'group_norm_GB': dict(
7207+
name=['group_norm_GB'],
7208+
interface=['CustomizedTest'],
7209+
atol=1e-4,
7210+
rtol=1e-5,
7211+
para=dict(
7212+
num_groups=[32, 4, 5, 1],
7213+
eps=[1e-05, 1e-05, 1e-05, 1e-05],
7214+
reduced_axes = [[2, 3], [1, 3], [0, 3], [2, 3]],
7215+
channel_axis = [1, 2, 1, 0]
7216+
),
7217+
tensor_para=dict(
7218+
args=[
7219+
{
7220+
"ins": ["input"],
7221+
"shape": ((2, 256, 7, 10), (2, 256, 12, 12),
7222+
(12, 15, 8, 9),(3, 6, 9, 0)),
7223+
"dtype": [np.float32, np.float64, np.float16],
7224+
},
7225+
{
7226+
"ins": ["weight", "bias"],
7227+
"shape": ((256,), (12,),
7228+
(15,), (3,)),
7229+
"dtype": [np.float32, np.float64, np.float16],
7230+
},
7231+
]
7232+
),
7233+
),
72057234

72067235
'unique': dict(
72077236
name=['unique'],

diopi_test/python/conformance/customized_test.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -908,4 +908,42 @@ def batch_norm_GB(input, running_mean, running_var, weight, bias, training=False
908908
eps=eps,
909909
)
910910
out = out.permute(dims)
911-
return out
911+
return out
912+
913+
def group_norm_GB(input, num_groups, weight=None, bias=None, eps=1e-05, reduced_axes=[2, 3], channel_axis=1):
914+
915+
input_dims = list(input.size())
916+
reduced_axes_set = set(reduced_axes)
917+
dims = []
918+
non_reduced_dims = []
919+
920+
for i, size in enumerate(input_dims):
921+
if i == channel_axis:
922+
continue
923+
elif i in reduced_axes_set:
924+
continue
925+
else:
926+
non_reduced_dims.append(i)
927+
N = 1
928+
for i in non_reduced_dims:
929+
N = N * input.size(i)
930+
HxW = 1
931+
for i in reduced_axes:
932+
HxW = HxW * input.size(i)
933+
C = input.size(channel_axis)
934+
dims = non_reduced_dims + [channel_axis] + reduced_axes
935+
permuted_input = input.permute(dims)
936+
reshaped_input = permuted_input.reshape([N, C, HxW, 1]).contiguous()
937+
out = torch.nn.functional.group_norm(
938+
reshaped_input,
939+
num_groups,
940+
weight=weight,
941+
bias=bias,
942+
eps=eps
943+
)
944+
945+
reversed_order = [0]*len(dims)
946+
for i in range(1, len(dims)):
947+
reversed_order[dims[i]] = i
948+
return out.reshape(permuted_input.shape).permute(reversed_order)
949+

diopi_test/python/conformance/diopi_functions.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2864,8 +2864,8 @@ def batch_norm_GB(
28642864
)
28652865

28662866
check_returncode(ret)
2867-
GLOBAL_STATE["batch_norm_save_mean"] = save_mean
2868-
GLOBAL_STATE["batch_norm_save_invstd"] = save_invstd
2867+
GLOBAL_STATE["batch_norm_GB_save_mean"] = save_mean
2868+
GLOBAL_STATE["batch_norm_GB_save_invstd"] = save_invstd
28692869
return out
28702870

28712871

@@ -5242,6 +5242,38 @@ def norm_backward(grad_outputs, input, p, dim, keepdim=False, dtype=None):
52425242

52435243
return {k: v for k, v in out.items() if v.requires_grad}
52445244

5245+
def group_norm_GB(input, num_groups, weight=None, bias=None, eps=1e-05, reduced_axes=[2, 3], channel_axis=1):
5246+
dim = list(input.size().data)
5247+
N = 1
5248+
for i in range(len(dim)):
5249+
if i not in reduced_axes and i != channel_axis:
5250+
N = N * dim[i]
5251+
save_mean = Tensor((N, num_groups), input.get_dtype())
5252+
save_invstd = raw_like(save_mean)
5253+
5254+
weight = None if weight is None else weight
5255+
bias = None if bias is None else bias
5256+
5257+
reduced_axes = Sizes(reduced_axes)
5258+
out = raw_like(input)
5259+
func = check_function("diopiGroupNormGB")
5260+
ret = func(
5261+
input.context(),
5262+
out,
5263+
save_mean,
5264+
save_invstd,
5265+
input,
5266+
weight,
5267+
bias,
5268+
num_groups,
5269+
eps,
5270+
reduced_axes,
5271+
channel_axis
5272+
)
5273+
check_returncode(ret)
5274+
GLOBAL_STATE["group_norm_GB_save_mean"] = save_mean
5275+
GLOBAL_STATE["group_norm_GB_save_invstd"] = save_invstd
5276+
return out
52455277

52465278
def group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
52475279
dim = list(input.size().data)

impl/torch/functions/functions.cpp

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4183,8 +4183,66 @@ diopiError_t diopiForeachnormScalar(diopiContextHandle_t ctx, diopiTensorHandle_
41834183
return diopiSuccess;
41844184
}
41854185

4186+
diopiError_t diopiGroupNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
4187+
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
4188+
double eps, diopiSize_t reduced_axes, const int64_t channel_axis) {
4189+
impl::aten::setCurStream(ctx);
4190+
auto atInput = impl::aten::buildATen(input);
4191+
auto axisSize = atInput.size(channel_axis);
4192+
auto k = axisSize / num_groups;
4193+
at::IntArrayRef atReducedAxes = impl::aten::buildAtIntArray(reduced_axes);
4194+
std::vector<int64_t> dims;
4195+
int64_t N = 1;
4196+
for (int i = 0; i < atInput.dim(); i++) {
4197+
if (i == channel_axis) {
4198+
continue;
4199+
} else {
4200+
bool is_reduced_axis = false;
4201+
for (int m = 0; m < reduced_axes.len; m++) {
4202+
if (i == reduced_axes.data[m]) {
4203+
is_reduced_axis = true;
4204+
break;
4205+
}
4206+
}
4207+
if (is_reduced_axis) {
4208+
continue;
4209+
} else {
4210+
dims.push_back(i);
4211+
N *= atInput.size(i);
4212+
}
4213+
}
4214+
}
4215+
dims.push_back(channel_axis);
4216+
int64_t HxW = 1;
4217+
for(auto i = 0; i < reduced_axes.len; i++) {
4218+
dims.push_back(reduced_axes.data[i]);
4219+
HxW *= atInput.size(reduced_axes.data[i]);
4220+
}
4221+
auto C = atInput.size(channel_axis);
4222+
auto permutedInput = atInput.permute(dims);
4223+
auto permutedShape = permutedInput.sizes();
4224+
auto reshapedInput = permutedInput.reshape({N, C, HxW, 1}).contiguous();
4225+
4226+
auto atWeight = impl::aten::buildATen(weight);
4227+
auto atBias = impl::aten::buildATen(bias);
4228+
auto atOut = impl::aten::buildATen(out);
4229+
auto atSaveMean = impl::aten::buildATen(save_mean);
4230+
auto atSaveInvstd = impl::aten::buildATen(save_invstd);
4231+
4232+
std::vector<int64_t> reverse_order(dims.size());
4233+
for (auto i = 0; i < atInput.dim(); i++) {
4234+
reverse_order[dims[i]] = i;
4235+
}
4236+
auto tempOut = CALL_ATEN_CUDA_FUNC(native_group_norm, reshapedInput, atWeight, atBias, N, C, HxW, num_groups, eps);
4237+
at::native::copy_(atOut, std::get<0>(tempOut).reshape(permutedShape).permute(reverse_order), true);
4238+
at::native::copy_(atSaveMean, std::get<1>(tempOut), true);
4239+
at::native::copy_(atSaveInvstd, std::get<2>(tempOut), true);
4240+
return diopiSuccess;
4241+
}
4242+
41864243
diopiError_t diopiGroupNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
4187-
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups, double eps) {
4244+
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
4245+
double eps) {
41884246
impl::aten::setCurStream(ctx);
41894247
auto atInput = impl::aten::buildATen(input);
41904248
auto atWeight = impl::aten::buildATen(weight);

proto/include/diopi/functions.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3600,6 +3600,13 @@ DIOPI_API diopiError_t diopiGroupNorm(diopiContextHandle_t ctx, diopiTensorHandl
36003600
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
36013601
double eps);
36023602

3603+
/**
3604+
* @brief Applies Group Normalization over a mini-batch of inputs.
3605+
*/
3606+
DIOPI_API diopiError_t diopiGroupNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
3607+
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
3608+
double eps, diopiSize_t reduced_axes, const int64_t channel_axis);
3609+
36033610
/**
36043611
* @brief Compute the backward pass of diopiGroupNorm().
36053612
* @param[in] ctx Context environment.

0 commit comments

Comments
 (0)