Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ jobs:
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
env:
UV_SYSTEM_PYTHON: 1

steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -123,6 +125,8 @@ jobs:
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
env:
UV_SYSTEM_PYTHON: 1

steps:
- uses: actions/checkout@v4
Expand All @@ -144,7 +148,8 @@ jobs:
make install

- name: Run tests
run: python -m unittest discover -s tests
run: |
python -m unittest discover -s tests

cli-tests:
env:
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ repos:
args: ["--install-types", "--non-interactive"]
additional_dependencies:
- types-click
- ray
171 changes: 162 additions & 9 deletions clusterscope/job_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,56 @@
# LICENSE file in the root directory of this source tree.
import os
import random
import socket
import subprocess
import time

from contextlib import closing
from functools import lru_cache
from typing import Any, Dict, MutableMapping, Optional

import ray

MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)


class RayCoordinator:
LEADER_MAX_RETRIES = 30
LEADER_RETRY_INTERVAL = 1.0

def __init__(
self,
job_id: int,
world_size: int,
):
self.job_id = job_id
self.worker_info: Dict[int, Dict[str, Any]] = {}
self.leader_port = None
self.ready_workers = 0
self.world_size = world_size

def register_worker(
self, hostname: str, rank: int, free_port: int | None
) -> Dict[str, Any]:
"""Register a worker with its placement group ID and GPU ID"""
self.ready_workers += 1
info = {
"hostname": hostname,
"rank": rank,
"ready_workers": self.ready_workers,
"world_size": self.world_size,
"leader_port": free_port,
}
self.worker_info[rank] = info
return info

def get_leader_info(self) -> Dict[str, Any] | None:
if self.ready_workers == self.world_size:
return self.worker_info[0]
else:
return None


class JobInfo:
"""
This class is used to get information about the current job.
Expand All @@ -33,6 +76,7 @@ def __init__(self):
self.is_torch_run = lambda: "LOCAL_RANK" in os.environ
self.is_torchelastic_run = lambda: "TORCHELASTIC_RUN_ID" in os.environ
self.is_slurm_job = lambda: "SLURM_JOB_ID" in os.environ
self.is_ray_job = lambda: ray.is_initialized()
self.job_id = self.get_job_id()
self.job_name = self.get_job_name()
self.global_rank = self.get_global_rank()
Expand Down Expand Up @@ -95,6 +139,8 @@ def get_world_size(self) -> int:
except ValueError:
raise RuntimeError(f"WORLD_SIZE cannot be parsed. {world_size=}")
return world_size
if self.is_ray_job():
return len(ray.nodes())
if self.is_slurm_job():
return int(os.environ["SLURM_NTASKS"])
return 1
Expand Down Expand Up @@ -136,14 +182,121 @@ def get_master_addr(self) -> str:
)
return "127.0.0.1"

def set_env_if_exists(
self,
target_key: str,
source_key: str,
source_dict: MutableMapping[str, str] = os.environ,
):
if source_key in source_dict:
os.environ[target_key] = str(source_dict[source_key])

def set_torch_distributed_env(
self, ray_coordinator_name: Optional[str] = None
) -> None:
"""
Set torch distributed env variables from slurm env variables, and ray env variables.

`ray_coordinator_name`: the name of the ray coordinator actor. If not provided, it skips setting vars from ray env.

Preferece: Ray > Slurm.
"""
self.set_torch_distributed_env_from_slurm()

if ray_coordinator_name:
self.set_torch_distributed_env_from_ray(
ray_coordinator_name=ray_coordinator_name
)

def set_torch_distributed_env_from_slurm(self) -> None:
if self.is_slurm_job():
os.environ["WORLD_SIZE"] = str(os.environ.get("SLURM_NTASKS"))
os.environ["RANK"] = str(os.environ.get("SLURM_PROCID"))
os.environ["LOCAL_WORLD_SIZE"] = os.environ.get(
"SLURM_NTASKS_PER_NODE", "1"
if not self.is_slurm_job():
return

self.set_env_if_exists(
target_key="WORLD_SIZE",
source_key="SLURM_NTASKS",
)
self.set_env_if_exists(
target_key="RANK",
source_key="SLURM_PROCID",
)
self.set_env_if_exists(
target_key="LOCAL_WORLD_SIZE",
source_key="SLURM_NTASKS_PER_NODE",
)
if "LOCAL_WORLD_SIZE" not in os.environ:
os.environ["LOCAL_WORLD_SIZE"] = "1"
self.set_env_if_exists(
target_key="LOCAL_RANK",
source_key="SLURM_LOCALID",
)
self.set_env_if_exists(
target_key="MASTER_ADDR",
source_key=self.get_master_addr(),
)
self.set_env_if_exists(
target_key="MASTER_PORT",
source_key=str(self.get_master_port()),
)
self.set_env_if_exists(
target_key="CUDA_VISIBLE_DEVICES",
source_key="SLURM_LOCALID",
)

def set_torch_distributed_env_from_ray(self, ray_coordinator_name: str) -> None:
if not self.is_ray_job():
return

hostname = socket.gethostname()

coordinator_name = os.environ.get(ray_coordinator_name)
if coordinator_name is None:
raise RuntimeError(
f"Ray coordinator name not found in environment variable {coordinator_name=}"
)
os.environ["LOCAL_RANK"] = str(os.environ.get("SLURM_LOCALID"))
os.environ["MASTER_ADDR"] = self.get_master_addr()
os.environ["MASTER_PORT"] = str(self.get_master_port())
os.environ["CUDA_VISIBLE_DEVICES"] = str(os.environ.get("SLURM_LOCALID"))

coordinator = ray.get_actor(*coordinator_name.split(":"))

free_port = None
rank = self.get_global_rank()
if rank == 0:
free_port = self._find_free_port()

worker_info = ray.get(
coordinator.register_worker.remote(
hostname=hostname, rank=rank, free_port=free_port
)
)

leader = None
for attempts in range(RayCoordinator.LEADER_MAX_RETRIES):
leader = ray.get(coordinator.get_leader_info.remote())
if leader is not None:
break
time.sleep(RayCoordinator.LEADER_RETRY_INTERVAL * (1.1**attempts))

if not leader:
raise TimeoutError(f"Worker {rank} timed out waiting for leader")

self.set_env_if_exists(
target_key="WORLD_SIZE",
source_key="world_size",
source_dict=worker_info,
)
self.set_env_if_exists(
target_key="MASTER_ADDR",
source_key="hostname",
source_dict=leader,
)
self.set_env_if_exists(
target_key="MASTER_PORT",
source_key="leader_port",
source_dict=leader,
)

def _find_free_port(self) -> int:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = s.getsockname()[1]
return int(port)
43 changes: 40 additions & 3 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,30 +1,46 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml -o dev-requirements.txt --extra dev
attrs==25.3.0
# via usort
# via
# jsonschema
# referencing
# usort
black==25.9.0
# via
# clusterscope (pyproject.toml)
# ufmt
build==1.3.0
# via clusterscope (pyproject.toml)
certifi==2025.8.3
# via requests
cfgv==3.4.0
# via pre-commit
charset-normalizer==3.4.3
# via requests
click==8.3.0
# via
# clusterscope (pyproject.toml)
# black
# moreorless
# ray
# ufmt
# usort
distlib==0.4.0
# via virtualenv
filelock==3.19.1
# via virtualenv
# via
# ray
# virtualenv
flake8==7.3.0
# via clusterscope (pyproject.toml)
identify==2.6.14
# via pre-commit
idna==3.10
# via requests
jsonschema==4.25.1
# via ray
jsonschema-specifications==2025.9.1
# via jsonschema
libcst==1.8.5
# via
# ufmt
Expand All @@ -35,6 +51,8 @@ moreorless==0.5.0
# via
# ufmt
# usort
msgpack==1.1.1
# via ray
mypy==1.18.2
# via clusterscope (pyproject.toml)
mypy-extensions==1.1.0
Expand All @@ -47,6 +65,7 @@ packaging==25.0
# via
# black
# build
# ray
pathspec==0.12.1
# via
# black
Expand All @@ -58,6 +77,8 @@ platformdirs==4.4.0
# virtualenv
pre-commit==4.3.0
# via clusterscope (pyproject.toml)
protobuf==6.32.1
# via ray
pycodestyle==2.14.0
# via flake8
pyflakes==3.4.0
Expand All @@ -70,6 +91,19 @@ pyyaml==6.0.3
# via
# libcst
# pre-commit
# ray
ray==2.49.2
# via clusterscope (pyproject.toml)
referencing==0.36.2
# via
# jsonschema
# jsonschema-specifications
requests==2.32.5
# via ray
rpds-py==0.27.1
# via
# jsonschema
# referencing
stdlibs==2025.5.10
# via usort
toml==0.10.2
Expand All @@ -79,7 +113,7 @@ tomli==2.2.1
# black
# build
# mypy
tomlkit==0.13.2
tomlkit==0.13.3
# via ufmt
trailrunner==1.4.0
# via
Expand All @@ -91,10 +125,13 @@ typing-extensions==4.15.0
# via
# black
# mypy
# referencing
# ufmt
# virtualenv
ufmt==2.8.0
# via clusterscope (pyproject.toml)
urllib3==2.5.0
# via requests
usort==1.0.8.post1
# via
# clusterscope (pyproject.toml)
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ authors = [
]
dependencies = [
"click>=8.0.0",
"ray",
]

[project.scripts]
cscope = "clusterscope.cli:main"

[project.optional-dependencies]
dev = [
"ray",
"flake8",
"black",
"ufmt",
Expand Down
Loading