Skip to content

Check numerical equivalence / closeness between different kernel preferences #2651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: jerryzh168/stack/14
Choose a base branch
from

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Aug 1, 2025

Stacked PRs:


Check numerical equivalence / closeness between different kernel preferences

Summary:
This PR checks different kernel preferences for Float8Tensor are similar in numerics
(AUTO, TORCH and FBGEMM)

triton implementation and torchao implementation are a bit different right now actually, need to decide if we should fix it or not

  1. difference in quantize op
    main difference seems to be the triton implementation is using:
a_scale = MAX_FP8 / max_abs
then do
a_scale = 1.0 / a_scale
a_fp8 = a * a_scale

while torch is doing:

a_scale = max_abs / MAX_FP8
a_fp8 = a / a_scale

Also the hp_value_lb and hp_value_ub settings are slightly different

triton choose scale and quantize code: https://github.com/pytorch/FBGEMM/blob/a4286c01ef01dad435b2ec8798605127d3032cd8/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py#L2382-L2392

torchao choose scale and quantize code:

def _choose_scale_float8(

def _quantize_affine_float8(

  1. (potentially) difference in matrix multiplication ops

TORCH and AUTO/FBGEMM are using different quantized mm ops

Added a reverse option to bring sqnr closer:

granularity: PerTensor()  sizes: ((128,), 256, 128)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerTensor()  sizes: ((128,), 256, 128)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerTensor()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerTensor()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerRow()  sizes: ((128,), 256, 128)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerRow()  sizes: ((128,), 256, 128)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerRow()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.AUTO tensor(64.5000, device='cuda:0', dtype=torch.bfloat16)
granularity: PerRow()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.FBGEMM tensor(68., device='cuda:0', dtype=torch.bfloat16)

Test Plan:
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_kernel_preference_numerical_equivalence

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Aug 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2651

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 11 New Failures

As of commit bd0faa2 with merge base b757fb9 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 2534529 to c608b78 Compare August 1, 2025 00:53
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/14 branch from e19cb46 to 5ae457c Compare August 1, 2025 00:53
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 1, 2025
@jerryzh168 jerryzh168 added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 1, 2025
for i in range(1, len(kp_and_res)):
kp, res = kp_and_res[i]
self.assertTrue(
compute_error(res, kp_and_res[0][1]) > 28,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @vkuzo we don't have equivalence yet due to some differences in implementation, do you think we should match torchao quant primitives (choose_scale_float8 + quantize_float8) and triton ones?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know what the differences are?

IMO we should also choose either TORCH or FBGEMM (but not AUTO) as the reference, and match others to the reference

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah see PR summary for differences

I can update and use TORCH as reference

@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 1, 2025 00:56
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from c608b78 to 42a767c Compare August 1, 2025 00:56
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 1, 2025 00:56
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 1, 2025 03:38
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 42a767c to 65a4f84 Compare August 1, 2025 03:38
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 1, 2025 03:38
@jerryzh168 jerryzh168 requested a review from vkuzo August 1, 2025 04:43
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 1, 2025 21:12
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 65a4f84 to ba8efe2 Compare August 1, 2025 21:13
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 1, 2025 21:13
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 2, 2025 01:31
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 2, 2025 01:31
hp_value_ub: Optional[float] = None,
reverse: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make the callsites match instead of having a flag

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this is not the final state, was just trying out things to check if this helps or not

@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 4, 2025 17:30
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from ba8efe2 to 1720743 Compare August 4, 2025 17:30
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 4, 2025 17:30
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 4, 2025 18:15
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 1720743 to 36fce5e Compare August 4, 2025 18:15
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 4, 2025 18:15
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 4, 2025 22:14
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 36fce5e to 9824504 Compare August 4, 2025 22:14
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 4, 2025 22:15
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 4, 2025 23:51
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 9824504 to e818cb0 Compare August 4, 2025 23:51
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 4, 2025 23:51
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 5, 2025 01:20
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from e818cb0 to 1c3a47d Compare August 5, 2025 01:20
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 5, 2025 01:21
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 5, 2025 01:31
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 1c3a47d to d15e4ae Compare August 5, 2025 01:31
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 5, 2025 01:31
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 5, 2025 03:24
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from d15e4ae to 2ea6fbe Compare August 5, 2025 03:24
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 5, 2025 03:25
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 5, 2025 18:39
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 2ea6fbe to 3a71e28 Compare August 5, 2025 18:39
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 5, 2025 18:39
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 5, 2025 23:29
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 3a71e28 to 11d2143 Compare August 5, 2025 23:30
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 5, 2025 23:30
…erences

Summary:
This PR checks different kernel preferences for Float8Tensor are similar in numerics
(AUTO, TORCH and FBGEMM)

triton implementation and torchao implementation are a bit different right now actually, need to decide if we should fix it or not

1. difference in quantize op
main difference seems to be the triton implementation is using:
```
a_scale = MAX_FP8 / max_abs
then do
a_scale = 1.0 / a_scale
a_fp8 = a * a_scale
```

while torch is doing:
```
a_scale = max_abs / MAX_FP8
a_fp8 = a / a_scale
```

Also the hp_value_lb and hp_value_ub settings are slightly different

triton choose scale and quantize code: https://github.com/pytorch/FBGEMM/blob/a4286c01ef01dad435b2ec8798605127d3032cd8/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py#L2382-L2392

torchao choose scale and quantize code:
https://github.com/pytorch/ao/blob/3c466f844684af0fb80014094f2ca8663881eb33/torchao/quantization/quant_primitives.py#L2183
https://github.com/pytorch/ao/blob/3c466f844684af0fb80014094f2ca8663881eb33/torchao/quantization/quant_primitives.py#L2283

2. (potentially) difference in matrix multiplication ops

TORCH and AUTO/FBGEMM are using different quantized mm ops

Added a reverse option to bring sqnr closer:
```
granularity: PerTensor()  sizes: ((128,), 256, 128)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerTensor()  sizes: ((128,), 256, 128)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerTensor()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerTensor()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerRow()  sizes: ((128,), 256, 128)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerRow()  sizes: ((128,), 256, 128)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerRow()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.AUTO tensor(64.5000, device='cuda:0', dtype=torch.bfloat16)
granularity: PerRow()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.FBGEMM tensor(68., device='cuda:0', dtype=torch.bfloat16)
```
Test Plan:
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_kernel_preference_numerical_equivalence

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2651, branch: jerryzh168/stack/15
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 6, 2025 01:07
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 11d2143 to bd0faa2 Compare August 6, 2025 01:08
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 6, 2025 01:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants