@@ -290,9 +290,8 @@ def testSeparableConv1DQuantize_(self, kwargs):
290290 @parameterized .named_parameters (
291291 ('padding_valid' , {'padding' : 'valid' }),
292292 ('padding_same' , {'padding' : 'same' }),
293- # TODO(b/186666265): tighten the tolerance to 1e-5.
294293 ('padding_same_dilation_2' ,
295- {'padding' : 'same' , 'dilation_rate' : 2 }, 0.19 ),
294+ {'padding' : 'same' , 'dilation_rate' : 2 }),
296295 ('strides' , {'strides' : 2 }),
297296 ('dilation_rate' , {'dilation_rate' : 2 }),
298297 ('depth_multiplier' , {'depth_multiplier' : 2 }),
@@ -307,7 +306,7 @@ def testSeparableConv1DQuantize_(self, kwargs):
307306 'pointwise_constraint' : tf .keras .constraints .min_max_norm (0. , 2. ),
308307 'bias_constraint' : tf .keras .constraints .unit_norm ()})
309308 )
310- def testSeparableConvQuantize_ (self , kwargs , tolerance = 1e-5 ):
309+ def testSeparableConvQuantize_ (self , kwargs ):
311310 kwargs ['filters' ] = 2
312311 kwargs ['kernel_size' ] = 3
313312 num_samples = 2
@@ -338,17 +337,20 @@ def testSeparableConvQuantize_(self, kwargs, tolerance=1e-5):
338337
339338 # Ensure model is equivalent, and training results are the same.
340339 sepconv_model .compile (loss = 'categorical_crossentropy' , optimizer = 'sgd' )
341- sepconv_model .fit (x , y , epochs = 100 )
342340 transformed_model .compile (loss = 'categorical_crossentropy' , optimizer = 'sgd' )
343- transformed_model .fit (x , y , epochs = 100 )
344341
345- # Over a long training cycle with constraints and regularizers, the model
346- # can build very minute differences.
347- self .assertAllClose (
348- sepconv_model .predict (x ),
349- transformed_model .predict (x ),
350- atol = tolerance ,
351- rtol = tolerance )
342+ epochs = 100
343+ for _ in range (epochs ):
344+ sepconv_model .fit (x , y , epochs = 1 , verbose = 2 )
345+ transformed_model .fit (x , y , epochs = 1 , verbose = 2 )
346+ self .assertAllClose (
347+ sepconv_model .get_weights (),
348+ transformed_model .get_weights ())
349+ # To prevent accumulated numerical errors.
350+ transformed_model .set_weights (sepconv_model .get_weights ())
351+ self .assertAllClose (
352+ sepconv_model .predict (x ),
353+ transformed_model .predict (x ))
352354
353355 # TODO(pulkitb): Add individual tests for the following transforms.
354356 # Conv2DReshapeBatchNormQuantize, Conv2DReshapeBatchNormReLUQuantize
0 commit comments