@@ -69,7 +69,7 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmp_path, model_bs, dm
69
69
70
70
tuner = Tuner (trainer )
71
71
new_batch_size = tuner .scale_batch_size (model , mode = "binsearch" , init_val = 4 , max_trials = 2 , datamodule = datamodule )
72
- assert new_batch_size == 16
72
+ assert new_batch_size == 8
73
73
74
74
if model_bs is not None :
75
75
assert model .batch_size == new_batch_size
@@ -314,7 +314,9 @@ def test_dataloader_reset_with_scale_batch_size(tmp_path, caplog, scale_method,
314
314
315
315
dataset_len = len (trainer .train_dataloader .dataset )
316
316
assert dataset_len == 64
317
- assert caplog .text .count ("trying batch size" ) == (max_trials if init_batch_size < dataset_len else 0 )
317
+ # With our fix, when max_trials is reached, we don't try the doubled batch size, so we get max_trials - 1 messages
318
+ expected_tries = max_trials - 1 if init_batch_size < dataset_len and max_trials > 0 else 0
319
+ assert caplog .text .count ("trying batch size" ) == expected_tries
318
320
assert caplog .text .count ("greater or equal than the length" ) == int (new_batch_size == dataset_len )
319
321
320
322
assert trainer .train_dataloader .batch_size == new_batch_size
@@ -326,7 +328,8 @@ def test_tuner_with_evaluation_methods(tmp_path, trainer_fn):
326
328
"""Test batch size tuner with Trainer's evaluation methods."""
327
329
before_batch_size = 2
328
330
max_trials = 4
329
- expected_scaled_batch_size = before_batch_size ** (max_trials + 1 )
331
+ # With our fix, we return the last successful batch size, not the doubled untested one
332
+ expected_scaled_batch_size = before_batch_size ** max_trials # 2^4 = 16, not 2^5 = 32
330
333
331
334
model = BatchSizeModel (batch_size = before_batch_size )
332
335
trainer = Trainer (default_root_dir = tmp_path , max_epochs = 100 )
@@ -349,7 +352,8 @@ def test_batch_size_finder_callback(tmp_path, trainer_fn):
349
352
before_batch_size = 2
350
353
max_trials = 4
351
354
max_epochs = 2
352
- expected_scaled_batch_size = before_batch_size ** (max_trials + 1 )
355
+ # With our fix, we return the last successful batch size, not the doubled untested one
356
+ expected_scaled_batch_size = before_batch_size ** max_trials # 2^4 = 16, not 2^5 = 32
353
357
354
358
model = BatchSizeModel (batch_size = before_batch_size )
355
359
batch_size_finder = BatchSizeFinder (max_trials = max_trials , batch_arg_name = "batch_size" )
@@ -533,3 +537,49 @@ def train_dataloader(self):
533
537
assert len (scale_checkpoints ) == 0 , (
534
538
f"scale_batch_size checkpoint files should be cleaned up, but found: { scale_checkpoints } "
535
539
)
540
+
541
+
542
+ class AlwaysSucceedingBoringModel (BoringModel ):
543
+ """A BoringModel that never fails with OOM errors for batch size scaling tests."""
544
+
545
+ def __init__ (self , batch_size = 2 ):
546
+ super ().__init__ ()
547
+ self .batch_size = batch_size
548
+
549
+
550
+ class FailsAtBatchSizeBoringModel (BoringModel ):
551
+ """A BoringModel that fails when batch size reaches a certain threshold."""
552
+
553
+ def __init__ (self , batch_size = 2 , fail_at = 16 ):
554
+ super ().__init__ ()
555
+ self .batch_size = batch_size
556
+ self .fail_at = fail_at
557
+
558
+ def training_step (self , batch , batch_idx ):
559
+ # Simulate OOM error when batch size is too large
560
+ if self .batch_size >= self .fail_at :
561
+ raise RuntimeError ("CUDA error: out of memory" )
562
+ return super ().training_step (batch , batch_idx )
563
+
564
+
565
+ @pytest .mark .parametrize (
566
+ ("max_trials" , "mode" , "init_val" , "expected" ),
567
+ [
568
+ (3 , "power" , 2 , 8 ),
569
+ (3 , "binsearch" , 2 , 8 ),
570
+ (1 , "power" , 4 , 4 ),
571
+ (0 , "power" , 2 , 2 ),
572
+ ],
573
+ )
574
+ def test_scale_batch_size_max_trials_modes (tmp_path , max_trials , mode , init_val , expected ):
575
+ model = AlwaysSucceedingBoringModel (batch_size = init_val )
576
+ trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 )
577
+ tuner = Tuner (trainer )
578
+ result = tuner .scale_batch_size (
579
+ model ,
580
+ mode = mode ,
581
+ steps_per_trial = 1 ,
582
+ max_trials = max_trials ,
583
+ init_val = init_val ,
584
+ )
585
+ assert result == expected
0 commit comments