diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 950e3bd6f..b954dd0c1 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -524,6 +524,9 @@ def wrapper() -> Any: if origin in (list, List): elem_type = get_args(ftype)[0] arg_kwargs.update(nargs="*", type=elem_type) + elif ftype is bool: + # Special handling for boolean arguments + arg_kwargs.update(type=lambda x: x.lower() in ["true", "1", "yes"]) else: arg_kwargs.update(type=ftype)