Skip to content

Commit 42cf795

Browse files
committed
fix augmentations order
1 parent d78c4da commit 42cf795

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

code/DARTS_baseline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def run(self):
8484
self.config.selected_archs = []
8585
shutil.rmtree(self.config.output_path, ignore_errors=True)
8686

87-
for _ in tqdm(range(self.config.n_ensemble_models), desc="Finding best architectures"):
88-
train_loader, valid_loader, test_loader = self.get_data_loaders()
87+
for idx, _ in enumerate(tqdm(range(self.config.n_ensemble_models), desc="Finding best architectures")):
88+
train_loader, valid_loader, test_loader = self.get_data_loaders(seed=self.config.seed + idx * 10)
8989
self.config.selected_archs.append(self.get_best_models(train_loader, valid_loader))
9090

9191
for idx, arch in enumerate(tqdm(self.config.selected_archs, desc="Training models")):

code/train_models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,18 @@ def __init__(self, config: TrainConfig):
7676
config.device if torch.cuda.is_available() else "cpu"
7777
)
7878

79-
def get_data_loaders(self, batch_size: int = None) -> Tuple[DataLoader, DataLoader]:
79+
def get_data_loaders(self, batch_size: int = None, seed=None) -> Tuple[DataLoader, DataLoader]:
8080
"""
8181
Create training and validation data loaders for the chosen dataset.
8282
"""
8383
bs = batch_size if batch_size else self.config.batch_size_final
8484
train_transform = transforms.Compose(
8585
self.base_transform
8686
+ [
87+
transforms.RandomHorizontalFlip(),
88+
transforms.RandomCrop(32, padding=4),
8789
transforms.ToTensor(),
8890
transforms.Normalize(self.MEAN, self.STD),
89-
transforms.RandomCrop(32, padding=4),
90-
transforms.RandomHorizontalFlip(),
9191
]
9292
)
9393
test_transform = transforms.Compose(
@@ -123,6 +123,8 @@ def get_data_loaders(self, batch_size: int = None) -> Tuple[DataLoader, DataLoad
123123

124124
num_samples = len(train_data)
125125
indices = list(range(num_samples))
126+
np_seed = seed if seed else self.config.seed
127+
np.random.seed(np_seed)
126128
np.random.shuffle(indices)
127129

128130
if self.config.evaluate_ensemble_flag:

0 commit comments

Comments
 (0)