Skip to content

Conversation

namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Sep 17, 2025

Summary:
This PR adds sparsify overhead benchmark, omitted in ICLR workshop paper: https://arxiv.org/abs/2503.16672

In the paper, there are two parts for the benchmark: 1) Sparsify operation overhead, 2) Sparse-GEMM kernel performance. Part 1) was omitted from the original benchmark, so this PR adds the missing sparsify-only benchmark comparing torchao.sparse24_sm90_sparsify against torch._cslt_compress (cuSPASRELt) baseline.

Test plan: CI

This PR adds sparsify overhead benchmark, omitted in ICLR workshop paper:
https://arxiv.org/abs/2503.16672

In the paper, there are two parts for the benchmark: 1) Sparsify
operation overhead, 2) Sparse-GEMM kernel performance. Part 1) was
omitted from the original benchmark, so this PR adds the missing
sparsify-only benchmark comparing `torchao.sparse24_sm90_sparsify`
against `torch._cslt_compress` (cuSPASRELt) baseline.

Test plan: CI
Copy link

pytorch-bot bot commented Sep 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3021

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 17, 2025
@namgyu-youn
Copy link
Contributor Author

@jcaip Please review this PR, thanks.

@jcaip
Copy link
Contributor

jcaip commented Sep 18, 2025

@namgyu-youn Can you share the results of your benchmark script?

@namgyu-youn
Copy link
Contributor Author

@namgyu-youn Can you share the results of your benchmark script?

@jcaip unfortunately not available to H100 HBM, please feel free to edit for benchmarks result

Copy link
Contributor

@jcaip jcaip left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of nits but otherwise looks good - thanks for adding!

lambda: torch.ops.torchao.sparse24_sm90_sparsify(
input_tensor,
"cutlass",
"srelu",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be "identity" here instead

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update this :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder; I missed it.

scale=X_scale,
)
)
cusparse_time = benchmark_microseconds(lambda: torch._cslt_compress(input_tensor))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need this lambda? Can you just pass in like we do in L41 above:

cusparse_time = benchmark_microseconds(torch._cslt_compress, input_tensor)

# Sparsify-only benchmarks
X_scale = torch.empty([num_tokens, 1], device="cuda", dtype=torch.float32)
ao_cusparse_time = benchmark_microseconds(
lambda: torch.ops.torchao.sparse24_sm90_sparsify(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same nit as below

"srelu",
"largest",
dtype=torch.float8_e4m3fn,
scale=X_scale,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can pass in None to scale for the fairest comparison.

"fp8_c_time (us)": fp8_c_time,
"fp8_c_sparse_time (us)": fp8_c_sparse_time,
"fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time,
"ao_cusparse_time (us)": ao_cusparse_time,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think something like ao_fast_sparsification_time is a better var name.

"fp8_c_sparse_time (us)": fp8_c_sparse_time,
"fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time,
"ao_cusparse_time (us)": ao_cusparse_time,
"cusparse_compress_time (us)": cusparse_time,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cusparselt* instead of cusparse so we don't get confused :)

@jcaip jcaip added sparsity topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) labels Sep 18, 2025
@namgyu-youn namgyu-youn requested a review from jcaip September 19, 2025 06:31
lambda: torch.ops.torchao.sparse24_sm90_sparsify(
input_tensor,
"cutlass",
"srelu",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update this :)

"fp8_c_sparse_time (us)": fp8_c_sparse_time,
"fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time,
"ao_fast_sparsification_time (us)": ao_fast_sparsification_time,
"cusparse*_compress_time (us)": cusparse_time,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cusparse_time isnt a good name for this because there is a seperate cusparse library, aside from cusparselt. please use cusparselt here instead

Also looks like theres a typo in the string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I didn't know it due to my lack of background.

@namgyu-youn namgyu-youn requested a review from jcaip September 21, 2025 11:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. sparsity topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Missing benchmark for sparse24_sm90_sparsify overhead
2 participants