Skip to content

Commit e0d9424

Browse files
authored
[grpo] fix async generate (#5634)
1 parent 07d54af commit e0d9424

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2345,7 +2345,7 @@ def _server_rollout(self, inputs: DataType, request_config: RequestConfig,
23452345

23462346
else:
23472347
# For global inputs, only main process keeps outputs
2348-
outputs = outputs if self.accelerator.is_main_process else []
2348+
outputs = all_outputs if self.accelerator.is_main_process else []
23492349

23502350
return outputs
23512351

@@ -2542,6 +2542,9 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out
25422542
return input_data
25432543

25442544
if not self.dynamic_num_samples:
2545+
if self.async_generate and not outputs:
2546+
# In async generation, only the main process receives outputs; non-main ranks get an empty list.
2547+
return outputs
25452548
assert len(inputs) == len(outputs)
25462549
return [
25472550
merge_output_input_data(deepcopy(input_data), output) for input_data, output in zip(inputs, outputs)

0 commit comments

Comments
 (0)