Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def tb_wandb_log(
all_ranks: bool = False,
):
# logs to both tb and wandb (if present) from the zeroth rank
do_log = torch.distributed.get_rank() == 0 or all_ranks
do_log = all_ranks or torch.distributed.get_rank() == 0
if do_log and value is not None:
if tensorboard_writer:
tensorboard_writer.add_scalar(key, value, iteration_no)
Expand Down
86 changes: 60 additions & 26 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import time
import socket
from typing import Dict, List
import importlib

import requests

Expand Down Expand Up @@ -155,35 +156,68 @@ def get_wandb_api_key(neox_args):


def init_wandb(neox_args):
# Wandb. (one worker per machine)
if neox_args.use_wandb == False:
"""
Initialise wandb once per distributed process so that every rank
publishes its own system metrics while sharing a single run-ID.
- Requires wandb ≥ 0.19.5
- Rank_0 is still the primary writer; other ranks set `x_primary=False`
to avoid race-conditions on artefact uploads / run-state updates.
"""
if not neox_args.use_wandb:
return

if not neox_args.wandb_init_all_ranks:
use_wandb = is_local_main() and (
get_wandb_api_key(neox_args=neox_args) is not None
# If the user didn't define WANDB_RUN_ID we create a fresh run‑id
# on rank‑0 and broadcast it so all ranks attach to the same run.
run_id = os.environ.get("WANDB_RUN_ID")
if torch.distributed.is_initialized():
if run_id is None:
if torch.distributed.get_rank() == 0:
run_id = wandb.util.generate_id()
run_id = torch.distributed.broadcast_object_list([run_id])[0]
else:
run_id = torch.distributed.broadcast_object_list([run_id])[0]

rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
local_rank_ = local_rank()

_settings = dict(
mode = "shared",
x_label = f"rank-{rank}",
x_stats_sampling_interval= 1,
x_stats_gpu_device_ids = [local_rank_],
)
if rank != 0:
_settings.update(
x_primary = False,
x_update_finish_state= False,
)

use_wandb = get_wandb_api_key(neox_args) is not None
neox_args.update_value("use_wandb", use_wandb)
if not use_wandb:
return

try:
wandb.init(
project = neox_args.wandb_project,
entity = neox_args.wandb_team,
group = neox_args.wandb_group,
name = neox_args.wandb_run_name,
id = run_id,
resume = "allow",
settings = wandb.Settings(**_settings),
save_code = False,
force = False,
)
if rank == 0:
wandb.config.update(neox_args.all_config)
except wandb.UsageError as e:
neox_args.update_value("use_wandb", False)
print(e)
print(
"Skipping wandb. Execute `wandb login` on all nodes or set WANDB_API_KEY.",
flush=True,
)
neox_args.update_value("use_wandb", use_wandb)
if neox_args.use_wandb:
group_name = neox_args.wandb_group
run_name = neox_args.wandb_run_name
try:
wandb.init(
project=neox_args.wandb_project,
group=group_name,
name=run_name,
save_code=False,
force=False,
entity=neox_args.wandb_team,
)
except wandb.UsageError as e:
neox_args.update_value("use_wandb", False)
print(e)
print(
"Skipping wandb. Execute `wandb login` on local or main node machine to enable.",
flush=True,
)
wandb.config.update(neox_args.all_config)


def obtain_resource_pool(
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-wandb.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
wandb>=0.10.28
wandb>=0.19.5
Loading