From ad66544287582eff136fb786e787d93729437282 Mon Sep 17 00:00:00 2001 From: Iris Zhang Date: Wed, 16 Jul 2025 23:37:57 -0700 Subject: [PATCH] Fix "inf" norm gradient clipping for FSDP2 (#3199) Summary: "inf" norm calculation for FSDP2 is incorrect, as the `total_grad_norm` would always be 1.0 regardless of the gradients. This is due to `_batch_cal_norm` would calculate `total_grad_norm = total_grad_norm ** (1.0 / norm_type)`, even if the `norm_type` is `inf` norm. For `inf` norm, since `norm_type` is `torch.inf`, `total_grad_norm` would become `total_grad_norm ** (1.0 / torch.inf) `, which would always just be 1.0. Before the fix, line 220 is comparing `self._norm_type != torch.inf`, as `self._norm_type` is either float number of `"inf"`, it would always enter this if statement, which is incorrect for the infinity norm case. This issue is found during Shampoo War Room debugging for HSDP2 (context: https://fb.workplace.com/groups/3095840833991792/permalink/4084834545092411/). During ablation study, we found that DDP Fused Adam is on-par with HSDP2 Fused Adam when 2-norm is used for clipping (Figure 1), while HSDP2 Fused Adam behaves strangely (high variance between runs) when infinity norm is used for clipping (Figure 2). Figure 1 {F1980311095} Figure 2 {F1980311115} TODO: As discussed, we should be able to significantly simplify things by leveraging a similar implementation as what's used in TorchTitan (https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/utils.py) and abstract out the gradient norm computation. Reviewed By: yoyoyocmu Differential Revision: D78326114 --- torchrec/optim/clipping.py | 10 +-- torchrec/optim/tests/test_clipping.py | 107 ++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 5 deletions(-) diff --git a/torchrec/optim/clipping.py b/torchrec/optim/clipping.py index 5ffbbbca7..2ba9a6290 100644 --- a/torchrec/optim/clipping.py +++ b/torchrec/optim/clipping.py @@ -165,7 +165,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: if total_grad_norm is None else ( torch.maximum(total_grad_norm, sharded_grad_norm) - if self._norm_type == torch.inf + if norm_type == torch.inf else total_grad_norm + sharded_grad_norm ) ) @@ -192,7 +192,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: if total_grad_norm is None else ( torch.maximum(total_grad_norm, replicated_grad_norm) - if self._norm_type == torch.inf + if norm_type == torch.inf else total_grad_norm + replicated_grad_norm ) ) @@ -202,11 +202,11 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: global log_grad_norm if log_grad_norm: - if total_grad_norm is not None and self._norm_type != torch.inf: + if total_grad_norm is not None and norm_type != torch.inf: # pyre-ignore[58] grad_norm = total_grad_norm ** (1.0 / norm_type) else: - grad_norm = 0 + grad_norm = total_grad_norm rank = dist.get_rank() logger.info( @@ -217,7 +217,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: if total_grad_norm is None: return - if self._norm_type != torch.inf: + if norm_type != torch.inf: # pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float. total_grad_norm = total_grad_norm ** (1.0 / norm_type) # pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor]. diff --git a/torchrec/optim/tests/test_clipping.py b/torchrec/optim/tests/test_clipping.py index f26fd8884..0c837ec86 100644 --- a/torchrec/optim/tests/test_clipping.py +++ b/torchrec/optim/tests/test_clipping.py @@ -8,10 +8,21 @@ # pyre-strict import unittest +from typing import Dict, List, Union from unittest.mock import MagicMock, patch import torch from torch.autograd import Variable +from torch.distributed import ProcessGroup +from torch.distributed.tensor import distribute_tensor, DTensor, init_device_mesh, Shard +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) from torchrec.optim.clipping import GradientClipping, GradientClippingOptimizer from torchrec.optim.test_utils import DummyKeyedOptimizer @@ -229,3 +240,99 @@ def test_clip_no_gradients_norm_meta_device( gradient_clipping_optimizer.step() mock_clip_grad_norm.assert_not_called() + + +@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available") +@instantiate_parametrized_tests +class TestGradientClippingDTensor(DTensorTestBase): + def _get_params_to_pg( + self, params: List[DTensor] + ) -> Dict[DTensor, List[ProcessGroup]]: + return {param: [param.device_mesh.get_group()] for param in params} + + @with_comms + @parametrize("norm_type", ("inf",)) + def test_dtensor_clip_all_gradients_norm( + self, norm_type: Union[float, str] + ) -> None: + """ + Test to ensure that the gradient clipping optimizer clips gradients + correctly with mixed DTensor and tensor by comparing gradients to its + torch.tensor counterpart. + + Note that clipping for DTensor may require communication. + """ + + # create gradient clipping optimizer containing no dtensor for reference + ref_param_1 = torch.nn.Parameter( + torch.tensor([1.0, 2.0, 3.0], device=self.device_type) + ) + ref_param_2 = torch.nn.Parameter( + torch.tensor([4.0, 5.0, 6.0], device=self.device_type) + ) + ref_keyed_optimizer = DummyKeyedOptimizer( + {"param_1": ref_param_1, "param_2": ref_param_2}, + {}, + [{"params": [ref_param_1, ref_param_2]}], + ) + ref_gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=ref_keyed_optimizer, + clipping=GradientClipping.NORM, + max_gradient=10.0, + norm_type=norm_type, + ) + ref_gradient_clipping_optimizer.zero_grad() + ref_param_1.grad = torch.tensor([12.0, 15.0, 18.0], device=self.device_type) + ref_param_2.grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type) + ref_gradient_clipping_optimizer.step() + + # create gradient clipping optimizer containing both DTensor and tensor + device_mesh = init_device_mesh(self.device_type, (self.world_size,)) + param_1 = distribute_tensor( + torch.tensor([1.0, 2.0, 3.0], requires_grad=True, device=self.device_type), + device_mesh, + [Shard(0)], + ) + param_2 = torch.tensor( + [4.0, 5.0, 6.0], requires_grad=True, device=self.device_type + ) + param_to_pgs = self._get_params_to_pg([param_1]) + keyed_optimizer = DummyKeyedOptimizer( + {"dtensor_param_1": param_1, "dtensor_param_2": param_2}, + {}, + [{"params": [param_1, param_2]}], + ) + gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=keyed_optimizer, + clipping=GradientClipping.NORM, + max_gradient=10.0, + norm_type=norm_type, + enable_global_grad_clip=True, + param_to_pgs=param_to_pgs, # pyre-ignore[6] + ) + gradient_clipping_optimizer.zero_grad() + param_1.grad = distribute_tensor( + torch.tensor([12.0, 15.0, 18.0], device=self.device_type), + device_mesh, + [Shard(0)], + ) + param_2.grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type) + gradient_clipping_optimizer.step() + + for param_group, ref_param_group in zip( + gradient_clipping_optimizer.param_groups, + ref_gradient_clipping_optimizer.param_groups, + ): + for param, ref_param in zip( + param_group["params"], ref_param_group["params"] + ): + param_grad = ( + param.grad.full_tensor() # pyre-ignore[16] + if isinstance(param, DTensor) + else param.grad + ) + self.assertEqual( + param_grad, + ref_param.grad, + f"Expect gradient to be the same. However, found {param_grad=}, {ref_param.grad=}", + )