@@ -4183,8 +4183,66 @@ diopiError_t diopiForeachnormScalar(diopiContextHandle_t ctx, diopiTensorHandle_
4183
4183
return diopiSuccess;
4184
4184
}
4185
4185
4186
+ diopiError_t diopiGroupNormGB (diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
4187
+ diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
4188
+ double eps, diopiSize_t reduced_axes, const int64_t channel_axis) {
4189
+ impl::aten::setCurStream (ctx);
4190
+ auto atInput = impl::aten::buildATen (input);
4191
+ auto axisSize = atInput.size (channel_axis);
4192
+ auto k = axisSize / num_groups;
4193
+ at::IntArrayRef atReducedAxes = impl::aten::buildAtIntArray (reduced_axes);
4194
+ std::vector<int64_t > dims;
4195
+ int64_t N = 1 ;
4196
+ for (int i = 0 ; i < atInput.dim (); i++) {
4197
+ if (i == channel_axis) {
4198
+ continue ;
4199
+ } else {
4200
+ bool is_reduced_axis = false ;
4201
+ for (int m = 0 ; m < reduced_axes.len ; m++) {
4202
+ if (i == reduced_axes.data [m]) {
4203
+ is_reduced_axis = true ;
4204
+ break ;
4205
+ }
4206
+ }
4207
+ if (is_reduced_axis) {
4208
+ continue ;
4209
+ } else {
4210
+ dims.push_back (i);
4211
+ N *= atInput.size (i);
4212
+ }
4213
+ }
4214
+ }
4215
+ dims.push_back (channel_axis);
4216
+ int64_t HxW = 1 ;
4217
+ for (auto i = 0 ; i < reduced_axes.len ; i++) {
4218
+ dims.push_back (reduced_axes.data [i]);
4219
+ HxW *= atInput.size (reduced_axes.data [i]);
4220
+ }
4221
+ auto C = atInput.size (channel_axis);
4222
+ auto permutedInput = atInput.permute (dims);
4223
+ auto permutedShape = permutedInput.sizes ();
4224
+ auto reshapedInput = permutedInput.reshape ({N, C, HxW, 1 }).contiguous ();
4225
+
4226
+ auto atWeight = impl::aten::buildATen (weight);
4227
+ auto atBias = impl::aten::buildATen (bias);
4228
+ auto atOut = impl::aten::buildATen (out);
4229
+ auto atSaveMean = impl::aten::buildATen (save_mean);
4230
+ auto atSaveInvstd = impl::aten::buildATen (save_invstd);
4231
+
4232
+ std::vector<int64_t > reverse_order (dims.size ());
4233
+ for (auto i = 0 ; i < atInput.dim (); i++) {
4234
+ reverse_order[dims[i]] = i;
4235
+ }
4236
+ auto tempOut = CALL_ATEN_CUDA_FUNC (native_group_norm, reshapedInput, atWeight, atBias, N, C, HxW, num_groups, eps);
4237
+ at::native::copy_ (atOut, std::get<0 >(tempOut).reshape (permutedShape).permute (reverse_order), true );
4238
+ at::native::copy_ (atSaveMean, std::get<1 >(tempOut), true );
4239
+ at::native::copy_ (atSaveInvstd, std::get<2 >(tempOut), true );
4240
+ return diopiSuccess;
4241
+ }
4242
+
4186
4243
diopiError_t diopiGroupNorm (diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
4187
- diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups, double eps) {
4244
+ diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
4245
+ double eps) {
4188
4246
impl::aten::setCurStream (ctx);
4189
4247
auto atInput = impl::aten::buildATen (input);
4190
4248
auto atWeight = impl::aten::buildATen (weight);
0 commit comments