Skip to content

Commit 1f0439a

Browse files
committed
Fix OOM bug in lm eval
1 parent f111aea commit 1f0439a

File tree

4 files changed

+36
-13
lines changed

4 files changed

+36
-13
lines changed

algoperf/random_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType:
3535

3636
def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
3737
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
38-
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
38+
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32)
3939
return [new_seed, data]
4040

4141

4242
def _split(seed: SeedType, num: int = 2) -> SeedType:
4343
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
44-
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
44+
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32, size=[num, 2])
4545

4646

4747
def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name

algoperf/workloads/lm/lm_pytorch/workload.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""LM workload implemented in PyTorch."""
22

3+
import contextlib
34
from itertools import islice
45
from typing import Any, Dict, Iterator, Optional, Tuple
56

@@ -8,7 +9,7 @@
89
import torch.distributed as dist
910
from torch.nn.parallel import DistributedDataParallel as DDP
1011

11-
from algoperf import data_utils, param_utils, pytorch_utils, spec
12+
from algoperf import param_utils, pytorch_utils, spec
1213
from algoperf.workloads.lm.lm_pytorch.plainlm_model import (
1314
ModelConfig,
1415
Transformer,
@@ -72,12 +73,23 @@ def model_fn(
7273
del model_state, rng, update_batch_norm, dropout_rate
7374
model = params
7475

75-
# Convert one-hot inputs to token IDs if needed
76-
inputs = augmented_and_preprocessed_input_batch['inputs']
77-
if inputs.dim() == 3: # one-hot encoded
76+
# Set model to eval or train mode based on the mode parameter
77+
if mode == spec.ForwardPassMode.EVAL:
78+
model.eval()
79+
elif mode == spec.ForwardPassMode.TRAIN:
80+
model.train()
81+
contexts = {
82+
spec.ForwardPassMode.EVAL: torch.no_grad,
83+
spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
84+
}
85+
with contexts[mode]():
86+
# Convert one-hot inputs to token IDs if needed
87+
inputs = augmented_and_preprocessed_input_batch['inputs']
88+
if inputs.dim() == 3: # one-hot encoded
7889
inputs = inputs.argmax(dim=-1)
7990

80-
logits = model(inputs)
91+
logits = model(inputs)
92+
8193
return logits, None
8294

8395
def _build_input_queue(
@@ -90,12 +102,14 @@ def _build_input_queue(
90102
repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]:
91103
"""Build an input queue for the given split."""
92104
local_batch_size = global_batch_size // N_GPUS
105+
# In DDP mode, pass local_device_count=1 to prevent shard_and_maybe_pad_np
106+
# from seeing all GPUs via torch.cuda.device_count()
93107
loader = get_data_iter(
94108
data_rng=data_rng,
95109
split=split,
96110
data_dir=data_dir,
97111
global_batch_size=local_batch_size,
98-
num_batches=num_batches
112+
num_batches=num_batches,
99113
)
100114
if USE_PYTORCH_DDP:
101115
loader = islice(loader, RANK, None, N_GPUS)
@@ -104,7 +118,7 @@ def _build_input_queue(
104118
batch = {
105119
'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype),
106120
'targets': torch.tensor(batch['targets'], device=DEVICE, dtype=torch.int64),
107-
'weights': None,
121+
'weights': torch.tensor(batch['weights'], device=DEVICE, dtype=torch.float32) if batch['weights'] is not None else None,
108122
}
109123
yield batch
110124

algoperf/workloads/lm/workload.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def num_test_examples(self) -> int:
7373

7474
@property
7575
def eval_batch_size(self) -> int:
76-
return 64
76+
return 256
7777

7878
@property
7979
def train_mean(self):
@@ -138,6 +138,11 @@ def _eval_model_on_split(
138138
) -> Dict[str, float]:
139139
"""Run a full evaluation of the model."""
140140
num_batches = int(math.ceil(num_examples / global_batch_size))
141+
142+
# Handle edge case where num_batches is 0 (e.g., test split with 0 examples)
143+
if num_batches == 0:
144+
return {'loss': 0.0, 'ppl': 1.0}
145+
141146
if split not in self._eval_iters:
142147
# These iterators will repeat indefinitely.
143148
self._eval_iters[split] = self._build_input_queue(
@@ -159,7 +164,7 @@ def _eval_model_on_split(
159164
eval_metrics[metric_name] += metric_value
160165

161166
eval_results = self._normalize_eval_metrics(num_examples, eval_metrics)
162-
eval_results['ppl'] = np.exp(eval_results['loss']).item()
167+
eval_results['ppl'] = np.exp(eval_results['loss']).item()
163168
return eval_results
164169

165170

@@ -173,9 +178,11 @@ def _eval_batch(self,
173178
params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False)
174179
# Calculate cross-entropy loss
175180
metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights'])
181+
# CRITICAL: Detach tensors to free computation graph and activations
182+
# Without this, all intermediate activations are kept in memory!
176183
return {
177-
'loss': metrics['summed'],
178-
'denominator': metrics['n_valid_examples'],
184+
'loss': metrics['summed'].detach(),
185+
'denominator': metrics['n_valid_examples'].detach(),
179186
}
180187

181188

algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,8 @@ def get_batch_size(workload_name):
372372
return 128
373373
elif workload_name == 'mnist':
374374
return 16
375+
elif workload_name == 'lm':
376+
return 64
375377
else:
376378
raise ValueError(f'Unsupported workload name: {workload_name}.')
377379

0 commit comments

Comments
 (0)