Skip to content

Commit ef16909

Browse files
authored
Fix multi gpu trainer (PrimeIntellect-ai#492)
1 parent ef54176 commit ef16909

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/zeroband/training/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def train(config: TrainingConfig):
5757
# Optionally, sidecar the orchestrator
5858
orchestrator = None
5959
if config.orchestrator and world.rank == 0:
60+
config.orchestrator.num_train_workers = world.world_size
6061
logger.info("Starting orchestrator in a separate process")
6162

6263
# Create a queue for orchestrator to signal when setup is complete
@@ -89,6 +90,7 @@ def train(config: TrainingConfig):
8990
torch._dynamo.config.suppress_errors = True
9091

9192
torch.set_float32_matmul_precision("high")
93+
torch.cuda.set_device(world.rank)
9294

9395
if config.weights.path and world.rank == 0:
9496
if envs.SHARDCAST_OUTPUT_DIR is not None:

0 commit comments

Comments
 (0)