Skip to content

Fix "inf" norm gradient clipping for FSDP2 #3199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

wz337
Copy link
Contributor

@wz337 wz337 commented Jul 16, 2025

Summary:
"inf" norm calculation for FSDP2 is incorrect, as the total_grad_norm would always be 1.0 regardless of the gradients. This is due to _batch_cal_norm would calculate total_grad_norm = total_grad_norm ** (1.0 / norm_type), even if the norm_type is inf norm. For inf norm, since norm_type is torch.inf, total_grad_norm would become total_grad_norm ** (1.0 / torch.inf) , which would always just be 1.0.

Before the fix, line 220 is comparing self._norm_type != torch.inf, as self._norm_type is either float number of "inf", it would always enter this if statement, which is incorrect for the infinity norm case.

This issue is found during Shampoo War Room debugging for HSDP2 (context: https://fb.workplace.com/groups/3095840833991792/permalink/4084834545092411/).

During ablation study, we found that DDP Fused Adam is on-par with HSDP2 Fused Adam when 2-norm is used for clipping (Figure 1), while HSDP2 Fused Adam behaves strangely (high variance between runs) when infinity norm is used for clipping (Figure 2).

Figure 1
{F1980311095}

Figure 2
{F1980311115}

Differential Revision: D78326114

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 16, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78326114

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78326114

@wz337 wz337 force-pushed the export-D78326114 branch from 5a8a005 to 6a6e03a Compare July 16, 2025 17:32
wz337 added a commit to wz337/torchrec that referenced this pull request Jul 16, 2025
Summary:

"inf" norm calculation for FSDP2 is incorrect, as the `total_grad_norm` would always be 1.0 regardless of the gradients. This is due to `_batch_cal_norm` would calculate `total_grad_norm = total_grad_norm ** (1.0 / norm_type)`, even if the `norm_type` is `inf` norm. For `inf` norm, since `norm_type` is `torch.inf`, `total_grad_norm` would become `total_grad_norm ** (1.0 / torch.inf) `, which would always just be 1.0. 

Before the fix, line 220 is comparing `self._norm_type != torch.inf`, as `self._norm_type` is either float number of `"inf"`, it would always enter this if statement, which is incorrect for the infinity norm case.

This issue is found during Shampoo War Room debugging for HSDP2 (context: https://fb.workplace.com/groups/3095840833991792/permalink/4084834545092411/).

During ablation study, we found that DDP Fused Adam is on-par with HSDP2 Fused Adam when 2-norm is used for clipping (Figure 1), while HSDP2 Fused Adam behaves strangely (high variance between runs) when infinity norm is used for clipping (Figure 2). 

Figure 1
{F1980311095}

Figure 2
{F1980311115}

Differential Revision: D78326114
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78326114

wz337 added a commit to wz337/torchrec that referenced this pull request Jul 16, 2025
Summary:
Pull Request resolved: pytorch#3199

"inf" norm calculation for FSDP2 is incorrect, as the `total_grad_norm` would always be 1.0 regardless of the gradients. This is due to `_batch_cal_norm` would calculate `total_grad_norm = total_grad_norm ** (1.0 / norm_type)`, even if the `norm_type` is `inf` norm. For `inf` norm, since `norm_type` is `torch.inf`, `total_grad_norm` would become `total_grad_norm ** (1.0 / torch.inf) `, which would always just be 1.0.

Before the fix, line 220 is comparing `self._norm_type != torch.inf`, as `self._norm_type` is either float number of `"inf"`, it would always enter this if statement, which is incorrect for the infinity norm case.

This issue is found during Shampoo War Room debugging for HSDP2 (context: https://fb.workplace.com/groups/3095840833991792/permalink/4084834545092411/).

During ablation study, we found that DDP Fused Adam is on-par with HSDP2 Fused Adam when 2-norm is used for clipping (Figure 1), while HSDP2 Fused Adam behaves strangely (high variance between runs) when infinity norm is used for clipping (Figure 2).

Figure 1
{F1980311095}

Figure 2
{F1980311115}

Differential Revision: D78326114
@wz337 wz337 force-pushed the export-D78326114 branch from 6a6e03a to 6cef1c1 Compare July 16, 2025 17:36
Summary:

"inf" norm calculation for FSDP2 is incorrect, as the `total_grad_norm` would always be 1.0 regardless of the gradients. This is due to `_batch_cal_norm` would calculate `total_grad_norm = total_grad_norm ** (1.0 / norm_type)`, even if the `norm_type` is `inf` norm. For `inf` norm, since `norm_type` is `torch.inf`, `total_grad_norm` would become `total_grad_norm ** (1.0 / torch.inf) `, which would always just be 1.0. 

Before the fix, line 220 is comparing `self._norm_type != torch.inf`, as `self._norm_type` is either float number of `"inf"`, it would always enter this if statement, which is incorrect for the infinity norm case.

This issue is found during Shampoo War Room debugging for HSDP2 (context: https://fb.workplace.com/groups/3095840833991792/permalink/4084834545092411/).

During ablation study, we found that DDP Fused Adam is on-par with HSDP2 Fused Adam when 2-norm is used for clipping (Figure 1), while HSDP2 Fused Adam behaves strangely (high variance between runs) when infinity norm is used for clipping (Figure 2). 

Figure 1
{F1980311095}

Figure 2
{F1980311115}


TODO: As discussed, we should be able to significantly simplify things by leveraging a similar implementation as what's used in TorchTitan (https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/utils.py) and abstract out the gradient norm computation.

Reviewed By: yoyoyocmu

Differential Revision: D78326114
@wz337 wz337 force-pushed the export-D78326114 branch from 6cef1c1 to ad66544 Compare July 17, 2025 06:38
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78326114

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants