From d4132fadf7201a4a48b8edf6e1541fc69273537e Mon Sep 17 00:00:00 2001 From: Iris Zhang Date: Mon, 21 Jul 2025 17:26:08 -0700 Subject: [PATCH 1/2] @parametrize DTensor clipping to test clipping with inf norm, L1, and L2 norm (#3220) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3220 Differential Revision: D78706038 --- torchrec/optim/tests/test_clipping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrec/optim/tests/test_clipping.py b/torchrec/optim/tests/test_clipping.py index 0c837ec86..5f311459c 100644 --- a/torchrec/optim/tests/test_clipping.py +++ b/torchrec/optim/tests/test_clipping.py @@ -251,7 +251,7 @@ def _get_params_to_pg( return {param: [param.device_mesh.get_group()] for param in params} @with_comms - @parametrize("norm_type", ("inf",)) + @parametrize("norm_type", ("inf", 1, 2)) def test_dtensor_clip_all_gradients_norm( self, norm_type: Union[float, str] ) -> None: From 837ff803c2bc1df8930a232ff2aa0c9d20f698eb Mon Sep 17 00:00:00 2001 From: Iris Zhang Date: Mon, 21 Jul 2025 17:26:08 -0700 Subject: [PATCH 2/2] Refactor `_batch_cal_norm` and remove #pyre-ignore (#3200) Summary: As title. 1. Parse `norm_type` by `self._norm_type = float(norm_type)` at GradientClippingOptimizer `__init__()` immediately and only use `self._norm_type` later on so it's not susceptible to the error in D78326114. 2. Remove repeated `total_grad_norm` calculation 3. Fix #pyre errors. Reviewed By: aliafzal Differential Revision: D78398248 --- torchrec/optim/clipping.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/torchrec/optim/clipping.py b/torchrec/optim/clipping.py index 2ba9a6290..d38f08775 100644 --- a/torchrec/optim/clipping.py +++ b/torchrec/optim/clipping.py @@ -59,7 +59,7 @@ def __init__( super().__init__(optimizer) self._clipping = clipping self._max_gradient = max_gradient - self._norm_type = norm_type + self._norm_type = float(norm_type) self._check_meta: bool = True self._enable_global_grad_clip = enable_global_grad_clip self._step_num = 0 @@ -124,7 +124,7 @@ def step(self, closure: Any = None) -> None: torch.nn.utils.clip_grad_norm_( replicate_params, self._max_gradient, - norm_type=float(self._norm_type), + norm_type=self._norm_type, ) else: self.clip_grad_norm_() @@ -139,7 +139,6 @@ def step(self, closure: Any = None) -> None: def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: """Clip the gradient norm of all parameters.""" max_norm = self._max_gradient - norm_type = float(self._norm_type) all_grads = [] total_grad_norm = None @@ -157,7 +156,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: sharded_grad_norm = _batch_cal_norm( sharded_grads, max_norm, - norm_type, + self._norm_type, pgs, ) total_grad_norm = ( @@ -165,7 +164,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: if total_grad_norm is None else ( torch.maximum(total_grad_norm, sharded_grad_norm) - if norm_type == torch.inf + if self._norm_type == torch.inf else total_grad_norm + sharded_grad_norm ) ) @@ -184,7 +183,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: replicated_grad_norm = _batch_cal_norm( replicated_grads, max_norm, - norm_type, + self._norm_type, None, ) total_grad_norm = ( @@ -192,7 +191,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: if total_grad_norm is None else ( torch.maximum(total_grad_norm, replicated_grad_norm) - if norm_type == torch.inf + if self._norm_type == torch.inf else total_grad_norm + replicated_grad_norm ) ) @@ -200,11 +199,20 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: else: square_replicated_grad_norm = 0 + if total_grad_norm is not None: + total_grad_norm = ( + torch.pow(total_grad_norm, 1.0 / self._norm_type) + if self._norm_type != torch.inf + else total_grad_norm + ) + else: + return None + global log_grad_norm if log_grad_norm: - if total_grad_norm is not None and norm_type != torch.inf: + if total_grad_norm is not None and self._norm_type != torch.inf: # pyre-ignore[58] - grad_norm = total_grad_norm ** (1.0 / norm_type) + grad_norm = total_grad_norm ** (1.0 / self._norm_type) else: grad_norm = total_grad_norm @@ -213,15 +221,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}" ) - # Aggregation - if total_grad_norm is None: - return - - if norm_type != torch.inf: - # pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float. - total_grad_norm = total_grad_norm ** (1.0 / norm_type) - # pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor]. - clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6)) + clip_coef = torch.tensor(max_norm) / (total_grad_norm + 1e-6) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) torch._foreach_mul_(all_grads, clip_coef_clamped) return total_grad_norm