Skip to content

Add Float8BlockwiseLinear for training #2618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 1, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jul 28, 2025

Stacked PRs:


Add Float8BlockwiseLinear for training

  • Add autograd func wrapping triton kernels for fp8 blockwise linear layer
  • Add tests validating numerics
  • Validated e2e training with torchtitan, loss curve looks good.

Test plan

  • pytest test/prototype/blockwise_fp8_training/test_blockwise_linear.py
  • e2e training in torchtitan for 100 steps, loss looks same as bf16 (fp8 logs, bf16 logs)

Limitations

  • Only FSDP supported for parallelisms
  • torch.compile not supported yet

Performance and next steps

The perf is bad (4.4k TPS vs 6.3k TPS bf16) due largely due to slow GEMMs. As mentioned in #2617, if we want to improve perf, I can make the change necessary for compatibility with torch._scaled_mm and see if perf improves (very likely, I'd say).

stack-info: PR: #2618, branch: danielvegamyhre/stack/18
Copy link

pytorch-bot bot commented Jul 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2618

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 90376ea with merge base 6b82931 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Jul 28, 2025
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/17 branch from 4c0250f to 89357e5 Compare July 28, 2025 16:06
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/18 branch from c1683b6 to 21a36dc Compare July 28, 2025 16:06
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 28, 2025
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jul 28, 2025
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/17 to main July 28, 2025 16:14
danielvegamyhre added a commit that referenced this pull request Jul 28, 2025
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/18 branch from 21a36dc to eca0126 Compare July 28, 2025 16:15
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/17 July 28, 2025 16:15
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/17 to main July 28, 2025 17:21
danielvegamyhre added a commit that referenced this pull request Jul 28, 2025
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/18 branch from eca0126 to 5bfc200 Compare July 28, 2025 17:21
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/17 July 28, 2025 17:21
@vkuzo
Copy link
Contributor

vkuzo commented Jul 28, 2025

the integration looks good, but IMO we should use pytorch native code for quantization of weights/activations, and only add triton kernels when we're confident they match the pytorch native code with bitwise accuracy. Without this, it's hard to trust and debug the accuracy of this prototype.

@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/17 to main July 28, 2025 18:15
danielvegamyhre added a commit that referenced this pull request Jul 28, 2025
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/18 branch from 5bfc200 to cb92b94 Compare July 28, 2025 18:15
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/17 July 28, 2025 18:15
@danielvegamyhre
Copy link
Contributor Author

the integration looks good, but IMO we should use pytorch native code for quantization of weights/activations, and only add triton kernels when we're confident they match the pytorch native code with bitwise accuracy. Without this, it's hard to trust and debug the accuracy of this prototype.

Makes sense, I updated prior PR to assert bitwise equivalence, let me know what you think.

from torchao.prototype.blockwise_fp8_training.linear import Float8BlockwiseLinear


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sm90?

if in_features % block_size != 0 or out_features % block_size != 0:
pytest.skip(f"Dimensions must be divisible by block_size={block_size}")

torch.random.manual_seed(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually people do this one next to the imports, and then use copy.deepcopy to create copies of models. It's unusually to set the seed twice to get the same effect, even if it works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, updated.

@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/17 to main July 28, 2025 19:49
danielvegamyhre added a commit that referenced this pull request Jul 28, 2025
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/18 branch from cb92b94 to 2ffbc4f Compare July 28, 2025 19:49
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/17 July 28, 2025 19:49
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/17 to main July 28, 2025 19:54
danielvegamyhre added a commit that referenced this pull request Jul 28, 2025
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/18 branch from 2ffbc4f to 33255a2 Compare July 28, 2025 19:54
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/17 July 28, 2025 19:54
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/17 to main July 29, 2025 15:56
danielvegamyhre added a commit that referenced this pull request Jul 29, 2025
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/18 branch from 33255a2 to ef21071 Compare July 29, 2025 15:56
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/17 July 29, 2025 15:56
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/18 branch from ef21071 to 90376ea Compare July 30, 2025 15:50
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/17 to main July 30, 2025 15:50
@danielvegamyhre danielvegamyhre merged commit 3c466f8 into main Aug 1, 2025
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants