From a82a96a3235a8181e63b9f784305c25a02ed5ec2 Mon Sep 17 00:00:00 2001 From: Yernar Sadybekov Date: Wed, 16 Jul 2025 11:03:52 -0700 Subject: [PATCH] YAML config support for pipeline benchmarking (#3180) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3180 Added a support for YAML file configuration of the pipeline benchmarking. This feature makes easier to reproduce complex configurations without the need to CLI arguments passing. Example `.yaml ` file should look like: ``` RunOptions: world_size: 2 PipelineConfig: pipeline: "sparse" ``` Also, configs can be listed in a 'flat' way as well: ``` world_size: 2 pipeline: "sparse" ``` To run, add the `--yaml_config` flag with the `.yaml` file path. Additional flags can overwrite the `yaml` file configs as well if desired. Differential Revision: D78127340 --- .../distributed/benchmark/benchmark_utils.py | 39 ++++++++++++++++--- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 950e3bd6f..a94ad0cfc 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -40,6 +40,7 @@ import click import torch +import yaml from torch import multiprocessing as mp from torch.autograd.profiler import record_function from torchrec.distributed import DistributedModelParallel @@ -477,6 +478,13 @@ def wrapper() -> Any: sig = inspect.signature(func) parser = argparse.ArgumentParser(func.__doc__) + parser.add_argument( + "--yaml_config", + type=str, + default=None, + help="YAML config file for benchmarking", + ) + # Add loglevel argument with current logger level as default parser.add_argument( "--loglevel", @@ -485,6 +493,21 @@ def wrapper() -> Any: help="Set the logging level (e.g. info, debug, warning, error)", ) + pre_args, _ = parser.parse_known_args() + + yaml_defaults: Dict[str, Any] = {} + if pre_args.yaml_config: + try: + with open(pre_args.yaml_config, "r") as f: + yaml_defaults = yaml.safe_load(f) or {} + logger.info( + f"Loaded YAML config from {pre_args.yaml_config}: {yaml_defaults}" + ) + except Exception as e: + logger.warning( + f"Failed to load YAML config because {e}. Proceeding without it." + ) + seen_args = set() # track all -- we've added for _name, param in sig.parameters.items(): @@ -509,11 +532,17 @@ def wrapper() -> Any: ftype = non_none[0] origin = get_origin(ftype) - # Handle default_factory value - default_value = ( - f.default_factory() # pyre-ignore [29] - if f.default_factory is not MISSING - else f.default + # Handle default_factory value and allow YAML config to override it + default_value = yaml_defaults.get( + arg_name, # flat lookup + yaml_defaults.get(cls.__name__, {}).get( # hierarchy lookup + arg_name, + ( + f.default_factory() # pyre-ignore [29] + if f.default_factory is not MISSING + else f.default + ), + ), ) arg_kwargs = {