@@ -332,14 +332,18 @@ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
332
332
# in the usually fixed policy and sample magnitude from a normal distribution
333
333
# with mean `magnitude` and std-dev of `magnitude_std`.
334
334
# NOTE This is my own hack, being tested, not in papers or reference impls.
335
+ # If magnitude_std is inf, we sample magnitude from a uniform distribution
335
336
self .magnitude_std = self .hparams .get ('magnitude_std' , 0 )
336
337
337
338
def __call__ (self , img ):
338
339
if self .prob < 1.0 and random .random () > self .prob :
339
340
return img
340
341
magnitude = self .magnitude
341
- if self .magnitude_std and self .magnitude_std > 0 :
342
- magnitude = random .gauss (magnitude , self .magnitude_std )
342
+ if self .magnitude_std :
343
+ if self .magnitude_std == float ('inf' ):
344
+ magnitude = random .uniform (0 , magnitude )
345
+ elif self .magnitude_std > 0 :
346
+ magnitude = random .gauss (magnitude , self .magnitude_std )
343
347
magnitude = min (_MAX_LEVEL , max (0 , magnitude )) # clip to valid range
344
348
level_args = self .level_fn (magnitude , self .hparams ) if self .level_fn is not None else tuple ()
345
349
return self .aug_fn (img , * level_args , ** self .kwargs )
@@ -790,6 +794,7 @@ def augment_and_mix_transform(config_str, hparams):
790
794
depth = - 1
791
795
alpha = 1.
792
796
blended = False
797
+ hparams ['magnitude_std' ] = float ('inf' )
793
798
config = config_str .split ('-' )
794
799
assert config [0 ] == 'augmix'
795
800
config = config [1 :]
0 commit comments