Skip to content

Commit 9ac0601

Browse files
committed
fabric: unify CLI with jsonargparse
1 parent 3998b5d commit 9ac0601

File tree

2 files changed

+97
-77
lines changed

2 files changed

+97
-77
lines changed

requirements/fabric/test.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,5 @@ pytest-cov ==6.3.0
55
pytest-timeout ==2.4.0
66
pytest-rerunfailures ==16.0.1
77
pytest-random-order ==1.2.0
8-
click ==8.1.8; python_version < "3.11"
9-
click ==8.2.1; python_version > "3.10"
8+
jsonargparse[signatures,jsonnet] >=4.39.0, <4.41.0
109
tensorboardX >=2.6, <2.7.0 # todo: relax it back to `>=2.2` after fixing tests

src/lightning/fabric/cli.py

Lines changed: 96 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import logging
1515
import os
1616
import re
17+
import sys
1718
from argparse import Namespace
1819
from typing import Any, Optional
1920

@@ -31,9 +32,12 @@
3132

3233
_log = logging.getLogger(__name__)
3334

34-
_CLICK_AVAILABLE = RequirementCache("click")
35+
_JSONARGPARSE_AVAILABLE = RequirementCache("jsonargparse")
3536
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
3637

38+
if _JSONARGPARSE_AVAILABLE:
39+
from jsonargparse import ArgumentParser
40+
3741
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu", "auto")
3842

3943

@@ -45,127 +49,112 @@ def _get_supported_strategies() -> list[str]:
4549
return [strategy for strategy in available_strategies if not re.match(excluded, strategy)]
4650

4751

48-
if _CLICK_AVAILABLE:
49-
import click
52+
def _build_parser() -> "ArgumentParser":
53+
"""Build the jsonargparse-based CLI parser with subcommands."""
54+
if not _JSONARGPARSE_AVAILABLE: # pragma: no cover
55+
raise RuntimeError(
56+
"To use the Lightning Fabric CLI, you must have `jsonargparse` installed. "
57+
"Install it by running `pip install -U jsonargparse`."
58+
)
5059

51-
@click.group()
52-
def _main() -> None:
53-
pass
60+
parser = ArgumentParser(description="Lightning Fabric command line tool")
61+
subcommands = parser.add_subcommands()
5462

55-
@_main.command(
56-
"run",
57-
context_settings={
58-
"ignore_unknown_options": True,
59-
},
60-
)
61-
@click.argument(
62-
"script",
63-
type=click.Path(exists=True),
64-
)
65-
@click.option(
63+
# run subcommand
64+
run_parser = ArgumentParser(description="Run a Lightning Fabric script.")
65+
run_parser.add_argument(
6666
"--accelerator",
67-
type=click.Choice(_SUPPORTED_ACCELERATORS),
67+
type=str,
68+
choices=_SUPPORTED_ACCELERATORS,
6869
default=None,
6970
help="The hardware accelerator to run on.",
7071
)
71-
@click.option(
72+
run_parser.add_argument(
7273
"--strategy",
73-
type=click.Choice(_get_supported_strategies()),
74+
type=str,
75+
choices=_get_supported_strategies(),
7476
default=None,
7577
help="Strategy for how to run across multiple devices.",
7678
)
77-
@click.option(
79+
run_parser.add_argument(
7880
"--devices",
7981
type=str,
8082
default="1",
8183
help=(
82-
"Number of devices to run on (``int``), which devices to run on (``list`` or ``str``), or ``'auto'``."
83-
" The value applies per node."
84+
"Number of devices to run on (int), which devices to run on (list or str), or 'auto'. "
85+
"The value applies per node."
8486
),
8587
)
86-
@click.option(
87-
"--num-nodes",
88+
run_parser.add_argument(
8889
"--num_nodes",
90+
"--num-nodes",
8991
type=int,
9092
default=1,
9193
help="Number of machines (nodes) for distributed execution.",
9294
)
93-
@click.option(
94-
"--node-rank",
95+
run_parser.add_argument(
9596
"--node_rank",
97+
"--node-rank",
9698
type=int,
9799
default=0,
98100
help=(
99-
"The index of the machine (node) this command gets started on. Must be a number in the range"
100-
" 0, ..., num_nodes - 1."
101+
"The index of the machine (node) this command gets started on. Must be a number in the range "
102+
"0, ..., num_nodes - 1."
101103
),
102104
)
103-
@click.option(
104-
"--main-address",
105+
run_parser.add_argument(
105106
"--main_address",
107+
"--main-address",
106108
type=str,
107109
default="127.0.0.1",
108110
help="The hostname or IP address of the main machine (usually the one with node_rank = 0).",
109111
)
110-
@click.option(
111-
"--main-port",
112+
run_parser.add_argument(
112113
"--main_port",
114+
"--main-port",
113115
type=int,
114116
default=29400,
115117
help="The main port to connect to the main machine.",
116118
)
117-
@click.option(
119+
run_parser.add_argument(
118120
"--precision",
119-
type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)),
121+
type=str,
122+
choices=list(get_args(_PRECISION_INPUT_STR)) + list(get_args(_PRECISION_INPUT_STR_ALIAS)),
120123
default=None,
121124
help=(
122-
"Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``32``), "
123-
"half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)"
125+
"Double precision ('64-true' or '64'), full precision ('32-true' or '32'), "
126+
"half precision ('16-mixed' or '16') or bfloat16 precision ('bf16-mixed' or 'bf16')."
124127
),
125128
)
126-
@click.argument("script_args", nargs=-1, type=click.UNPROCESSED)
127-
def _run(**kwargs: Any) -> None:
128-
"""Run a Lightning Fabric script.
129-
130-
SCRIPT is the path to the Python script with the code to run. The script must contain a Fabric object.
131-
132-
SCRIPT_ARGS are the remaining arguments that you can pass to the script itself and are expected to be parsed
133-
there.
134-
135-
"""
136-
script_args = list(kwargs.pop("script_args", []))
137-
main(args=Namespace(**kwargs), script_args=script_args)
129+
run_parser.add_argument(
130+
"script",
131+
type=str,
132+
help="Path to the Python script with the code to run. The script must contain a Fabric object.",
133+
)
134+
subcommands.add_subcommand("run", run_parser, help="Run a Lightning Fabric script")
138135

139-
@_main.command(
140-
"consolidate",
141-
context_settings={
142-
"ignore_unknown_options": True,
143-
},
136+
# consolidate subcommand
137+
con_parser = ArgumentParser(
138+
description="Convert a distributed/sharded checkpoint into a single file that can be loaded with torch.load()."
144139
)
145-
@click.argument(
140+
con_parser.add_argument(
146141
"checkpoint_folder",
147-
type=click.Path(exists=True),
142+
type=str,
143+
help="Path to the checkpoint folder to consolidate.",
148144
)
149-
@click.option(
145+
con_parser.add_argument(
150146
"--output_file",
151-
type=click.Path(exists=True),
147+
type=str,
152148
default=None,
153149
help=(
154-
"Path to the file where the converted checkpoint should be saved. The file should not already exist."
155-
" If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
156-
" and a '.consolidated' suffix."
150+
"Path to the file where the converted checkpoint should be saved. The file should not already exist. "
151+
"If not provided, the file will be saved next to the input checkpoint folder with the same name and a "
152+
"'.consolidated' suffix."
157153
),
158154
)
159-
def _consolidate(checkpoint_folder: str, output_file: Optional[str]) -> None:
160-
"""Convert a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`.
161-
162-
Only supports FSDP sharded checkpoints at the moment.
155+
subcommands.add_subcommand("consolidate", con_parser, help="Consolidate a distributed checkpoint")
163156

164-
"""
165-
args = Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file)
166-
config = _process_cli_args(args)
167-
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
168-
torch.save(checkpoint, config.output_file)
157+
return parser
169158

170159

171160
def _set_env_variables(args: Namespace) -> None:
@@ -234,12 +223,44 @@ def main(args: Namespace, script_args: Optional[list[str]] = None) -> None:
234223
_torchrun_launch(args, script_args or [])
235224

236225

237-
if __name__ == "__main__":
238-
if not _CLICK_AVAILABLE: # pragma: no cover
226+
def _run_command(cfg: Namespace, script_args: list[str]) -> None:
227+
"""Execute the 'run' subcommand with the provided config and extra script args."""
228+
main(args=Namespace(**cfg), script_args=script_args)
229+
230+
231+
def _consolidate_command(cfg: Namespace) -> None:
232+
"""Execute the 'consolidate' subcommand with the provided config."""
233+
args = Namespace(checkpoint_folder=cfg.checkpoint_folder, output_file=cfg.output_file)
234+
config = _process_cli_args(args)
235+
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
236+
torch.save(checkpoint, config.output_file)
237+
238+
239+
def cli_main(argv: Optional[list[str]] = None) -> None:
240+
"""Entry point for the Fabric CLI using jsonargparse."""
241+
if not _JSONARGPARSE_AVAILABLE: # pragma: no cover
239242
_log.error(
240-
"To use the Lightning Fabric CLI, you must have `click` installed."
241-
" Install it by running `pip install -U click`."
243+
"To use the Lightning Fabric CLI, you must have `jsonargparse` installed."
244+
" Install it by running `pip install -U jsonargparse`."
242245
)
243246
raise SystemExit(1)
244247

245-
_run()
248+
parser = _build_parser()
249+
# parse_known_args so that for 'run' we can forward unknown args to the user script
250+
cfg, unknown = parser.parse_known_args(argv)
251+
252+
if not getattr(cfg, "subcommand", None):
253+
parser.print_help()
254+
return
255+
256+
if cfg.subcommand == "run":
257+
# unknown contains the script's own args
258+
_run_command(cfg.run, unknown)
259+
elif cfg.subcommand == "consolidate":
260+
_consolidate_command(cfg.consolidate)
261+
else: # pragma: no cover
262+
parser.print_help()
263+
264+
265+
if __name__ == "__main__":
266+
cli_main(sys.argv[1:])

0 commit comments

Comments
 (0)