diff --git a/clusterscope/cli.py b/clusterscope/cli.py index 4bbcb43..6883468 100644 --- a/clusterscope/cli.py +++ b/clusterscope/cli.py @@ -213,6 +213,7 @@ def task(): default=None, help="Time limit for the job (format: HH:MM:SS or days-HH:MM:SS, optional)", ) +@click.option("--slurm-cmd", default=None, help="Command to run on Slurm") def slurm( num_gpus: int, num_tasks_per_node: int, @@ -221,6 +222,7 @@ def slurm( account: str, qos: str, time: str, + slurm_cmd: str, ): """Generate job requirements for a task of a Slurm job.""" partitions = get_partition_info() @@ -239,6 +241,7 @@ def slurm( account=account, qos=qos, time=time, + slurm_cmd=slurm_cmd, ) # Route to the correct format method based on CLI option diff --git a/clusterscope/cluster_info.py b/clusterscope/cluster_info.py index 51ee28b..855c646 100644 --- a/clusterscope/cluster_info.py +++ b/clusterscope/cluster_info.py @@ -27,6 +27,7 @@ class ResourceShape(NamedTuple): tasks_per_node: int gpus_per_node: int slurm_partition: str + slurm_cmd: Optional[str] = None account: Optional[str] = None qos: Optional[str] = None time: Optional[str] = None @@ -59,6 +60,8 @@ def to_sbatch(self) -> str: value = getattr(self, attr_name) if value is not None: lines.append(f"#SBATCH --{attr_name}={value}") + if self.slurm_cmd is not None: + lines.append(self.slurm_cmd) return "\n".join(lines) def to_srun(self) -> str: @@ -79,6 +82,8 @@ def to_srun(self) -> str: value = getattr(self, attr_name) if value is not None: cmd_parts.append(f"--{attr_name}={value}") + if self.slurm_cmd is not None: + cmd_parts.append(self.slurm_cmd) return " ".join(cmd_parts) def to_submitit(self) -> str: