Skip to content

Refactor _batch_cal_norm and remove #pyre-ignore #3200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions torchrec/optim/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,28 +200,23 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
else:
square_replicated_grad_norm = 0

if total_grad_norm is not None:
total_grad_norm = (
torch.pow(total_grad_norm, 1.0 / norm_type)
if norm_type != torch.inf
else total_grad_norm
)
else:
return None

global log_grad_norm
if log_grad_norm:
if total_grad_norm is not None and self._norm_type != torch.inf:
# pyre-ignore[58]
grad_norm = total_grad_norm ** (1.0 / norm_type)
else:
grad_norm = 0

rank = dist.get_rank()
logger.info(
f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}"
f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {total_grad_norm}"
)

# Aggregation
if total_grad_norm is None:
return

if self._norm_type != torch.inf:
# pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
total_grad_norm = total_grad_norm ** (1.0 / norm_type)
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6))
clip_coef = torch.tensor(max_norm) / (total_grad_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
torch._foreach_mul_(all_grads, clip_coef_clamped)
return total_grad_norm
Expand Down
107 changes: 107 additions & 0 deletions torchrec/optim/tests/test_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,21 @@
# pyre-strict

import unittest
from typing import Any, Dict, List, Union
from unittest.mock import MagicMock, patch

import torch
from torch.autograd import Variable
from torch.distributed import ProcessGroup
from torch.distributed.tensor import distribute_tensor, DTensor, init_device_mesh, Shard
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torchrec.optim.clipping import GradientClipping, GradientClippingOptimizer
from torchrec.optim.test_utils import DummyKeyedOptimizer

Expand Down Expand Up @@ -229,3 +240,99 @@ def test_clip_no_gradients_norm_meta_device(
gradient_clipping_optimizer.step()

mock_clip_grad_norm.assert_not_called()


@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
@instantiate_parametrized_tests
class TestGradientClippingDTensor(DTensorTestBase):
def _get_params_to_pg(
self, params: List[DTensor]
) -> Dict[DTensor, List[ProcessGroup]]:
return {param: [param.device_mesh.get_group()] for param in params}

@with_comms
@parametrize("norm_type", ("inf", 2))
def test_dtensor_clip_all_gradients_norm(
self, norm_type: Union[float, str]
) -> None:
"""
Test to ensure that the gradient clipping optimizer clips
dtensor gradients correctly by comparing gradients to its
torch.tensor counterpart.

Note that clipping for DTensor may require communication.
"""

# create gradient clipping optimizer containing no dtensor for reference
ref_param_1 = torch.nn.Parameter(
torch.tensor([1.0, 2.0, 3.0], device=self.device_type)
)
ref_param_2 = torch.nn.Parameter(
torch.tensor([4.0, 5.0, 6.0], device=self.device_type)
)
ref_keyed_optimizer = DummyKeyedOptimizer(
{"param_1": ref_param_1, "param_2": ref_param_2},
{},
[{"params": [ref_param_1, ref_param_2]}],
)
ref_gradient_clipping_optimizer = GradientClippingOptimizer(
optimizer=ref_keyed_optimizer,
clipping=GradientClipping.NORM,
max_gradient=10.0,
norm_type=norm_type,
)
ref_gradient_clipping_optimizer.zero_grad()
ref_param_1.grad = torch.tensor([12.0, 12.0, 12.0], device=self.device_type)
ref_param_2.grad = torch.tensor([30.0, 30.0, 30.0], device=self.device_type)
ref_gradient_clipping_optimizer.step()

# create gradient clipping optimizer containing DTensor
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
param_1 = distribute_tensor(
torch.tensor([1.0, 2.0, 3.0], requires_grad=True, device=self.device_type),
device_mesh,
[Shard(0)],
)
param_2 = distribute_tensor(
torch.tensor([4.0, 5.0, 6.0], requires_grad=True, device=self.device_type),
device_mesh,
[Shard(0)],
)
param_to_pgs = self._get_params_to_pg([param_1, param_2])
keyed_optimizer = DummyKeyedOptimizer(
{"dtensor_param_1": param_1, "dtensor_param_2": param_2},
{},
[{"params": [param_1, param_2]}],
)
gradient_clipping_optimizer = GradientClippingOptimizer(
optimizer=keyed_optimizer,
clipping=GradientClipping.NORM,
max_gradient=10.0,
norm_type=norm_type,
enable_global_grad_clip=True,
param_to_pgs=param_to_pgs, # pyre-ignore[6]
)
gradient_clipping_optimizer.zero_grad()
param_1.grad = distribute_tensor(
torch.tensor([12.0, 12.0, 12.0], device=self.device_type),
device_mesh,
[Shard(0)],
)
param_2.grad = distribute_tensor(
torch.tensor([30.0, 30.0, 30.0], device=self.device_type),
device_mesh,
[Shard(0)],
)
gradient_clipping_optimizer.step()

for param_group, ref_param_group in zip(
gradient_clipping_optimizer.param_groups,
ref_gradient_clipping_optimizer.param_groups,
):
for param, ref_param in zip(
param_group["params"], ref_param_group["params"]
):
self.assertTrue(
torch.equal(param.grad.full_tensor(), ref_param.grad),
f"Expect gradient to be the same. However, found {param.grad=}, {ref_param.grad=}",
)
Loading