From 50ac2cc2eb6a0698486c4171047701b2508a11cb Mon Sep 17 00:00:00 2001 From: youn17 Date: Thu, 18 Sep 2025 02:12:38 +0900 Subject: [PATCH 1/3] Summary: This PR adds sparsify overhead benchmark, omitted in ICLR workshop paper: https://arxiv.org/abs/2503.16672 In the paper, there are two parts for the benchmark: 1) Sparsify operation overhead, 2) Sparse-GEMM kernel performance. Part 1) was omitted from the original benchmark, so this PR adds the missing sparsify-only benchmark comparing `torchao.sparse24_sm90_sparsify` against `torch._cslt_compress` (cuSPASRELt) baseline. Test plan: CI --- benchmarks/benchmark_e2e_fp8_sparse_linear.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/benchmarks/benchmark_e2e_fp8_sparse_linear.py b/benchmarks/benchmark_e2e_fp8_sparse_linear.py index fbab8c0671..bae0e708fe 100644 --- a/benchmarks/benchmark_e2e_fp8_sparse_linear.py +++ b/benchmarks/benchmark_e2e_fp8_sparse_linear.py @@ -40,6 +40,20 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): input_tensor = torch.randn(num_tokens, hidden_size).to(torch.bfloat16).cuda() fp16_time = benchmark_microseconds(ffn_ref, input_tensor) + # Sparsify-only benchmarks + X_scale = torch.empty([num_tokens, 1], device="cuda", dtype=torch.float32) + ao_cusparse_time = benchmark_microseconds( + lambda: torch.ops.torchao.sparse24_sm90_sparsify( + input_tensor, + "cutlass", + "srelu", + "largest", + dtype=torch.float8_e4m3fn, + scale=X_scale, + ) + ) + cusparse_time = benchmark_microseconds(lambda: torch._cslt_compress(input_tensor)) + # bf16 ffn_clone = ( nn.Sequential( @@ -117,7 +131,10 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): "fp8_c_time (us)": fp8_c_time, "fp8_c_sparse_time (us)": fp8_c_sparse_time, "fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time, + "ao_cusparse_time (us)": ao_cusparse_time, + "cusparse_compress_time (us)": cusparse_time, "speedup": fp8_c_time / fp8_c_activation_sparse_time, + "sparsify_speedup": cusparse_time / ao_cusparse_time, } From f9f2f8d06ef014d0f50f72343a3eae06c5c5ea97 Mon Sep 17 00:00:00 2001 From: youn17 Date: Fri, 19 Sep 2025 15:29:50 +0900 Subject: [PATCH 2/3] remove lambda, scale for fair comparison --- benchmarks/benchmark_e2e_fp8_sparse_linear.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/benchmarks/benchmark_e2e_fp8_sparse_linear.py b/benchmarks/benchmark_e2e_fp8_sparse_linear.py index bae0e708fe..2b1a442cba 100644 --- a/benchmarks/benchmark_e2e_fp8_sparse_linear.py +++ b/benchmarks/benchmark_e2e_fp8_sparse_linear.py @@ -41,18 +41,16 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): fp16_time = benchmark_microseconds(ffn_ref, input_tensor) # Sparsify-only benchmarks - X_scale = torch.empty([num_tokens, 1], device="cuda", dtype=torch.float32) - ao_cusparse_time = benchmark_microseconds( - lambda: torch.ops.torchao.sparse24_sm90_sparsify( + ao_fast_sparsification_time = benchmark_microseconds( + torch.ops.torchao.sparse24_sm90_sparsify( input_tensor, "cutlass", "srelu", "largest", dtype=torch.float8_e4m3fn, - scale=X_scale, ) ) - cusparse_time = benchmark_microseconds(lambda: torch._cslt_compress(input_tensor)) + cusparse_time = benchmark_microseconds(torch._cslt_compress, input_tensor) # bf16 ffn_clone = ( @@ -131,10 +129,10 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): "fp8_c_time (us)": fp8_c_time, "fp8_c_sparse_time (us)": fp8_c_sparse_time, "fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time, - "ao_cusparse_time (us)": ao_cusparse_time, - "cusparse_compress_time (us)": cusparse_time, + "ao_fast_sparsification_time (us)": ao_fast_sparsification_time, + "cusparse*_compress_time (us)": cusparse_time, "speedup": fp8_c_time / fp8_c_activation_sparse_time, - "sparsify_speedup": cusparse_time / ao_cusparse_time, + "sparsify_speedup": cusparse_time / ao_fast_sparsification_time, } From cfbeabfcf5533b0adb40785f5d502a11165470f9 Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 21 Sep 2025 20:40:06 +0900 Subject: [PATCH 3/3] rename attributes to prevent duplicate naming --- benchmarks/benchmark_e2e_fp8_sparse_linear.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmark_e2e_fp8_sparse_linear.py b/benchmarks/benchmark_e2e_fp8_sparse_linear.py index 2b1a442cba..f2b25a9202 100644 --- a/benchmarks/benchmark_e2e_fp8_sparse_linear.py +++ b/benchmarks/benchmark_e2e_fp8_sparse_linear.py @@ -45,12 +45,12 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): torch.ops.torchao.sparse24_sm90_sparsify( input_tensor, "cutlass", - "srelu", + "identity", "largest", dtype=torch.float8_e4m3fn, ) ) - cusparse_time = benchmark_microseconds(torch._cslt_compress, input_tensor) + cusparselt_time = benchmark_microseconds(torch._cslt_compress, input_tensor) # bf16 ffn_clone = ( @@ -130,9 +130,9 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): "fp8_c_sparse_time (us)": fp8_c_sparse_time, "fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time, "ao_fast_sparsification_time (us)": ao_fast_sparsification_time, - "cusparse*_compress_time (us)": cusparse_time, + "cusparselt_compress_time (us)": cusparselt_time, "speedup": fp8_c_time / fp8_c_activation_sparse_time, - "sparsify_speedup": cusparse_time / ao_fast_sparsification_time, + "sparsify_speedup": cusparselt_time / ao_fast_sparsification_time, }