We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d23596b commit 96f8b96Copy full SHA for 96f8b96
tests/bn_folding_test.py
@@ -464,17 +464,13 @@ def test_same_training_and_prediction(model_name):
464
if model_name == "conv2d":
465
x_shape = (2, 2, 1)
466
kernel = np.array([[[[1., 1.]], [[1., 0.]]], [[[1., 1.]], [[0., 1.]]]])
467
- gamma = np.array([2., 1.])
468
- beta = np.array([0., 1.])
469
- moving_mean = np.array([1., 1.])
470
- moving_variance = np.array([1., 2.])
471
elif model_name == "dense":
472
x_shape = (4,)
473
kernel = np.array([[1., 1.], [1., 0.], [1., 1.], [0., 1.]])
474
475
476
477
+ gamma = np.array([2., 1.])
+ beta = np.array([0., 1.])
+ moving_mean = np.array([1., 1.])
+ moving_variance = np.array([1., 2.])
478
iteration = np.array(-1)
479
480
train_ds = generate_dataset(train_size=10, batch_size=10, input_shape=x_shape,
0 commit comments