From 6e4cc339bf0d558e76d397c59167391896eb0a31 Mon Sep 17 00:00:00 2001 From: Yernar Sadybekov Date: Wed, 16 Jul 2025 10:50:27 -0700 Subject: [PATCH 1/2] Fix boolean support in cmd_conf util (#3198) Summary: Fixed and extended the boolean argument parsing in the cmd_conf decorator utility. Differential Revision: D78303765 --- torchrec/distributed/benchmark/benchmark_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 950e3bd6f..b954dd0c1 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -524,6 +524,9 @@ def wrapper() -> Any: if origin in (list, List): elem_type = get_args(ftype)[0] arg_kwargs.update(nargs="*", type=elem_type) + elif ftype is bool: + # Special handling for boolean arguments + arg_kwargs.update(type=lambda x: x.lower() in ["true", "1", "yes"]) else: arg_kwargs.update(type=ftype) From 2d8f87452ada54bad72449b8416ca3ceb8cd5780 Mon Sep 17 00:00:00 2001 From: Yernar Sadybekov Date: Wed, 16 Jul 2025 10:50:27 -0700 Subject: [PATCH 2/2] Inference mode configuration support for benchmarking Summary: Added a configuration option to run end-to-end pipeline benchmark in inference mode (`.eval()` mode). Previously the models were in training mode by default. Models in training mode have additional overhead of loss computation, backward pass, etc... Differential Revision: D78303867 --- .../benchmark/benchmark_train_pipeline.py | 7 ++++ .../train_pipeline/train_pipelines.py | 35 ++++++++++++++++--- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index a13158352..adc729ff5 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -91,6 +91,9 @@ class RunOptions: Default is "EXACT_ADAGRAD". sparse_lr (float): Learning rate for sparse parameters. Default is 0.1. + training_mode (bool): Whether to run the model in training mode. + If True, model remains in training mode. If False, model.eval() is called. + Default is True. """ world_size: int = 2 @@ -110,6 +113,7 @@ class RunOptions: sparse_lr: float = 0.1 sparse_momentum: Optional[float] = None sparse_weight_decay: Optional[float] = None + training_mode: bool = True @dataclass @@ -343,6 +347,9 @@ def runner( planner=planner, ) + if not run_option.training_mode: + sharded_model.eval() + def _func_to_benchmark( bench_inputs: List[ModelInput], model: nn.Module, diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 9007f55bd..850f915ab 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -95,6 +95,20 @@ has_2d_support = False +# Returns (losses, output) from model forward pass. Losses is None if model is in eval() mode +# pyre-ignore[3] +def unpack_model_fwd( + model_fwd_fn: Callable[[Any], Any], batch: Any, training: bool # pyre-ignore[2] +) -> Tuple[torch.Tensor, Any]: + result = model_fwd_fn(batch) + if training: + # result expected to be (losses, output) + return result + else: + # result expected to be output only + return None, result # pyre-ignore[7] + + class ModelDetachedException(Exception): pass @@ -667,7 +681,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # forward with record_function("## forward ##"): - losses, output = self._model_fwd(self.batches[0]) + losses, output = unpack_model_fwd( + self._model_fwd, self.batches[0], self._model.training + ) if self._enqueue_batch_after_forward: # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here. @@ -1030,7 +1046,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # forward with record_function("## forward ##"): - losses, output = self._model_fwd(self.batches[0]) + losses, output = unpack_model_fwd( + self._model_fwd, self.batches[0], self._model.training + ) if len(self.batches) >= 2: # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) @@ -1254,7 +1272,10 @@ def _mlp_forward( _wait_for_events( batch, context, torch.get_device_module(self._device).current_stream() ) - return self._model_fwd(batch) + losses, output = unpack_model_fwd( + self._model_fwd, batch, self._model.training + ) + return losses, output def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None: assert len(context.embedding_features) == len(context.embedding_tensors) @@ -1466,7 +1487,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: self._wait_sparse_data_dist() # forward with record_function("## forward ##"): - losses, output = self._model_fwd(self._batch_i) + losses, output = unpack_model_fwd( + self._model_fwd, self._batch_i, self._model.training + ) self._prefetch(self._batch_ip1) @@ -2024,7 +2047,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # forward ctx = self.get_compiled_autograd_ctx() with ctx, torchrec_use_sync_collectives(), record_function("## forward ##"): - losses, output = self._model_fwd(self.batches[0]) + losses, output = unpack_model_fwd( + self._model_fwd, self.batches[0], self._model.training + ) if len(self.batches) >= 2: self.wait_sparse_data_dist(self.contexts[1])