Skip to content

GRPOTrainer vLLM colocate hardcodes MASTER_PORT=12345 so no parallel runs possible #3979

@ruggsea

Description

@ruggsea

Reproduction

When using GRPOTrainer with vLLM in colocate mode, launching two or more training runs on the same machine at the same time crashes every run after the first one with a port-in-use error from torch.distributed. I think that TRL sets a fixed default MASTER_PORT=12345, so concurrent runs collide on the same port.

Evidence in the grpo_trainer.py

                os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
                os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12345")

The error:

torch.distributed.DistNetworkError: The server socket has failed to listen on any local network address. port: 12345, useIpv6: false, code: -98, name: EADDRINUSE, message: address already in use

Steps to Reproduce

Start a training run on one GPU of a multigpu node using GRPO with use_vllm=True and vllm_mode="colocate".
In another shell on the same host, start a second run on another GPU.
The second run fails with EADDRINUSE on port 12345.

What I would expect

Multiple colocated GRPO runs on the same machine should not fight over a fixed rendezvous port, or there should be a supported, documented way to set distinct ports per run.

Environment

TRL: 0.19.x
vLLM: 0.10.1.1
CUDA, single node
Mode: use_vllm=True, vllm_mode="colocate"

How I would fix this

Replace the hardcoded default with a safer behavior and expose config:
Add to GRPOConfig:

distributed_master_addr: Optional[str] = None
distributed_master_port: Optional[int] = None

System Info

  • Platform: Linux-6.14.0-27-generic-x86_64-with-glibc2.41
  • Python version: 3.10.16
  • TRL version: 0.21.0
  • PyTorch version: 2.7.1
  • accelerator(s): NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200
  • Transformers version: 4.55.4
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • Datasets version: 3.2.0
  • HF Hub version: 0.34.4
  • bitsandbytes version: 0.45.1
  • DeepSpeed version: not installed
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: 1.102.0
  • PEFT version: 0.14.0
  • vLLM version: 0.10.1.1

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 GRPORelated to GRPO🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions