Skip to content

Commit 52bfcc9

Browse files
Fixed MBU/MFU calculations for the learner.
1 parent df735bc commit 52bfcc9

File tree

2 files changed

+42
-31
lines changed

2 files changed

+42
-31
lines changed

open_instruct/grpo_fast.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,7 +2244,7 @@ def generate_thread(args, vllm_engines, resume_training_step, stop_event, genera
22442244
def calculate_mfu_mbu(
22452245
model_dims: utils.ModelDims, packed_data: PackedData, train_duration: float, args: Args
22462246
) -> dict[str, float]:
2247-
"""Calculate Model FLOPs Utilization (MFU) and Model Bandwidth Utilization (MBU).
2247+
"""Calculate Model FLOPs Utilization (MFU) for training.
22482248
22492249
Args:
22502250
model_dims: Model dimensions for FLOPs/memory calculations
@@ -2253,7 +2253,7 @@ def calculate_mfu_mbu(
22532253
args: Training arguments
22542254
22552255
Returns:
2256-
Dictionary with 'mfu' and 'mbu' keys as percentages
2256+
Dictionary with 'mfu' key as percentage (MBU not reported for training)
22572257
"""
22582258
assert model_dims is not None, "model_dims must not be None"
22592259
assert packed_data is not None, "packed_data must not be None"
@@ -2264,42 +2264,43 @@ def calculate_mfu_mbu(
22642264
# Get GPU specifications
22652265
device_name = utils.get_device_name(torch.cuda.get_device_name(0))
22662266
device_flops = utils.GPU_SPECS[device_name]["flops"]
2267-
device_memory_bandwidth = utils.GPU_SPECS[device_name]["memory_bandwidth"]
2268-
2269-
# For GRPO, we have multiple samples per prompt
2270-
# prompt_lengths contains lengths for unique prompts
2271-
# response_lengths contains lengths for all samples (num_prompts * samples_per_prompt)
2272-
assert (
2273-
len(packed_data.response_lengths) == len(packed_data.prompt_lengths) * args.num_samples_per_prompt_rollout
2274-
), (
2275-
f"Expected {len(packed_data.prompt_lengths) * args.num_samples_per_prompt_rollout} response lengths, "
2276-
f"got {len(packed_data.response_lengths)}"
2267+
2268+
# For training, we need to calculate total sequence lengths (prompt + response)
2269+
# This represents the full sequence that the model is trained on
2270+
total_sequence_lengths = []
2271+
response_idx = 0
2272+
for prompt_len in packed_data.prompt_lengths:
2273+
# For each unique prompt, get all its response lengths
2274+
for _ in range(args.num_samples_per_prompt_rollout):
2275+
response_len = packed_data.response_lengths[response_idx]
2276+
total_sequence_lengths.append(prompt_len + response_len)
2277+
response_idx += 1
2278+
2279+
# Create a new ModelDims instance with is_training=True
2280+
training_model_dims = utils.ModelDims(
2281+
num_layers=model_dims.num_layers,
2282+
hidden_size=model_dims.hidden_size,
2283+
intermediate_size=model_dims.intermediate_size,
2284+
vocab_size=model_dims.vocab_size,
2285+
num_attn_heads=model_dims.num_attn_heads,
2286+
num_kv_heads=model_dims.num_kv_heads,
2287+
is_training=True,
22772288
)
22782289

2279-
# Calculate FLOPs with proper handling of samples_per_prompt
2280-
# Note: prompt prefill is only done once per unique prompt, not per sample
2281-
total_flops = model_dims.flops(
2282-
prompt_lengths=packed_data.prompt_lengths,
2283-
response_lengths=packed_data.response_lengths,
2284-
samples_per_prompt=args.num_samples_per_prompt_rollout,
2290+
# Calculate FLOPs for training (forward + backward + gradient)
2291+
# Pass the total sequence lengths as prompt_lengths, with response_lengths=None
2292+
total_flops = training_model_dims.flops(
2293+
prompt_lengths=total_sequence_lengths,
2294+
response_lengths=None, # None for training mode
2295+
samples_per_prompt=1, # Each sequence is treated independently for training
22852296
)
22862297

22872298
# MFU = (FLOPs / time) / peak_FLOPS * 100
22882299
flops_per_second = total_flops / train_duration
22892300
mfu = 100 * flops_per_second / device_flops
22902301

2291-
# Calculate memory bandwidth utilization
2292-
total_memory_bytes = model_dims.memory_bytes(
2293-
prompt_lengths=packed_data.prompt_lengths,
2294-
response_lengths=packed_data.response_lengths,
2295-
samples_per_prompt=args.num_samples_per_prompt_rollout,
2296-
)
2297-
2298-
# MBU = (Memory bytes / time) / peak_bandwidth * 100
2299-
bytes_per_second = total_memory_bytes / train_duration
2300-
mbu = 100 * bytes_per_second / device_memory_bandwidth
2301-
2302-
return {"mfu": mfu, "mbu": mbu}
2302+
# MBU is not reported during training as requested
2303+
return {"mfu": mfu}
23032304

23042305

23052306
def one_training_step(
@@ -2383,7 +2384,6 @@ def one_training_step(
23832384
"epoch": episode / args.num_samples_per_prompt_rollout / len(train_dataset),
23842385
"learner_tokens_per_second": num_total_tokens / total_time,
23852386
"learner_mfu": utilization_metrics["mfu"],
2386-
"learner_mbu": utilization_metrics["mbu"],
23872387
"time/total": total_time,
23882388
"time/training": train_timer.duration,
23892389
"time/saving": save_time,

open_instruct/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,6 +1661,8 @@ def check_oe_eval_internal():
16611661
# Approximate softmax cost per attention score:
16621662
# ~4 scalar ops/score: exp + subtract max (stabilization) + sum + divide.
16631663
SOFTMAX_FLOPS_PER_SCORE = 4
1664+
# Training multiplier: forward + backward + gradient computation
1665+
TRAINING_FLOP_MULT = 3
16641666

16651667

16661668
@dataclasses.dataclass
@@ -1671,6 +1673,7 @@ class ModelDims:
16711673
vocab_size: int
16721674
num_attn_heads: int
16731675
num_kv_heads: Optional[int] = None
1676+
is_training: bool = False
16741677

16751678
def __post_init__(self):
16761679
if self.num_kv_heads is None:
@@ -1740,6 +1743,7 @@ def decode_flops(
17401743
17411744
Embedding lookups are ignored by design.
17421745
"""
1746+
assert not self.is_training, "decode_flops should not be called when is_training=True"
17431747
assert len(response_lengths) == len(prompt_lengths) * samples_per_prompt, (
17441748
f"Expected {len(prompt_lengths) * samples_per_prompt} response lengths, got {len(response_lengths)}"
17451749
)
@@ -1771,9 +1775,16 @@ def flops(
17711775
response_lengths: List of response lengths (samples_per_prompt * len(prompt_lengths) total)
17721776
samples_per_prompt: Number of samples generated per prompt
17731777
"""
1778+
if self.is_training:
1779+
assert response_lengths is None, "response_lengths should be None when is_training=True"
1780+
17741781
total = self.prefill_flops(prompt_lengths)
17751782
if response_lengths is not None:
17761783
total += self.decode_flops(prompt_lengths, response_lengths, samples_per_prompt)
1784+
1785+
if self.is_training:
1786+
total *= TRAINING_FLOP_MULT
1787+
17771788
return total
17781789

17791790
def weight_memory_bytes(self, num_tokens: int, dtype_bytes: int = 2) -> int:

0 commit comments

Comments
 (0)