-
Notifications
You must be signed in to change notification settings - Fork 310
[moe training] add fp8 rowwise kernels for expert weights #2696
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: danielvegamyhre/stack/29
Are you sure you want to change the base?
Conversation
stack-info: PR: #2696, branch: danielvegamyhre/stack/30
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2696
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
af159db
to
f6688be
Compare
f6688be
to
a6f8cbb
Compare
a6f8cbb
to
ef4e25c
Compare
ef4e25c
to
2ea3573
Compare
2ea3573
to
6704fd3
Compare
tl.float32 | ||
) | ||
if round_scales_to_power_of_2: | ||
scales = tl.exp2(tl.floor(tl.log2(scales))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems expensive, can we just extract the bits?
|
||
# Apply scales to tensor and convert to float8. | ||
tensor_scaled = input_hp_t.to(torch.float32) * scales | ||
float8_tensor = to_fp8_saturated(tensor_scaled, target_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is confusing because it sounds like Float8TrainingTensor
, maybe name it float8_data
?
fad9062
to
241e9b7
Compare
6704fd3
to
c789281
Compare
Stacked PRs:
[moe training] add fp8 rowwise kernels for expert weights
Summary
Test plan
pytest test/prototype/moe_training/test_kernels.py