Skip to content

Commit f09e78e

Browse files
author
Yin Hongyun
committed
add batch_norm_GB
1 parent 1e98d59 commit f09e78e

File tree

6 files changed

+136
-0
lines changed

6 files changed

+136
-0
lines changed

diopi_test/python/configs/diopi_configs.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,46 @@
152152
),
153153
),
154154

155+
"batch_norm_GB": dict(
156+
name=["batch_norm_GB"],
157+
interface=['CustomizedTest'],
158+
dtype=[np.float32, np.float16, np.float64],
159+
atol=1e-3,
160+
rtol=1e-4,
161+
atol_half=1e-1,
162+
rtol_half=1e-2,
163+
para=dict(
164+
training=[True, True, True],
165+
momentum=[0.01, 0.01, 0.01],
166+
axis=[0, 1, 2],
167+
eps=[1e-4, 1e-4, 1e-4],
168+
),
169+
tensor_para=dict(
170+
args=[
171+
{
172+
"ins": ["input"],
173+
"shape": ((2, 64, 32, 32),(2, 64, 32, 32),(2, 64, 32, 32)),
174+
"gen_fn": "Genfunc.randn",
175+
},
176+
{
177+
"ins": ["running_mean"],
178+
"shape": ((2,), (64,), (32,)),
179+
"gen_fn": "Genfunc.zeros",
180+
},
181+
{
182+
"ins": ["running_var"],
183+
"shape": ((2,), (64,), (32,)),
184+
"gen_fn": "Genfunc.ones",
185+
},
186+
{
187+
"ins": ["weight", "bias"],
188+
"shape": ((2,), (64,), (32,)),
189+
"gen_fn": "Genfunc.randn",
190+
},
191+
]
192+
),
193+
),
194+
155195
# FIXME batch_norm输入0size的张量报错
156196
'batch_norm': dict(
157197
name=["batch_norm"],

diopi_test/python/conformance/customized_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,3 +891,21 @@ def pool3d(input, kernel_size, stride, padding, dilation, ceil_mode, count_inclu
891891
def layer_normGB(input, weight, bias, eps, normalized_shape):
892892
return torch.nn.functional.layer_norm(input=input, weight=weight, bias=bias, eps=eps, normalized_shape=normalized_shape)
893893

894+
def batch_norm_GB(input, running_mean, running_var, weight, bias, training=False, momentum=0.1, eps=1e-05, axis=1):
895+
dim = input.dim()
896+
dims = list(range(dim))
897+
dims.remove(axis)
898+
dims.insert(1, axis)
899+
permuted_input = input.permute(dims)
900+
out = torch.nn.functional.batch_norm(
901+
permuted_input,
902+
running_mean,
903+
running_var,
904+
weight=weight,
905+
bias=bias,
906+
training=training,
907+
momentum=momentum,
908+
eps=eps,
909+
)
910+
out = out.permute(dims)
911+
return out

diopi_test/python/conformance/diopi_functions.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2823,6 +2823,52 @@ def batch_norm(
28232823
return out
28242824

28252825

2826+
def batch_norm_GB(
2827+
input,
2828+
running_mean,
2829+
running_var,
2830+
weight,
2831+
bias,
2832+
training=False,
2833+
momentum=0.1,
2834+
eps=1e-05,
2835+
axis=1
2836+
) -> Tensor:
2837+
dim = input.size().len
2838+
dim = [i for i in range(dim) if i!= axis]
2839+
dtype = Dtype.float32 if input.get_dtype() == Dtype.float16 else None
2840+
_, save_mean = reduce_op_process(input, dim, dtype=dtype)
2841+
save_invstd = raw_like(save_mean)
2842+
2843+
if not training:
2844+
assert (
2845+
running_mean is not None and running_var is not None
2846+
), "if not trainging, running_mean and running_var must be defined"
2847+
2848+
out = raw_like(input)
2849+
func = check_function("diopiBatchNormGB")
2850+
ret = func(
2851+
input.context(),
2852+
out,
2853+
save_mean,
2854+
save_invstd,
2855+
input,
2856+
weight,
2857+
bias,
2858+
running_mean,
2859+
running_var,
2860+
training,
2861+
momentum,
2862+
eps,
2863+
axis
2864+
)
2865+
2866+
check_returncode(ret)
2867+
GLOBAL_STATE["batch_norm_save_mean"] = save_mean
2868+
GLOBAL_STATE["batch_norm_save_invstd"] = save_invstd
2869+
return out
2870+
2871+
28262872
def batch_norm_stats(input, eps):
28272873
func = check_function("diopiBatchNormStats")
28282874
# cuda accumulate dtype mapping

diopi_test/python/conformance/global_op_list.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"conv2d": ["2d", "input", "weight"],
1212
"conv3d": ["3d", "input", "weight"],
1313
"batch_norm": ["input"],
14+
"batch_norm_GB": ["input", "running_mean", "running_var"],
1415
"adaptive_avg_pool2d": ["2d", "input"],
1516
"adaptive_max_pool2d": ["2d", "input"],
1617
"adaptive_avg_pool3d": ["3d", "input"],
@@ -64,6 +65,7 @@
6465

6566
ops_with_states = {
6667
"batch_norm": {"running_mean", "running_var"},
68+
"batch_norm_GB": {"running_mean", "running_var"},
6769
"sgd": {"buf", "param"},
6870
"fill_": {"input"},
6971
"zero_": {"input"},

impl/torch/functions/functions.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2557,6 +2557,29 @@ diopiError_t diopiBatchNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
25572557
return diopiSuccess;
25582558
}
25592559

2560+
diopiError_t diopiBatchNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
2561+
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, diopiTensorHandle_t running_mean,
2562+
diopiTensorHandle_t running_var, bool training, double momentum, double eps, int64_t axis) {
2563+
impl::aten::setCurStream(ctx);
2564+
auto atInput = impl::aten::buildATen(input);
2565+
auto atWeight = impl::aten::buildATen(weight);
2566+
auto atBias = impl::aten::buildATen(bias);
2567+
auto atRunningMean = impl::aten::buildATen(running_mean);
2568+
auto atRunningVar = impl::aten::buildATen(running_var);
2569+
auto atOut = impl::aten::buildATen(out);
2570+
auto atSaveMean = impl::aten::buildATen(save_mean);
2571+
auto atSaveInvstd = impl::aten::buildATen(save_invstd);
2572+
2573+
std::vector<int64_t> dims(atInput.dim());
2574+
std::iota(dims.begin(), dims.end(), 0);
2575+
std::swap(dims[1], dims[axis]);
2576+
auto permutedInput = atInput.permute(dims);
2577+
CALL_ATEN_CUDA_FUNC(
2578+
native_batch_norm_out, atOut, atSaveMean, atSaveInvstd, permutedInput, atWeight, atBias, atRunningMean, atRunningVar, training, momentum, eps);
2579+
atOut = atOut.permute(dims);
2580+
return diopiSuccess;
2581+
}
2582+
25602583
diopiError_t diopiSlice(diopiContextHandle_t ctx, diopiTensorHandle_t null_out, diopiConstTensorHandle_t input, int64_t dim, int64_t start, int64_t end,
25612584
int64_t step) {
25622585
impl::aten::setCurStream(ctx);

proto/include/diopi/functions.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ DIOPI_API diopiError_t diopiBatchNorm(diopiContextHandle_t ctx, diopiTensorHandl
120120
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias,
121121
diopiTensorHandle_t running_mean, diopiTensorHandle_t running_var, bool training, double momentum, double eps);
122122

123+
/**
124+
* @brief Applies Batch Normalization.
125+
*/
126+
DIOPI_API diopiError_t diopiBatchNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
127+
diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias,
128+
diopiTensorHandle_t running_mean, diopiTensorHandle_t running_var, bool training, double momentum, double eps, int64_t axis);
129+
123130
/**
124131
* @brief Computes the mean and inverse standard deviation across a batch of data for Synchronized Batch Normalization (SyncBN).
125132
* @param[in] ctx Context environment.

0 commit comments

Comments
 (0)