Skip to content

Commit ad66544

Browse files
wz337facebook-github-bot
authored andcommitted
Fix "inf" norm gradient clipping for FSDP2 (#3199)
Summary: "inf" norm calculation for FSDP2 is incorrect, as the `total_grad_norm` would always be 1.0 regardless of the gradients. This is due to `_batch_cal_norm` would calculate `total_grad_norm = total_grad_norm ** (1.0 / norm_type)`, even if the `norm_type` is `inf` norm. For `inf` norm, since `norm_type` is `torch.inf`, `total_grad_norm` would become `total_grad_norm ** (1.0 / torch.inf) `, which would always just be 1.0. Before the fix, line 220 is comparing `self._norm_type != torch.inf`, as `self._norm_type` is either float number of `"inf"`, it would always enter this if statement, which is incorrect for the infinity norm case. This issue is found during Shampoo War Room debugging for HSDP2 (context: https://fb.workplace.com/groups/3095840833991792/permalink/4084834545092411/). During ablation study, we found that DDP Fused Adam is on-par with HSDP2 Fused Adam when 2-norm is used for clipping (Figure 1), while HSDP2 Fused Adam behaves strangely (high variance between runs) when infinity norm is used for clipping (Figure 2). Figure 1 {F1980311095} Figure 2 {F1980311115} TODO: As discussed, we should be able to significantly simplify things by leveraging a similar implementation as what's used in TorchTitan (https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/utils.py) and abstract out the gradient norm computation. Reviewed By: yoyoyocmu Differential Revision: D78326114
1 parent 94fc482 commit ad66544

File tree

2 files changed

+112
-5
lines changed

2 files changed

+112
-5
lines changed

torchrec/optim/clipping.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
165165
if total_grad_norm is None
166166
else (
167167
torch.maximum(total_grad_norm, sharded_grad_norm)
168-
if self._norm_type == torch.inf
168+
if norm_type == torch.inf
169169
else total_grad_norm + sharded_grad_norm
170170
)
171171
)
@@ -192,7 +192,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
192192
if total_grad_norm is None
193193
else (
194194
torch.maximum(total_grad_norm, replicated_grad_norm)
195-
if self._norm_type == torch.inf
195+
if norm_type == torch.inf
196196
else total_grad_norm + replicated_grad_norm
197197
)
198198
)
@@ -202,11 +202,11 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
202202

203203
global log_grad_norm
204204
if log_grad_norm:
205-
if total_grad_norm is not None and self._norm_type != torch.inf:
205+
if total_grad_norm is not None and norm_type != torch.inf:
206206
# pyre-ignore[58]
207207
grad_norm = total_grad_norm ** (1.0 / norm_type)
208208
else:
209-
grad_norm = 0
209+
grad_norm = total_grad_norm
210210

211211
rank = dist.get_rank()
212212
logger.info(
@@ -217,7 +217,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
217217
if total_grad_norm is None:
218218
return
219219

220-
if self._norm_type != torch.inf:
220+
if norm_type != torch.inf:
221221
# pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222222
total_grad_norm = total_grad_norm ** (1.0 / norm_type)
223223
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].

torchrec/optim/tests/test_clipping.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,21 @@
88
# pyre-strict
99

1010
import unittest
11+
from typing import Dict, List, Union
1112
from unittest.mock import MagicMock, patch
1213

1314
import torch
1415
from torch.autograd import Variable
16+
from torch.distributed import ProcessGroup
17+
from torch.distributed.tensor import distribute_tensor, DTensor, init_device_mesh, Shard
18+
from torch.testing._internal.common_utils import (
19+
instantiate_parametrized_tests,
20+
parametrize,
21+
)
22+
from torch.testing._internal.distributed._tensor.common_dtensor import (
23+
DTensorTestBase,
24+
with_comms,
25+
)
1526
from torchrec.optim.clipping import GradientClipping, GradientClippingOptimizer
1627
from torchrec.optim.test_utils import DummyKeyedOptimizer
1728

@@ -229,3 +240,99 @@ def test_clip_no_gradients_norm_meta_device(
229240
gradient_clipping_optimizer.step()
230241

231242
mock_clip_grad_norm.assert_not_called()
243+
244+
245+
@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
246+
@instantiate_parametrized_tests
247+
class TestGradientClippingDTensor(DTensorTestBase):
248+
def _get_params_to_pg(
249+
self, params: List[DTensor]
250+
) -> Dict[DTensor, List[ProcessGroup]]:
251+
return {param: [param.device_mesh.get_group()] for param in params}
252+
253+
@with_comms
254+
@parametrize("norm_type", ("inf",))
255+
def test_dtensor_clip_all_gradients_norm(
256+
self, norm_type: Union[float, str]
257+
) -> None:
258+
"""
259+
Test to ensure that the gradient clipping optimizer clips gradients
260+
correctly with mixed DTensor and tensor by comparing gradients to its
261+
torch.tensor counterpart.
262+
263+
Note that clipping for DTensor may require communication.
264+
"""
265+
266+
# create gradient clipping optimizer containing no dtensor for reference
267+
ref_param_1 = torch.nn.Parameter(
268+
torch.tensor([1.0, 2.0, 3.0], device=self.device_type)
269+
)
270+
ref_param_2 = torch.nn.Parameter(
271+
torch.tensor([4.0, 5.0, 6.0], device=self.device_type)
272+
)
273+
ref_keyed_optimizer = DummyKeyedOptimizer(
274+
{"param_1": ref_param_1, "param_2": ref_param_2},
275+
{},
276+
[{"params": [ref_param_1, ref_param_2]}],
277+
)
278+
ref_gradient_clipping_optimizer = GradientClippingOptimizer(
279+
optimizer=ref_keyed_optimizer,
280+
clipping=GradientClipping.NORM,
281+
max_gradient=10.0,
282+
norm_type=norm_type,
283+
)
284+
ref_gradient_clipping_optimizer.zero_grad()
285+
ref_param_1.grad = torch.tensor([12.0, 15.0, 18.0], device=self.device_type)
286+
ref_param_2.grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type)
287+
ref_gradient_clipping_optimizer.step()
288+
289+
# create gradient clipping optimizer containing both DTensor and tensor
290+
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
291+
param_1 = distribute_tensor(
292+
torch.tensor([1.0, 2.0, 3.0], requires_grad=True, device=self.device_type),
293+
device_mesh,
294+
[Shard(0)],
295+
)
296+
param_2 = torch.tensor(
297+
[4.0, 5.0, 6.0], requires_grad=True, device=self.device_type
298+
)
299+
param_to_pgs = self._get_params_to_pg([param_1])
300+
keyed_optimizer = DummyKeyedOptimizer(
301+
{"dtensor_param_1": param_1, "dtensor_param_2": param_2},
302+
{},
303+
[{"params": [param_1, param_2]}],
304+
)
305+
gradient_clipping_optimizer = GradientClippingOptimizer(
306+
optimizer=keyed_optimizer,
307+
clipping=GradientClipping.NORM,
308+
max_gradient=10.0,
309+
norm_type=norm_type,
310+
enable_global_grad_clip=True,
311+
param_to_pgs=param_to_pgs, # pyre-ignore[6]
312+
)
313+
gradient_clipping_optimizer.zero_grad()
314+
param_1.grad = distribute_tensor(
315+
torch.tensor([12.0, 15.0, 18.0], device=self.device_type),
316+
device_mesh,
317+
[Shard(0)],
318+
)
319+
param_2.grad = torch.tensor([20.0, 30.0, 15.0], device=self.device_type)
320+
gradient_clipping_optimizer.step()
321+
322+
for param_group, ref_param_group in zip(
323+
gradient_clipping_optimizer.param_groups,
324+
ref_gradient_clipping_optimizer.param_groups,
325+
):
326+
for param, ref_param in zip(
327+
param_group["params"], ref_param_group["params"]
328+
):
329+
param_grad = (
330+
param.grad.full_tensor() # pyre-ignore[16]
331+
if isinstance(param, DTensor)
332+
else param.grad
333+
)
334+
self.assertEqual(
335+
param_grad,
336+
ref_param.grad,
337+
f"Expect gradient to be the same. However, found {param_grad=}, {ref_param.grad=}",
338+
)

0 commit comments

Comments
 (0)