@@ -2244,7 +2244,7 @@ def generate_thread(args, vllm_engines, resume_training_step, stop_event, genera
22442244def calculate_mfu_mbu (
22452245 model_dims : utils .ModelDims , packed_data : PackedData , train_duration : float , args : Args
22462246) -> dict [str , float ]:
2247- """Calculate Model FLOPs Utilization (MFU) and Model Bandwidth Utilization (MBU) .
2247+ """Calculate Model FLOPs Utilization (MFU) for training .
22482248
22492249 Args:
22502250 model_dims: Model dimensions for FLOPs/memory calculations
@@ -2253,7 +2253,7 @@ def calculate_mfu_mbu(
22532253 args: Training arguments
22542254
22552255 Returns:
2256- Dictionary with 'mfu' and 'mbu' keys as percentages
2256+ Dictionary with 'mfu' key as percentage (MBU not reported for training)
22572257 """
22582258 assert model_dims is not None , "model_dims must not be None"
22592259 assert packed_data is not None , "packed_data must not be None"
@@ -2264,42 +2264,43 @@ def calculate_mfu_mbu(
22642264 # Get GPU specifications
22652265 device_name = utils .get_device_name (torch .cuda .get_device_name (0 ))
22662266 device_flops = utils .GPU_SPECS [device_name ]["flops" ]
2267- device_memory_bandwidth = utils .GPU_SPECS [device_name ]["memory_bandwidth" ]
2268-
2269- # For GRPO, we have multiple samples per prompt
2270- # prompt_lengths contains lengths for unique prompts
2271- # response_lengths contains lengths for all samples (num_prompts * samples_per_prompt)
2272- assert (
2273- len (packed_data .response_lengths ) == len (packed_data .prompt_lengths ) * args .num_samples_per_prompt_rollout
2274- ), (
2275- f"Expected { len (packed_data .prompt_lengths ) * args .num_samples_per_prompt_rollout } response lengths, "
2276- f"got { len (packed_data .response_lengths )} "
2267+
2268+ # For training, we need to calculate total sequence lengths (prompt + response)
2269+ # This represents the full sequence that the model is trained on
2270+ total_sequence_lengths = []
2271+ response_idx = 0
2272+ for prompt_len in packed_data .prompt_lengths :
2273+ # For each unique prompt, get all its response lengths
2274+ for _ in range (args .num_samples_per_prompt_rollout ):
2275+ response_len = packed_data .response_lengths [response_idx ]
2276+ total_sequence_lengths .append (prompt_len + response_len )
2277+ response_idx += 1
2278+
2279+ # Create a new ModelDims instance with is_training=True
2280+ training_model_dims = utils .ModelDims (
2281+ num_layers = model_dims .num_layers ,
2282+ hidden_size = model_dims .hidden_size ,
2283+ intermediate_size = model_dims .intermediate_size ,
2284+ vocab_size = model_dims .vocab_size ,
2285+ num_attn_heads = model_dims .num_attn_heads ,
2286+ num_kv_heads = model_dims .num_kv_heads ,
2287+ is_training = True ,
22772288 )
22782289
2279- # Calculate FLOPs with proper handling of samples_per_prompt
2280- # Note: prompt prefill is only done once per unique prompt, not per sample
2281- total_flops = model_dims .flops (
2282- prompt_lengths = packed_data . prompt_lengths ,
2283- response_lengths = packed_data . response_lengths ,
2284- samples_per_prompt = args . num_samples_per_prompt_rollout ,
2290+ # Calculate FLOPs for training (forward + backward + gradient)
2291+ # Pass the total sequence lengths as prompt_lengths, with response_lengths=None
2292+ total_flops = training_model_dims .flops (
2293+ prompt_lengths = total_sequence_lengths ,
2294+ response_lengths = None , # None for training mode
2295+ samples_per_prompt = 1 , # Each sequence is treated independently for training
22852296 )
22862297
22872298 # MFU = (FLOPs / time) / peak_FLOPS * 100
22882299 flops_per_second = total_flops / train_duration
22892300 mfu = 100 * flops_per_second / device_flops
22902301
2291- # Calculate memory bandwidth utilization
2292- total_memory_bytes = model_dims .memory_bytes (
2293- prompt_lengths = packed_data .prompt_lengths ,
2294- response_lengths = packed_data .response_lengths ,
2295- samples_per_prompt = args .num_samples_per_prompt_rollout ,
2296- )
2297-
2298- # MBU = (Memory bytes / time) / peak_bandwidth * 100
2299- bytes_per_second = total_memory_bytes / train_duration
2300- mbu = 100 * bytes_per_second / device_memory_bandwidth
2301-
2302- return {"mfu" : mfu , "mbu" : mbu }
2302+ # MBU is not reported during training as requested
2303+ return {"mfu" : mfu }
23032304
23042305
23052306def one_training_step (
@@ -2383,7 +2384,6 @@ def one_training_step(
23832384 "epoch" : episode / args .num_samples_per_prompt_rollout / len (train_dataset ),
23842385 "learner_tokens_per_second" : num_total_tokens / total_time ,
23852386 "learner_mfu" : utilization_metrics ["mfu" ],
2386- "learner_mbu" : utilization_metrics ["mbu" ],
23872387 "time/total" : total_time ,
23882388 "time/training" : train_timer .duration ,
23892389 "time/saving" : save_time ,
0 commit comments