Skip to content

Remove double baseline calculations for CI microbenchmarks #2613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 118 additions & 29 deletions benchmarks/microbenchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
from copy import deepcopy
from pathlib import Path
from typing import Dict, Tuple

import torch

Expand All @@ -34,15 +35,70 @@
create_model_and_input_data,
)

# -----------------------------------------------------------------------------
# Baseline caching
#
# ``_BASELINE_CACHE`` maps a unique key to a tuple
# ``(eager_baseline_time, compile_baseline_time)``. See ``_make_cache_key`` for the key
# construction. Users should not access this cache directly; it is
# internal to this module. The cache intentionally holds the
# uncompiled base model so that quantized versions can be derived
# without mutating the cached copy.

_BASELINE_CACHE: Dict[Tuple, Tuple[float, float]] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add comment for what key is and maybe give an example



def _make_cache_key(config: BenchmarkConfig) -> Tuple:
"""Create a key for caching based on benchmark configuration.

Parameters that affect baseline performance are included:

* model type (e.g. ``linear`` or ``transformer_block``)
* shape dimensions (m, k, n)
* high precision dtype (bf16, fp16, etc.)
* device (cuda, cpu, mps)
* compile settings (whether compile is enabled and compile mode)

Sparsity and quantization settings are deliberately excluded
because the baseline (non‑quantized, non‑sparse) performance is
independent of those attributes.
"""
return (
config.model_type,
config.m,
config.k,
config.n,
config.high_precision_dtype,
config.device,
config.torch_compile_mode,
)


def run(config: BenchmarkConfig) -> BenchmarkResult:
"""Run inference benchmarks"""
"""
Run inference benchmarks.

The function first checks if a baseline for the given configuration
already exists in the internal cache. If not, it measures the baseline
inference time and stores the result. When the baseline is cached,
the function reuses the cached baselines to calculate speedup metrics.

Args:
config (BenchmarkConfig): Benchmark configuration.

Returns:
BenchmarkResult: Result of the benchmark.
"""
try:
clean_caches() # Clean caches

# Create output directory if it doesn't exist
Path(config.output_dir).mkdir(parents=True, exist_ok=True)

# Prepare result container
result = BenchmarkResult(config=config)

# Create model and input data
base_model, input_data = create_model_and_input_data(
config.model_type,
config.m,
Expand All @@ -51,28 +107,46 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
high_precision_dtype=config.high_precision_dtype,
device=config.device,
)
# Copy base model for quantizing
m_copy = deepcopy(base_model)

# Run benchmarks
result = BenchmarkResult(config=config)
# Generate a cache key for the current configuration
cache_key = _make_cache_key(config)

# Store result in model for memory profiling
base_model._benchmark_result = result
# Check if the baseline for this configuration has been computed
if cache_key not in _BASELINE_CACHE:
# Switch model to eval and move to device
base_model = base_model.eval().to(config.device)
print("Benchmarking eager baseline inference.....")
eager_baseline_time = model_inference_time_in_ms(
model=base_model, input_data=input_data
)

# Run baseline benchmarking
base_model = base_model.eval().to(config.device)
if config.use_torch_compile:
print("Compiling baseline model....")
print("Benchmarking compile baseline inference.....")
base_model = torch.compile(
base_model, mode=config.torch_compile_mode, fullgraph=True
)
# Benchmark time to run an inference call for baseline model
print("Benchmarking baseline inference.....")
result.baseline_inference_time_in_ms = model_inference_time_in_ms(
model=base_model, input_data=input_data
)
compile_baseline_time = model_inference_time_in_ms(
model=base_model, input_data=input_data
)

# Store uncompiled model, input and baseline time
_BASELINE_CACHE[cache_key] = (eager_baseline_time, compile_baseline_time)

result.eager_baseline_inference_time_in_ms = eager_baseline_time
result.compile_baseline_inference_time_in_ms = compile_baseline_time
else:
# Retrieve cached values
cached_eager_time, cached_compile_time = _BASELINE_CACHE[cache_key]
result.eager_baseline_inference_time_in_ms = cached_eager_time
result.compile_baseline_inference_time_in_ms = cached_compile_time

# At this point, ``base_model`` is an uncompiled model ready for quantization,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base_model could be compiled in L124 right?

# and ``input_data`` is the corresponding input tensor. The baseline time
# has been stored in ``result.baseline_inference_time_in_ms``.

# Copy base model for quantizing/sparsifying
m_copy = deepcopy(base_model)

# Determine quantization/sparsity configuration
ao_base_config = string_to_config(
config.quantization,
config.sparsity,
Expand Down Expand Up @@ -101,24 +175,39 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
m_copy = m_copy.eval().to(config.device)
quantize_(m_copy, ao_base_config)

if config.use_torch_compile:
print("Compiling quantized model....")
m_copy = torch.compile(
m_copy, mode=config.torch_compile_mode, fullgraph=True
)

# Store result in model for memory profiling
m_copy._benchmark_result = result

# Benchmark time to run an inference call for quantized model
# Measure inference time for quantized model
print("Benchmarking eager quantized model.....")
result.eager_model_inference_time_in_ms = model_inference_time_in_ms(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add quantized somewhere in the name?

model=m_copy, input_data=input_data
)

# Measure inference time for compiled quantized model
print("Benchmarking quantized model.....")
result.model_inference_time_in_ms = model_inference_time_in_ms(
m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True)
result.compile_model_inference_time_in_ms = model_inference_time_in_ms(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this one

model=m_copy, input_data=input_data
)

# Calculate speedup w.r.t. baseline
result.speedup = round(
result.baseline_inference_time_in_ms / result.model_inference_time_in_ms, 2
# Compute eager speedup relative to baseline
result.eager_speedup_on_baseline = round(
result.eager_baseline_inference_time_in_ms
/ result.eager_model_inference_time_in_ms,
2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: pass by keyword arg to show what this is

)
# Compute compile speedup relative to baseline
result.compile_speedup_on_baseline = round(
result.compile_baseline_inference_time_in_ms
/ result.compile_model_inference_time_in_ms,
2,
)
# Compute compile speedup for quantized model relative to eager quantized model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to do this comparison? I think it might be more useful to just compare eager quantized v.s. eager baseline and compile quantized v.s. compile baseline, since these shows the speedup in different serving environments

result.compile_speedup_on_eager = round(
result.eager_model_inference_time_in_ms
/ result.compile_model_inference_time_in_ms,
2,
)

# Run profiler if enabled
Expand Down Expand Up @@ -165,9 +254,9 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
result.memory_profile_path
)
except ValueError as e:
if "not enough values to unpack" in e:
if "not enough values to unpack" in str(e):
print(
"Failed due to existing bugs, re-run the code to generate memory profile. Please raise an issue if it persists."
"Failed due to existing bugs, rerun the code to generate memory profile. Please raise an issue if it persists."
)
except Exception as e:
print(f"Error running memory profiler: {e}")
Expand Down
3 changes: 0 additions & 3 deletions benchmarks/microbenchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@ def get_quantization_sparsity_recipes(
"""
config_recipes = set()

# Always include baseline without sparsity
config_recipes.add(("baseline", None))

# Add all quantization techniques without sparsity
for quant_config in quantization_recipes:
config_recipes.add((quant_config, None))
Expand Down
4 changes: 0 additions & 4 deletions benchmarks/microbenchmarks/test/benchmark_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ model_params:
min_power: 14
max_power: 16
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "linear"
Expand All @@ -27,7 +26,6 @@ model_params:
[2048, 4096, 1024],
]
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "ln_linear_sigmoid"
Expand All @@ -41,7 +39,6 @@ model_params:
[2048, 4096, 1024], # For transformer_block, k is the hidden dimension
]
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "transformer_block" # TODO: Add a custom model (Figure out how to do this, maybe pass a .py file with model definition)
Expand All @@ -58,7 +55,6 @@ model_params:
min_power: 10 # 1024
max_power: 11 # 2048
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "linear"
Expand Down
9 changes: 3 additions & 6 deletions benchmarks/microbenchmarks/test/test_benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def setUp(self):
sparsity="semi-sparse",
params={
"high_precision_dtype": "torch.float32",
"use_torch_compile": False,
"device": "cpu",
"model_type": "linear",
},
Expand All @@ -46,7 +45,7 @@ def test_run_inference(self, mock_string_to_config):

result = run(self.config)
self.assertIsInstance(result, BenchmarkResult)
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))
self.assertTrue(hasattr(result, "compile_model_inference_time_in_ms"))

@patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config")
def test_run_inference_with_semi_sparse_marlin(self, mock_string_to_config):
Expand All @@ -64,7 +63,6 @@ def test_run_inference_with_semi_sparse_marlin(self, mock_string_to_config):
sparsity="semi-sparse",
params={
"high_precision_dtype": "torch.float32",
"use_torch_compile": False,
"device": "cpu",
"model_type": "linear",
},
Expand All @@ -75,7 +73,7 @@ def test_run_inference_with_semi_sparse_marlin(self, mock_string_to_config):
)
result = run(config)
self.assertIsInstance(result, BenchmarkResult)
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))
self.assertTrue(hasattr(result, "compile_model_inference_time_in_ms"))

@patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config")
def test_run_inference_with_block_sparsity(self, mock_string_to_config):
Expand All @@ -92,7 +90,6 @@ def test_run_inference_with_block_sparsity(self, mock_string_to_config):
sparsity="block",
params={
"high_precision_dtype": "torch.float32",
"use_torch_compile": False,
"device": "cpu",
"model_type": "linear",
},
Expand All @@ -103,7 +100,7 @@ def test_run_inference_with_block_sparsity(self, mock_string_to_config):
)
result = run(config)
self.assertIsInstance(result, BenchmarkResult)
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))
self.assertTrue(hasattr(result, "compile_model_inference_time_in_ms"))


if __name__ == "__main__":
Expand Down
11 changes: 5 additions & 6 deletions benchmarks/microbenchmarks/test/test_benchmark_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,12 @@ def test_memory_profiler_cuda_unavailable(self):
f"{config.name}_{self.m}_{self.k}_{self.n}_memory_profile.json",
)

# Generate memory profile
result, memory_stats = generate_memory_profile(
self.model, self.input_data, memory_profile_path
)

# Should return None when CUDA is unavailable
self.assertIsNone(result)
self.assertIsNone(
generate_memory_profile(
self.model, self.input_data, memory_profile_path
)
)

# Should not create file when CUDA is unavailable
self.assertFalse(os.path.exists(memory_profile_path))
Expand Down
2 changes: 0 additions & 2 deletions benchmarks/microbenchmarks/test/test_benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def setUp(self):
}
],
"high_precision_dtype": "torch.bfloat16",
"use_torch_compile": True,
"torch_compile_mode": "max-autotune",
"device": "cpu",
"model_type": "linear",
Expand Down Expand Up @@ -130,7 +129,6 @@ def test_get_param_combinations(self):
self.assertEqual(len(shapes), 1)
self.assertEqual(shapes[0], ("custom", [1024, 1024, 1024]))
self.assertEqual(params["high_precision_dtype"], "torch.bfloat16")
self.assertEqual(params["use_torch_compile"], True)

@patch("argparse.Namespace")
def test_load_benchmark_configs(self, mock_args):
Expand Down
4 changes: 1 addition & 3 deletions benchmarks/microbenchmarks/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def setUp(self):
self.test_params = {
"name": "test_model",
"high_precision_dtype": "torch.bfloat16",
"use_torch_compile": True,
"torch_compile_mode": "max-autotune",
"device": "cpu",
"model_type": "linear",
Expand All @@ -57,7 +56,6 @@ def test_benchmark_config(self):
self.assertEqual(config.k, 1024)
self.assertEqual(config.n, 1024)
self.assertEqual(config.high_precision_dtype, torch.bfloat16)
self.assertEqual(config.use_torch_compile, True)
self.assertEqual(config.torch_compile_mode, "max-autotune")
self.assertEqual(config.device, "cpu")
self.assertEqual(config.model_type, "linear")
Expand All @@ -76,7 +74,7 @@ def test_benchmark_result(self):
result = BenchmarkResult(config=config)

self.assertEqual(result.config, config)
self.assertEqual(result.model_inference_time_in_ms, 0.0)
self.assertEqual(result.compile_model_inference_time_in_ms, 0.0)

def test_get_default_device(self):
# Test CPU fallback
Expand Down
Loading
Loading