diff --git a/torchrec/optim/clipping.py b/torchrec/optim/clipping.py index 5ffbbbca7..5b4b222fc 100644 --- a/torchrec/optim/clipping.py +++ b/torchrec/optim/clipping.py @@ -200,28 +200,23 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: else: square_replicated_grad_norm = 0 + if total_grad_norm is not None: + total_grad_norm = ( + torch.pow(total_grad_norm, 1.0 / norm_type) + if norm_type != torch.inf + else total_grad_norm + ) + else: + return None + global log_grad_norm if log_grad_norm: - if total_grad_norm is not None and self._norm_type != torch.inf: - # pyre-ignore[58] - grad_norm = total_grad_norm ** (1.0 / norm_type) - else: - grad_norm = 0 - rank = dist.get_rank() logger.info( - f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}" + f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {total_grad_norm}" ) - # Aggregation - if total_grad_norm is None: - return - - if self._norm_type != torch.inf: - # pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float. - total_grad_norm = total_grad_norm ** (1.0 / norm_type) - # pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor]. - clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6)) + clip_coef = torch.tensor(max_norm) / (total_grad_norm + 1e-6) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) torch._foreach_mul_(all_grads, clip_coef_clamped) return total_grad_norm diff --git a/torchrec/optim/tests/test_clipping.py b/torchrec/optim/tests/test_clipping.py index f26fd8884..961bcacb0 100644 --- a/torchrec/optim/tests/test_clipping.py +++ b/torchrec/optim/tests/test_clipping.py @@ -8,10 +8,21 @@ # pyre-strict import unittest +from typing import Any, Dict, List, Union from unittest.mock import MagicMock, patch import torch from torch.autograd import Variable +from torch.distributed import ProcessGroup +from torch.distributed.tensor import distribute_tensor, DTensor, init_device_mesh, Shard +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) from torchrec.optim.clipping import GradientClipping, GradientClippingOptimizer from torchrec.optim.test_utils import DummyKeyedOptimizer @@ -229,3 +240,99 @@ def test_clip_no_gradients_norm_meta_device( gradient_clipping_optimizer.step() mock_clip_grad_norm.assert_not_called() + + +@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available") +@instantiate_parametrized_tests +class TestGradientClippingDTensor(DTensorTestBase): + def _get_params_to_pg( + self, params: List[DTensor] + ) -> Dict[DTensor, List[ProcessGroup]]: + return {param: [param.device_mesh.get_group()] for param in params} + + @with_comms + @parametrize("norm_type", ("inf", 2)) + def test_dtensor_clip_all_gradients_norm( + self, norm_type: Union[float, str] + ) -> None: + """ + Test to ensure that the gradient clipping optimizer clips + dtensor gradients correctly by comparing gradients to its + torch.tensor counterpart. + + Note that clipping for DTensor may require communication. + """ + + # create gradient clipping optimizer containing no dtensor for reference + ref_param_1 = torch.nn.Parameter( + torch.tensor([1.0, 2.0, 3.0], device=self.device_type) + ) + ref_param_2 = torch.nn.Parameter( + torch.tensor([4.0, 5.0, 6.0], device=self.device_type) + ) + ref_keyed_optimizer = DummyKeyedOptimizer( + {"param_1": ref_param_1, "param_2": ref_param_2}, + {}, + [{"params": [ref_param_1, ref_param_2]}], + ) + ref_gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=ref_keyed_optimizer, + clipping=GradientClipping.NORM, + max_gradient=10.0, + norm_type=norm_type, + ) + ref_gradient_clipping_optimizer.zero_grad() + ref_param_1.grad = torch.tensor([12.0, 12.0, 12.0], device=self.device_type) + ref_param_2.grad = torch.tensor([30.0, 30.0, 30.0], device=self.device_type) + ref_gradient_clipping_optimizer.step() + + # create gradient clipping optimizer containing DTensor + device_mesh = init_device_mesh(self.device_type, (self.world_size,)) + param_1 = distribute_tensor( + torch.tensor([1.0, 2.0, 3.0], requires_grad=True, device=self.device_type), + device_mesh, + [Shard(0)], + ) + param_2 = distribute_tensor( + torch.tensor([4.0, 5.0, 6.0], requires_grad=True, device=self.device_type), + device_mesh, + [Shard(0)], + ) + param_to_pgs = self._get_params_to_pg([param_1, param_2]) + keyed_optimizer = DummyKeyedOptimizer( + {"dtensor_param_1": param_1, "dtensor_param_2": param_2}, + {}, + [{"params": [param_1, param_2]}], + ) + gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=keyed_optimizer, + clipping=GradientClipping.NORM, + max_gradient=10.0, + norm_type=norm_type, + enable_global_grad_clip=True, + param_to_pgs=param_to_pgs, # pyre-ignore[6] + ) + gradient_clipping_optimizer.zero_grad() + param_1.grad = distribute_tensor( + torch.tensor([12.0, 12.0, 12.0], device=self.device_type), + device_mesh, + [Shard(0)], + ) + param_2.grad = distribute_tensor( + torch.tensor([30.0, 30.0, 30.0], device=self.device_type), + device_mesh, + [Shard(0)], + ) + gradient_clipping_optimizer.step() + + for param_group, ref_param_group in zip( + gradient_clipping_optimizer.param_groups, + ref_gradient_clipping_optimizer.param_groups, + ): + for param, ref_param in zip( + param_group["params"], ref_param_group["params"] + ): + self.assertTrue( + torch.equal(param.grad.full_tensor(), ref_param.grad), + f"Expect gradient to be the same. However, found {param.grad=}, {ref_param.grad=}", + )