-
Notifications
You must be signed in to change notification settings - Fork 309
Remove double baseline calculations for CI microbenchmarks #2613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b76843a
4da5639
804bea7
a193e4c
d0b318f
63cd524
5a35513
b089e30
28f3f6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
import os | ||
from copy import deepcopy | ||
from pathlib import Path | ||
from typing import Dict, Tuple | ||
|
||
import torch | ||
|
||
|
@@ -34,15 +35,70 @@ | |
create_model_and_input_data, | ||
) | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Baseline caching | ||
# | ||
# ``_BASELINE_CACHE`` maps a unique key to a tuple | ||
# ``(eager_baseline_time, compile_baseline_time)``. See ``_make_cache_key`` for the key | ||
# construction. Users should not access this cache directly; it is | ||
# internal to this module. The cache intentionally holds the | ||
# uncompiled base model so that quantized versions can be derived | ||
# without mutating the cached copy. | ||
|
||
_BASELINE_CACHE: Dict[Tuple, Tuple[float, float]] = {} | ||
|
||
|
||
def _make_cache_key(config: BenchmarkConfig) -> Tuple: | ||
"""Create a key for caching based on benchmark configuration. | ||
|
||
Parameters that affect baseline performance are included: | ||
|
||
* model type (e.g. ``linear`` or ``transformer_block``) | ||
* shape dimensions (m, k, n) | ||
* high precision dtype (bf16, fp16, etc.) | ||
* device (cuda, cpu, mps) | ||
* compile settings (whether compile is enabled and compile mode) | ||
|
||
Sparsity and quantization settings are deliberately excluded | ||
because the baseline (non‑quantized, non‑sparse) performance is | ||
independent of those attributes. | ||
""" | ||
return ( | ||
config.model_type, | ||
config.m, | ||
config.k, | ||
config.n, | ||
config.high_precision_dtype, | ||
config.device, | ||
config.torch_compile_mode, | ||
) | ||
|
||
|
||
def run(config: BenchmarkConfig) -> BenchmarkResult: | ||
"""Run inference benchmarks""" | ||
""" | ||
Run inference benchmarks. | ||
|
||
The function first checks if a baseline for the given configuration | ||
already exists in the internal cache. If not, it measures the baseline | ||
inference time and stores the result. When the baseline is cached, | ||
the function reuses the cached baselines to calculate speedup metrics. | ||
|
||
Args: | ||
config (BenchmarkConfig): Benchmark configuration. | ||
|
||
Returns: | ||
BenchmarkResult: Result of the benchmark. | ||
""" | ||
try: | ||
clean_caches() # Clean caches | ||
|
||
# Create output directory if it doesn't exist | ||
Path(config.output_dir).mkdir(parents=True, exist_ok=True) | ||
|
||
# Prepare result container | ||
result = BenchmarkResult(config=config) | ||
|
||
# Create model and input data | ||
base_model, input_data = create_model_and_input_data( | ||
config.model_type, | ||
config.m, | ||
|
@@ -51,28 +107,46 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: | |
high_precision_dtype=config.high_precision_dtype, | ||
device=config.device, | ||
) | ||
# Copy base model for quantizing | ||
m_copy = deepcopy(base_model) | ||
|
||
# Run benchmarks | ||
result = BenchmarkResult(config=config) | ||
# Generate a cache key for the current configuration | ||
cache_key = _make_cache_key(config) | ||
|
||
# Store result in model for memory profiling | ||
base_model._benchmark_result = result | ||
# Check if the baseline for this configuration has been computed | ||
if cache_key not in _BASELINE_CACHE: | ||
# Switch model to eval and move to device | ||
base_model = base_model.eval().to(config.device) | ||
print("Benchmarking eager baseline inference.....") | ||
eager_baseline_time = model_inference_time_in_ms( | ||
model=base_model, input_data=input_data | ||
) | ||
|
||
# Run baseline benchmarking | ||
base_model = base_model.eval().to(config.device) | ||
if config.use_torch_compile: | ||
print("Compiling baseline model....") | ||
print("Benchmarking compile baseline inference.....") | ||
base_model = torch.compile( | ||
base_model, mode=config.torch_compile_mode, fullgraph=True | ||
) | ||
# Benchmark time to run an inference call for baseline model | ||
print("Benchmarking baseline inference.....") | ||
result.baseline_inference_time_in_ms = model_inference_time_in_ms( | ||
model=base_model, input_data=input_data | ||
) | ||
compile_baseline_time = model_inference_time_in_ms( | ||
model=base_model, input_data=input_data | ||
) | ||
|
||
# Store uncompiled model, input and baseline time | ||
_BASELINE_CACHE[cache_key] = (eager_baseline_time, compile_baseline_time) | ||
|
||
result.eager_baseline_inference_time_in_ms = eager_baseline_time | ||
result.compile_baseline_inference_time_in_ms = compile_baseline_time | ||
else: | ||
# Retrieve cached values | ||
cached_eager_time, cached_compile_time = _BASELINE_CACHE[cache_key] | ||
result.eager_baseline_inference_time_in_ms = cached_eager_time | ||
result.compile_baseline_inference_time_in_ms = cached_compile_time | ||
|
||
# At this point, ``base_model`` is an uncompiled model ready for quantization, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
# and ``input_data`` is the corresponding input tensor. The baseline time | ||
# has been stored in ``result.baseline_inference_time_in_ms``. | ||
|
||
# Copy base model for quantizing/sparsifying | ||
m_copy = deepcopy(base_model) | ||
|
||
# Determine quantization/sparsity configuration | ||
ao_base_config = string_to_config( | ||
config.quantization, | ||
config.sparsity, | ||
|
@@ -101,24 +175,39 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: | |
m_copy = m_copy.eval().to(config.device) | ||
quantize_(m_copy, ao_base_config) | ||
|
||
if config.use_torch_compile: | ||
print("Compiling quantized model....") | ||
m_copy = torch.compile( | ||
m_copy, mode=config.torch_compile_mode, fullgraph=True | ||
) | ||
|
||
# Store result in model for memory profiling | ||
m_copy._benchmark_result = result | ||
|
||
# Benchmark time to run an inference call for quantized model | ||
# Measure inference time for quantized model | ||
print("Benchmarking eager quantized model.....") | ||
result.eager_model_inference_time_in_ms = model_inference_time_in_ms( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add quantized somewhere in the name? |
||
model=m_copy, input_data=input_data | ||
) | ||
|
||
# Measure inference time for compiled quantized model | ||
print("Benchmarking quantized model.....") | ||
result.model_inference_time_in_ms = model_inference_time_in_ms( | ||
m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True) | ||
result.compile_model_inference_time_in_ms = model_inference_time_in_ms( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same for this one |
||
model=m_copy, input_data=input_data | ||
) | ||
|
||
# Calculate speedup w.r.t. baseline | ||
result.speedup = round( | ||
result.baseline_inference_time_in_ms / result.model_inference_time_in_ms, 2 | ||
# Compute eager speedup relative to baseline | ||
result.eager_speedup_on_baseline = round( | ||
result.eager_baseline_inference_time_in_ms | ||
/ result.eager_model_inference_time_in_ms, | ||
2, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: pass by keyword arg to show what this is |
||
) | ||
# Compute compile speedup relative to baseline | ||
result.compile_speedup_on_baseline = round( | ||
result.compile_baseline_inference_time_in_ms | ||
/ result.compile_model_inference_time_in_ms, | ||
2, | ||
) | ||
# Compute compile speedup for quantized model relative to eager quantized model | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to do this comparison? I think it might be more useful to just compare eager quantized v.s. eager baseline and compile quantized v.s. compile baseline, since these shows the speedup in different serving environments |
||
result.compile_speedup_on_eager = round( | ||
result.eager_model_inference_time_in_ms | ||
/ result.compile_model_inference_time_in_ms, | ||
2, | ||
) | ||
|
||
# Run profiler if enabled | ||
|
@@ -165,9 +254,9 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: | |
result.memory_profile_path | ||
) | ||
except ValueError as e: | ||
if "not enough values to unpack" in e: | ||
if "not enough values to unpack" in str(e): | ||
print( | ||
"Failed due to existing bugs, re-run the code to generate memory profile. Please raise an issue if it persists." | ||
"Failed due to existing bugs, re‑run the code to generate memory profile. Please raise an issue if it persists." | ||
) | ||
except Exception as e: | ||
print(f"Error running memory profiler: {e}") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add comment for what key is and maybe give an example