14
14
import logging
15
15
import os
16
16
import re
17
+ import sys
17
18
from argparse import Namespace
18
19
from typing import Any , Optional
19
20
31
32
32
33
_log = logging .getLogger (__name__ )
33
34
34
- _CLICK_AVAILABLE = RequirementCache ("click " )
35
+ _JSONARGPARSE_AVAILABLE = RequirementCache ("jsonargparse " )
35
36
_LIGHTNING_SDK_AVAILABLE = RequirementCache ("lightning_sdk" )
36
37
38
+ if _JSONARGPARSE_AVAILABLE :
39
+ from jsonargparse import ArgumentParser
40
+
37
41
_SUPPORTED_ACCELERATORS = ("cpu" , "gpu" , "cuda" , "mps" , "tpu" , "auto" )
38
42
39
43
@@ -45,127 +49,112 @@ def _get_supported_strategies() -> list[str]:
45
49
return [strategy for strategy in available_strategies if not re .match (excluded , strategy )]
46
50
47
51
48
- if _CLICK_AVAILABLE :
49
- import click
52
+ def _build_parser () -> "ArgumentParser" :
53
+ """Build the jsonargparse-based CLI parser with subcommands."""
54
+ if not _JSONARGPARSE_AVAILABLE : # pragma: no cover
55
+ raise RuntimeError (
56
+ "To use the Lightning Fabric CLI, you must have `jsonargparse` installed. "
57
+ "Install it by running `pip install -U jsonargparse`."
58
+ )
50
59
51
- @click .group ()
52
- def _main () -> None :
53
- pass
60
+ parser = ArgumentParser (description = "Lightning Fabric command line tool" )
61
+ subcommands = parser .add_subcommands ()
54
62
55
- @_main .command (
56
- "run" ,
57
- context_settings = {
58
- "ignore_unknown_options" : True ,
59
- },
60
- )
61
- @click .argument (
62
- "script" ,
63
- type = click .Path (exists = True ),
64
- )
65
- @click .option (
63
+ # run subcommand
64
+ run_parser = ArgumentParser (description = "Run a Lightning Fabric script." )
65
+ run_parser .add_argument (
66
66
"--accelerator" ,
67
- type = click .Choice (_SUPPORTED_ACCELERATORS ),
67
+ type = str ,
68
+ choices = _SUPPORTED_ACCELERATORS ,
68
69
default = None ,
69
70
help = "The hardware accelerator to run on." ,
70
71
)
71
- @ click . option (
72
+ run_parser . add_argument (
72
73
"--strategy" ,
73
- type = click .Choice (_get_supported_strategies ()),
74
+ type = str ,
75
+ choices = _get_supported_strategies (),
74
76
default = None ,
75
77
help = "Strategy for how to run across multiple devices." ,
76
78
)
77
- @ click . option (
79
+ run_parser . add_argument (
78
80
"--devices" ,
79
81
type = str ,
80
82
default = "1" ,
81
83
help = (
82
- "Number of devices to run on (`` int`` ), which devices to run on (`` list`` or `` str`` ), or `` 'auto'``. "
83
- " The value applies per node."
84
+ "Number of devices to run on (int), which devices to run on (list or str), or 'auto'. "
85
+ "The value applies per node."
84
86
),
85
87
)
86
- @click .option (
87
- "--num-nodes" ,
88
+ run_parser .add_argument (
88
89
"--num_nodes" ,
90
+ "--num-nodes" ,
89
91
type = int ,
90
92
default = 1 ,
91
93
help = "Number of machines (nodes) for distributed execution." ,
92
94
)
93
- @click .option (
94
- "--node-rank" ,
95
+ run_parser .add_argument (
95
96
"--node_rank" ,
97
+ "--node-rank" ,
96
98
type = int ,
97
99
default = 0 ,
98
100
help = (
99
- "The index of the machine (node) this command gets started on. Must be a number in the range"
100
- " 0, ..., num_nodes - 1."
101
+ "The index of the machine (node) this command gets started on. Must be a number in the range "
102
+ "0, ..., num_nodes - 1."
101
103
),
102
104
)
103
- @click .option (
104
- "--main-address" ,
105
+ run_parser .add_argument (
105
106
"--main_address" ,
107
+ "--main-address" ,
106
108
type = str ,
107
109
default = "127.0.0.1" ,
108
110
help = "The hostname or IP address of the main machine (usually the one with node_rank = 0)." ,
109
111
)
110
- @click .option (
111
- "--main-port" ,
112
+ run_parser .add_argument (
112
113
"--main_port" ,
114
+ "--main-port" ,
113
115
type = int ,
114
116
default = 29400 ,
115
117
help = "The main port to connect to the main machine." ,
116
118
)
117
- @ click . option (
119
+ run_parser . add_argument (
118
120
"--precision" ,
119
- type = click .Choice (get_args (_PRECISION_INPUT_STR ) + get_args (_PRECISION_INPUT_STR_ALIAS )),
121
+ type = str ,
122
+ choices = list (get_args (_PRECISION_INPUT_STR )) + list (get_args (_PRECISION_INPUT_STR_ALIAS )),
120
123
default = None ,
121
124
help = (
122
- "Double precision (`` 64-true`` or ``64`` ), full precision (`` 32-true`` or ``32`` ), "
123
- "half precision (`` 16-mixed`` or ``16`` ) or bfloat16 precision (`` bf16-mixed`` or `` bf16``) "
125
+ "Double precision (' 64-true' or '64' ), full precision (' 32-true' or '32' ), "
126
+ "half precision (' 16-mixed' or '16' ) or bfloat16 precision (' bf16-mixed' or ' bf16'). "
124
127
),
125
128
)
126
- @click .argument ("script_args" , nargs = - 1 , type = click .UNPROCESSED )
127
- def _run (** kwargs : Any ) -> None :
128
- """Run a Lightning Fabric script.
129
-
130
- SCRIPT is the path to the Python script with the code to run. The script must contain a Fabric object.
131
-
132
- SCRIPT_ARGS are the remaining arguments that you can pass to the script itself and are expected to be parsed
133
- there.
134
-
135
- """
136
- script_args = list (kwargs .pop ("script_args" , []))
137
- main (args = Namespace (** kwargs ), script_args = script_args )
129
+ run_parser .add_argument (
130
+ "script" ,
131
+ type = str ,
132
+ help = "Path to the Python script with the code to run. The script must contain a Fabric object." ,
133
+ )
134
+ subcommands .add_subcommand ("run" , run_parser , help = "Run a Lightning Fabric script" )
138
135
139
- @_main .command (
140
- "consolidate" ,
141
- context_settings = {
142
- "ignore_unknown_options" : True ,
143
- },
136
+ # consolidate subcommand
137
+ con_parser = ArgumentParser (
138
+ description = "Convert a distributed/sharded checkpoint into a single file that can be loaded with torch.load()."
144
139
)
145
- @ click . argument (
140
+ con_parser . add_argument (
146
141
"checkpoint_folder" ,
147
- type = click .Path (exists = True ),
142
+ type = str ,
143
+ help = "Path to the checkpoint folder to consolidate." ,
148
144
)
149
- @ click . option (
145
+ con_parser . add_argument (
150
146
"--output_file" ,
151
- type = click . Path ( exists = True ) ,
147
+ type = str ,
152
148
default = None ,
153
149
help = (
154
- "Path to the file where the converted checkpoint should be saved. The file should not already exist."
155
- " If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
156
- " and a '.consolidated' suffix."
150
+ "Path to the file where the converted checkpoint should be saved. The file should not already exist. "
151
+ "If not provided, the file will be saved next to the input checkpoint folder with the same name and a "
152
+ "'.consolidated' suffix."
157
153
),
158
154
)
159
- def _consolidate (checkpoint_folder : str , output_file : Optional [str ]) -> None :
160
- """Convert a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`.
161
-
162
- Only supports FSDP sharded checkpoints at the moment.
155
+ subcommands .add_subcommand ("consolidate" , con_parser , help = "Consolidate a distributed checkpoint" )
163
156
164
- """
165
- args = Namespace (checkpoint_folder = checkpoint_folder , output_file = output_file )
166
- config = _process_cli_args (args )
167
- checkpoint = _load_distributed_checkpoint (config .checkpoint_folder )
168
- torch .save (checkpoint , config .output_file )
157
+ return parser
169
158
170
159
171
160
def _set_env_variables (args : Namespace ) -> None :
@@ -234,12 +223,44 @@ def main(args: Namespace, script_args: Optional[list[str]] = None) -> None:
234
223
_torchrun_launch (args , script_args or [])
235
224
236
225
237
- if __name__ == "__main__" :
238
- if not _CLICK_AVAILABLE : # pragma: no cover
226
+ def _run_command (cfg : Namespace , script_args : list [str ]) -> None :
227
+ """Execute the 'run' subcommand with the provided config and extra script args."""
228
+ main (args = Namespace (** cfg ), script_args = script_args )
229
+
230
+
231
+ def _consolidate_command (cfg : Namespace ) -> None :
232
+ """Execute the 'consolidate' subcommand with the provided config."""
233
+ args = Namespace (checkpoint_folder = cfg .checkpoint_folder , output_file = cfg .output_file )
234
+ config = _process_cli_args (args )
235
+ checkpoint = _load_distributed_checkpoint (config .checkpoint_folder )
236
+ torch .save (checkpoint , config .output_file )
237
+
238
+
239
+ def cli_main (argv : Optional [list [str ]] = None ) -> None :
240
+ """Entry point for the Fabric CLI using jsonargparse."""
241
+ if not _JSONARGPARSE_AVAILABLE : # pragma: no cover
239
242
_log .error (
240
- "To use the Lightning Fabric CLI, you must have `click ` installed."
241
- " Install it by running `pip install -U click `."
243
+ "To use the Lightning Fabric CLI, you must have `jsonargparse ` installed."
244
+ " Install it by running `pip install -U jsonargparse `."
242
245
)
243
246
raise SystemExit (1 )
244
247
245
- _run ()
248
+ parser = _build_parser ()
249
+ # parse_known_args so that for 'run' we can forward unknown args to the user script
250
+ cfg , unknown = parser .parse_known_args (argv )
251
+
252
+ if not getattr (cfg , "subcommand" , None ):
253
+ parser .print_help ()
254
+ return
255
+
256
+ if cfg .subcommand == "run" :
257
+ # unknown contains the script's own args
258
+ _run_command (cfg .run , unknown )
259
+ elif cfg .subcommand == "consolidate" :
260
+ _consolidate_command (cfg .consolidate )
261
+ else : # pragma: no cover
262
+ parser .print_help ()
263
+
264
+
265
+ if __name__ == "__main__" :
266
+ cli_main (sys .argv [1 :])
0 commit comments