Skip to content

Commit 6a6e03a

Browse files
wz337facebook-github-bot
authored andcommitted
Fix "inf" norm gradient clipping for FSDP2 (#3199)
Summary: "inf" norm calculation for FSDP2 is incorrect, as the `total_grad_norm` would always be 1.0 regardless of the gradients. This is due to `_batch_cal_norm` would calculate `total_grad_norm = total_grad_norm ** (1.0 / norm_type)`, even if the `norm_type` is `inf` norm. For `inf` norm, since `norm_type` is `torch.inf`, `total_grad_norm` would become `total_grad_norm ** (1.0 / torch.inf) `, which would always just be 1.0. Before the fix, line 220 is comparing `self._norm_type != torch.inf`, as `self._norm_type` is either float number of `"inf"`, it would always enter this if statement, which is incorrect for the infinity norm case. This issue is found during Shampoo War Room debugging for HSDP2 (context: https://fb.workplace.com/groups/3095840833991792/permalink/4084834545092411/). During ablation study, we found that DDP Fused Adam is on-par with HSDP2 Fused Adam when 2-norm is used for clipping (Figure 1), while HSDP2 Fused Adam behaves strangely (high variance between runs) when infinity norm is used for clipping (Figure 2). Figure 1 {F1980311095} Figure 2 {F1980311115} Differential Revision: D78326114
1 parent 0d6ef90 commit 6a6e03a

File tree

2 files changed

+111
-2
lines changed

2 files changed

+111
-2
lines changed

torchrec/optim/clipping.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
202202

203203
global log_grad_norm
204204
if log_grad_norm:
205-
if total_grad_norm is not None and self._norm_type != torch.inf:
205+
if total_grad_norm is not None and norm_type != torch.inf:
206206
# pyre-ignore[58]
207207
grad_norm = total_grad_norm ** (1.0 / norm_type)
208208
else:
@@ -217,7 +217,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
217217
if total_grad_norm is None:
218218
return
219219

220-
if self._norm_type != torch.inf:
220+
if norm_type != torch.inf:
221221
# pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222222
total_grad_norm = total_grad_norm ** (1.0 / norm_type)
223223
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].

torchrec/optim/tests/test_clipping.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,21 @@
88
# pyre-strict
99

1010
import unittest
11+
from typing import Any, Dict, List, Union
1112
from unittest.mock import MagicMock, patch
1213

1314
import torch
1415
from torch.autograd import Variable
16+
from torch.distributed import ProcessGroup
17+
from torch.distributed.tensor import distribute_tensor, DTensor, init_device_mesh, Shard
18+
from torch.testing._internal.common_utils import (
19+
instantiate_parametrized_tests,
20+
parametrize,
21+
)
22+
from torch.testing._internal.distributed._tensor.common_dtensor import (
23+
DTensorTestBase,
24+
with_comms,
25+
)
1526
from torchrec.optim.clipping import GradientClipping, GradientClippingOptimizer
1627
from torchrec.optim.test_utils import DummyKeyedOptimizer
1728

@@ -229,3 +240,101 @@ def test_clip_no_gradients_norm_meta_device(
229240
gradient_clipping_optimizer.step()
230241

231242
mock_clip_grad_norm.assert_not_called()
243+
244+
245+
@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
246+
@instantiate_parametrized_tests
247+
class TestGradientClippingDTensor(DTensorTestBase):
248+
def _get_params_to_pg(
249+
self, params: List[DTensor]
250+
) -> Dict[DTensor, List[ProcessGroup]]:
251+
return {param: [param.device_mesh.get_group()] for param in params}
252+
253+
@with_comms
254+
@parametrize("norm_type", ("inf", 2))
255+
def test_dtensor_clip_all_gradients_norm(
256+
self, norm_type: Union[float, str]
257+
) -> None:
258+
"""
259+
Test to ensure that the gradient clipping optimizer clips
260+
dtensor gradients correctly by comparing gradients to its
261+
torch.tensor counterpart.
262+
263+
Note that clipping for DTensor may require communication.
264+
"""
265+
266+
# create gradient clipping optimizer containing no dtensor for reference
267+
ref_param_1 = torch.nn.Parameter(
268+
torch.tensor([1.0, 2.0, 3.0], device=self.device_type)
269+
)
270+
ref_param_2 = torch.nn.Parameter(
271+
torch.tensor([4.0, 5.0, 6.0], device=self.device_type)
272+
)
273+
ref_keyed_optimizer = DummyKeyedOptimizer(
274+
{"param_1": ref_param_1, "param_2": ref_param_2},
275+
{},
276+
[{"params": [ref_param_1, ref_param_2]}],
277+
)
278+
ref_gradient_clipping_optimizer = GradientClippingOptimizer(
279+
optimizer=ref_keyed_optimizer,
280+
clipping=GradientClipping.NORM,
281+
max_gradient=10.0,
282+
norm_type=norm_type,
283+
)
284+
ref_gradient_clipping_optimizer.zero_grad()
285+
ref_param_1.grad = torch.tensor([12.0, 12.0, 12.0], device=self.device_type)
286+
ref_param_2.grad = torch.tensor([30.0, 30.0, 30.0], device=self.device_type)
287+
ref_gradient_clipping_optimizer.step()
288+
289+
# create gradient clipping optimizer containing DTensor
290+
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
291+
param_1 = distribute_tensor(
292+
torch.tensor([1.0, 2.0, 3.0], requires_grad=True, device=self.device_type),
293+
device_mesh,
294+
[Shard(0)],
295+
)
296+
param_2 = distribute_tensor(
297+
torch.tensor([4.0, 5.0, 6.0], requires_grad=True, device=self.device_type),
298+
device_mesh,
299+
[Shard(0)],
300+
)
301+
param_to_pgs = self._get_params_to_pg([param_1, param_2])
302+
keyed_optimizer = DummyKeyedOptimizer(
303+
{"dtensor_param_1": param_1, "dtensor_param_2": param_2},
304+
{},
305+
[{"params": [param_1, param_2]}],
306+
)
307+
gradient_clipping_optimizer = GradientClippingOptimizer(
308+
optimizer=keyed_optimizer,
309+
clipping=GradientClipping.NORM,
310+
max_gradient=10.0,
311+
norm_type=norm_type,
312+
enable_global_grad_clip=True,
313+
param_to_pgs=param_to_pgs, # pyre-ignore[6]
314+
)
315+
gradient_clipping_optimizer.zero_grad()
316+
param_1.grad = distribute_tensor(
317+
torch.tensor([12.0, 12.0, 12.0], device=self.device_type),
318+
device_mesh,
319+
[Shard(0)],
320+
)
321+
param_2.grad = distribute_tensor(
322+
torch.tensor([30.0, 30.0, 30.0], device=self.device_type),
323+
device_mesh,
324+
[Shard(0)],
325+
)
326+
gradient_clipping_optimizer.step()
327+
328+
for param_group, ref_param_group in zip(
329+
gradient_clipping_optimizer.param_groups,
330+
ref_gradient_clipping_optimizer.param_groups,
331+
):
332+
for param, ref_param in zip(
333+
param_group["params"], ref_param_group["params"]
334+
):
335+
param_grad = param.grad.full_tensor()
336+
self.assertEqual(
337+
param_grad,
338+
ref_param.grad,
339+
f"Expect gradient to be the same. However, found {param_grad=}, {ref_param.grad=}",
340+
)

0 commit comments

Comments
 (0)