Skip to content

Commit 3998b5d

Browse files
authored
Fix edgecase in batch size scaling tuner alg (#21187)
* fix + tests * changelog
1 parent cd30ce4 commit 3998b5d

File tree

3 files changed

+71
-7
lines changed

3 files changed

+71
-7
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3131

3232
### Fixed
3333

34-
-
34+
- Fixed edgecase when `max_trials` is reached in `Tuner.scale_batch_size` ([#21187](https://github.com/Lightning-AI/pytorch-lightning/pull/21187))
3535

3636

3737
---

src/lightning/pytorch/tuner/batch_size_scaling.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,22 @@ def _run_power_scaling(
178178
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
179179
# if it was we exit, else we continue downscaling in case we haven't encountered a single optimal batch size
180180
any_success = False
181-
for _ in range(max_trials):
181+
last_successful_size = new_size
182+
for i in range(max_trials):
182183
garbage_collection_cuda()
183184

184185
# reset after each try
185186
_reset_progress(trainer)
186187

187188
try:
188189
_try_loop_run(trainer, params)
190+
last_successful_size = new_size # Store the current size before doubling
191+
192+
# Check if this is the last trial before trying to double
193+
if i + 1 >= max_trials:
194+
new_size = last_successful_size
195+
break
196+
189197
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")
190198

191199
if not changed:
@@ -224,6 +232,7 @@ def _run_binary_scaling(
224232
low = 1
225233
high = None
226234
count = 0
235+
last_successful_size = new_size
227236
while True:
228237
garbage_collection_cuda()
229238

@@ -233,9 +242,14 @@ def _run_binary_scaling(
233242
try:
234243
# run loop
235244
_try_loop_run(trainer, params)
245+
last_successful_size = new_size # Store the current size before doubling
236246
count += 1
237-
if count > max_trials:
247+
248+
# Check if we've reached max_trials before trying to adjust batch size
249+
if count >= max_trials:
250+
new_size = last_successful_size
238251
break
252+
239253
# Double in size
240254
low = new_size
241255
if high:

tests/tests_pytorch/tuner/test_scale_batch_size.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmp_path, model_bs, dm
6969

7070
tuner = Tuner(trainer)
7171
new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule)
72-
assert new_batch_size == 16
72+
assert new_batch_size == 8
7373

7474
if model_bs is not None:
7575
assert model.batch_size == new_batch_size
@@ -314,7 +314,9 @@ def test_dataloader_reset_with_scale_batch_size(tmp_path, caplog, scale_method,
314314

315315
dataset_len = len(trainer.train_dataloader.dataset)
316316
assert dataset_len == 64
317-
assert caplog.text.count("trying batch size") == (max_trials if init_batch_size < dataset_len else 0)
317+
# With our fix, when max_trials is reached, we don't try the doubled batch size, so we get max_trials - 1 messages
318+
expected_tries = max_trials - 1 if init_batch_size < dataset_len and max_trials > 0 else 0
319+
assert caplog.text.count("trying batch size") == expected_tries
318320
assert caplog.text.count("greater or equal than the length") == int(new_batch_size == dataset_len)
319321

320322
assert trainer.train_dataloader.batch_size == new_batch_size
@@ -326,7 +328,8 @@ def test_tuner_with_evaluation_methods(tmp_path, trainer_fn):
326328
"""Test batch size tuner with Trainer's evaluation methods."""
327329
before_batch_size = 2
328330
max_trials = 4
329-
expected_scaled_batch_size = before_batch_size ** (max_trials + 1)
331+
# With our fix, we return the last successful batch size, not the doubled untested one
332+
expected_scaled_batch_size = before_batch_size**max_trials # 2^4 = 16, not 2^5 = 32
330333

331334
model = BatchSizeModel(batch_size=before_batch_size)
332335
trainer = Trainer(default_root_dir=tmp_path, max_epochs=100)
@@ -349,7 +352,8 @@ def test_batch_size_finder_callback(tmp_path, trainer_fn):
349352
before_batch_size = 2
350353
max_trials = 4
351354
max_epochs = 2
352-
expected_scaled_batch_size = before_batch_size ** (max_trials + 1)
355+
# With our fix, we return the last successful batch size, not the doubled untested one
356+
expected_scaled_batch_size = before_batch_size**max_trials # 2^4 = 16, not 2^5 = 32
353357

354358
model = BatchSizeModel(batch_size=before_batch_size)
355359
batch_size_finder = BatchSizeFinder(max_trials=max_trials, batch_arg_name="batch_size")
@@ -533,3 +537,49 @@ def train_dataloader(self):
533537
assert len(scale_checkpoints) == 0, (
534538
f"scale_batch_size checkpoint files should be cleaned up, but found: {scale_checkpoints}"
535539
)
540+
541+
542+
class AlwaysSucceedingBoringModel(BoringModel):
543+
"""A BoringModel that never fails with OOM errors for batch size scaling tests."""
544+
545+
def __init__(self, batch_size=2):
546+
super().__init__()
547+
self.batch_size = batch_size
548+
549+
550+
class FailsAtBatchSizeBoringModel(BoringModel):
551+
"""A BoringModel that fails when batch size reaches a certain threshold."""
552+
553+
def __init__(self, batch_size=2, fail_at=16):
554+
super().__init__()
555+
self.batch_size = batch_size
556+
self.fail_at = fail_at
557+
558+
def training_step(self, batch, batch_idx):
559+
# Simulate OOM error when batch size is too large
560+
if self.batch_size >= self.fail_at:
561+
raise RuntimeError("CUDA error: out of memory")
562+
return super().training_step(batch, batch_idx)
563+
564+
565+
@pytest.mark.parametrize(
566+
("max_trials", "mode", "init_val", "expected"),
567+
[
568+
(3, "power", 2, 8),
569+
(3, "binsearch", 2, 8),
570+
(1, "power", 4, 4),
571+
(0, "power", 2, 2),
572+
],
573+
)
574+
def test_scale_batch_size_max_trials_modes(tmp_path, max_trials, mode, init_val, expected):
575+
model = AlwaysSucceedingBoringModel(batch_size=init_val)
576+
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
577+
tuner = Tuner(trainer)
578+
result = tuner.scale_batch_size(
579+
model,
580+
mode=mode,
581+
steps_per_trial=1,
582+
max_trials=max_trials,
583+
init_val=init_val,
584+
)
585+
assert result == expected

0 commit comments

Comments
 (0)