Skip to content

Commit 5e27618

Browse files
committed
Fixed ece + saving models
1 parent 99a118f commit 5e27618

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

code/train_models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def get_data_loaders(self, batch_size: int = None) -> Tuple[DataLoader, DataLoad
8787
"""
8888
Create training and validation data loaders for the chosen dataset.
8989
"""
90-
bs = batch_size or self.config.batch_size
90+
bs = batch_size or self.config.batch_size_final
9191
transform = transforms.Compose(
9292
self.base_transform
9393
+ [
@@ -201,7 +201,8 @@ def train_model(
201201
evaluator.fit(model)
202202
self.models.append(model)
203203

204-
torch.save(model.state_dict(), f'{self.config.best_models_save_path}_trained/model_{model_id}.pth')
204+
print(f"Saving model {model_id} to {self.config.best_models_save_path}_trained")
205+
torch.save(model.state_dict(), Path(self.config.best_models_save_path) / f'model_{model_id}.pth')
205206

206207
return model
207208

@@ -329,6 +330,8 @@ def evaluate_ensemble(self, test_loader):
329330
bin_idx = torch.bucketize(conf, bin_boundaries, right=True) - 1
330331
if bin_idx < 0: # Handle case when confidence is exactly 0.0
331332
bin_idx = 0
333+
if bin_idx == n_bins:
334+
bin_idx = n_bins - 1
332335
bin_counts[bin_idx] += 1
333336
bin_conf_sums[bin_idx] += conf
334337
bin_acc_sums[bin_idx] += correct

0 commit comments

Comments
 (0)