-
Notifications
You must be signed in to change notification settings - Fork 309
Add Float8BlockwiseLinear for training #2618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2618
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 90376ea with merge base 6b82931 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
4c0250f
to
89357e5
Compare
c1683b6
to
21a36dc
Compare
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
21a36dc
to
eca0126
Compare
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
eca0126
to
5bfc200
Compare
the integration looks good, but IMO we should use pytorch native code for quantization of weights/activations, and only add triton kernels when we're confident they match the pytorch native code with bitwise accuracy. Without this, it's hard to trust and debug the accuracy of this prototype. |
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
5bfc200
to
cb92b94
Compare
Makes sense, I updated prior PR to assert bitwise equivalence, let me know what you think. |
from torchao.prototype.blockwise_fp8_training.linear import Float8BlockwiseLinear | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sm90?
if in_features % block_size != 0 or out_features % block_size != 0: | ||
pytest.skip(f"Dimensions must be divisible by block_size={block_size}") | ||
|
||
torch.random.manual_seed(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually people do this one next to the imports, and then use copy.deepcopy
to create copies of models. It's unusually to set the seed twice to get the same effect, even if it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, updated.
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
cb92b94
to
2ffbc4f
Compare
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
2ffbc4f
to
33255a2
Compare
stack-info: PR: #2618, branch: danielvegamyhre/stack/18
33255a2
to
ef21071
Compare
ef21071
to
90376ea
Compare
Stacked PRs:
Add Float8BlockwiseLinear for training
Test plan
pytest test/prototype/blockwise_fp8_training/test_blockwise_linear.py
Limitations
Performance and next steps
The perf is bad (4.4k TPS vs 6.3k TPS bf16) due largely due to slow GEMMs. As mentioned in #2617, if we want to improve perf, I can make the change necessary for compatibility with torch._scaled_mm and see if perf improves (very likely, I'd say).