@@ -76,18 +76,18 @@ def __init__(self, config: TrainConfig):
76
76
config .device if torch .cuda .is_available () else "cpu"
77
77
)
78
78
79
- def get_data_loaders (self , batch_size : int = None ) -> Tuple [DataLoader , DataLoader ]:
79
+ def get_data_loaders (self , batch_size : int = None , seed = None ) -> Tuple [DataLoader , DataLoader ]:
80
80
"""
81
81
Create training and validation data loaders for the chosen dataset.
82
82
"""
83
83
bs = batch_size if batch_size else self .config .batch_size_final
84
84
train_transform = transforms .Compose (
85
85
self .base_transform
86
86
+ [
87
+ transforms .RandomHorizontalFlip (),
88
+ transforms .RandomCrop (32 , padding = 4 ),
87
89
transforms .ToTensor (),
88
90
transforms .Normalize (self .MEAN , self .STD ),
89
- transforms .RandomCrop (32 , padding = 4 ),
90
- transforms .RandomHorizontalFlip (),
91
91
]
92
92
)
93
93
test_transform = transforms .Compose (
@@ -123,6 +123,8 @@ def get_data_loaders(self, batch_size: int = None) -> Tuple[DataLoader, DataLoad
123
123
124
124
num_samples = len (train_data )
125
125
indices = list (range (num_samples ))
126
+ np_seed = seed if seed else self .config .seed
127
+ np .random .seed (np_seed )
126
128
np .random .shuffle (indices )
127
129
128
130
if self .config .evaluate_ensemble_flag :
0 commit comments