Skip to content

Commit 9a1bd35

Browse files
authored
Merge pull request #571 from normster/augmix-fix
Enable uniform augmentation magnitude sampling and set AugMix default
2 parents c1cf971 + 79640fc commit 9a1bd35

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

timm/data/auto_augment.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,14 +332,18 @@ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
332332
# in the usually fixed policy and sample magnitude from a normal distribution
333333
# with mean `magnitude` and std-dev of `magnitude_std`.
334334
# NOTE This is my own hack, being tested, not in papers or reference impls.
335+
# If magnitude_std is inf, we sample magnitude from a uniform distribution
335336
self.magnitude_std = self.hparams.get('magnitude_std', 0)
336337

337338
def __call__(self, img):
338339
if self.prob < 1.0 and random.random() > self.prob:
339340
return img
340341
magnitude = self.magnitude
341-
if self.magnitude_std and self.magnitude_std > 0:
342-
magnitude = random.gauss(magnitude, self.magnitude_std)
342+
if self.magnitude_std:
343+
if self.magnitude_std == float('inf'):
344+
magnitude = random.uniform(0, magnitude)
345+
elif self.magnitude_std > 0:
346+
magnitude = random.gauss(magnitude, self.magnitude_std)
343347
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
344348
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
345349
return self.aug_fn(img, *level_args, **self.kwargs)
@@ -790,6 +794,7 @@ def augment_and_mix_transform(config_str, hparams):
790794
depth = -1
791795
alpha = 1.
792796
blended = False
797+
hparams['magnitude_std'] = float('inf')
793798
config = config_str.split('-')
794799
assert config[0] == 'augmix'
795800
config = config[1:]

0 commit comments

Comments
 (0)