@@ -4240,6 +4240,74 @@ diopiError_t diopiGroupNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out,
4240
4240
return diopiSuccess;
4241
4241
}
4242
4242
4243
+ diopiError_t diopiGroupNormGBBackward (diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight, diopiTensorHandle_t grad_bias,
4244
+ diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight,
4245
+ diopiConstTensorHandle_t mean, diopiConstTensorHandle_t rstd, int64_t num_groups, diopiSize_t reduced_axes, const int64_t channel_axis) {
4246
+ impl::aten::setCurStream (ctx);
4247
+ auto atGradOutput = impl::aten::buildATen (grad_output);
4248
+ auto atInput = impl::aten::buildATen (input);
4249
+ auto atWeight = impl::aten::buildATen (weight);
4250
+ auto atSaveMean = impl::aten::buildATen (mean);
4251
+ auto atSaveVar = impl::aten::buildATen (rstd);
4252
+ auto atGradWeight = impl::aten::buildATen (grad_weight);
4253
+ auto atGradBias = impl::aten::buildATen (grad_bias);
4254
+ auto axisSize = atInput.size (channel_axis);
4255
+ auto k = axisSize / num_groups;
4256
+ at::IntArrayRef atReducedAxes = impl::aten::buildAtIntArray (reduced_axes);
4257
+ std::vector<int64_t > dims;
4258
+ int64_t N = 1 ;
4259
+ for (int i = 0 ; i < atInput.dim (); i++) {
4260
+ if (i == channel_axis) {
4261
+ continue ;
4262
+ } else {
4263
+ bool is_reduced_axis = false ;
4264
+ for (int m = 0 ; m < reduced_axes.len ; m++) {
4265
+ if (i == reduced_axes.data [m]) {
4266
+ is_reduced_axis = true ;
4267
+ break ;
4268
+ }
4269
+ }
4270
+ if (is_reduced_axis) {
4271
+ continue ;
4272
+ } else {
4273
+ dims.push_back (i);
4274
+ N *= atInput.size (i);
4275
+ }
4276
+ }
4277
+ }
4278
+ dims.push_back (channel_axis);
4279
+ int64_t HxW = 1 ;
4280
+ for (auto i = 0 ; i < reduced_axes.len ; i++) {
4281
+ dims.push_back (reduced_axes.data [i]);
4282
+ HxW *= atInput.size (reduced_axes.data [i]);
4283
+ }
4284
+ auto C = atInput.size (channel_axis);
4285
+ auto permutedInput = atInput.permute (dims);
4286
+ auto permutedShape = permutedInput.sizes ();
4287
+ auto reshapedInput = permutedInput.reshape ({N, C, HxW, 1 }).contiguous ();
4288
+
4289
+ std::vector<int64_t > reverse_order (dims.size ());
4290
+ for (auto i = 0 ; i < atInput.dim (); i++) {
4291
+ reverse_order[dims[i]] = i;
4292
+ }
4293
+
4294
+ if (grad_weight && grad_bias) {
4295
+ auto atGradInput = impl::aten::buildATen (grad_input).permute (dims).reshape ({N, C, HxW, 1 });
4296
+
4297
+ at::native_group_norm_backward_out (
4298
+ atGradInput, atGradWeight, atGradBias, atGradOutput.permute (dims).reshape ({N, C, HxW, 1 }), reshapedInput, atSaveMean, atSaveVar, atWeight, N, C, HxW, num_groups, {true , true , true });
4299
+ atGradInput = atGradInput.reshape (permutedShape).permute (reverse_order);
4300
+ } else {
4301
+ auto atOuts = at::native_group_norm_backward (
4302
+ atGradOutput.permute (dims).reshape ({N, C, HxW, 1 }), reshapedInput, atSaveMean, atSaveVar, atWeight, N, C, HxW, num_groups, {true , grad_weight != nullptr , grad_bias != nullptr });
4303
+ impl::aten::updateATen2Tensor (ctx, std::get<0 >(atOuts).reshape (permutedShape).permute (reverse_order), grad_input);
4304
+ impl::aten::updateATen2Tensor (ctx, std::get<1 >(atOuts), grad_weight);
4305
+ impl::aten::updateATen2Tensor (ctx, std::get<2 >(atOuts), grad_bias);
4306
+ }
4307
+
4308
+ return diopiSuccess;
4309
+ }
4310
+
4243
4311
diopiError_t diopiGroupNorm (diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
4244
4312
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
4245
4313
double eps) {
0 commit comments