Skip to content

Performance Regression in Whisper with torch.compile on v4.52.0+ (vs v4.51.0) #40682

@baeseongsu

Description

@baeseongsu

System Info

  • transformers version: 4.52.0
  • Platform: Linux-5.15.0-106-generic-x86_64-with-glibc2.39
  • Python version: 3.12.3
  • Huggingface_hub version: 0.34.4
  • Safetensors version: 0.5.2
  • Accelerate version: 1.5.2
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.7.1+cu126 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA H100 80GB HBM3

Who can help?

@eustlb @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Minimal Reproducible Code
  • The base code snippet follows the official openai/whisper-large-v3-turbo model card's description for using torch.compile.
import gc
import time
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

def benchmark_whisper(compile_model=False, num_warmup=2, num_runs=5):
    """Benchmark Whisper with and without torch.compile."""
    print(f"\n{'='*50}")
    print(f"Testing: {'Compiled' if compile_model else 'Not Compiled'}")
    print(f"{'='*50}")

    # Setup
    torch.set_float32_matmul_precision("high")
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    model_id = "openai/whisper-large-v3-turbo"

    # Clear memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    # Load model
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
    ).to(device)

    # Configure for compilation
    model.generation_config.cache_implementation = "static"
    
    if compile_model:
        model = torch.compile(model, mode="reduce-overhead", fullgraph=True)

    processor = AutoProcessor.from_pretrained(model_id)
    pipe = pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        torch_dtype=torch_dtype,
        device=device,
        return_timestamps=True,
    )

    # Load sample audio
    dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
    sample = dataset[0]["audio"]

    # Warmup and Benchmark...
    print(f"Warmup ({num_warmup} steps)...")
    for _ in tqdm(range(num_warmup)):
        _ = pipe(sample.copy(), generate_kwargs={"language": "en", "task": "transcribe"})

    print(f"Benchmarking ({num_runs} runs)...")
    times = []
    for _ in tqdm(range(num_runs)):
        start_time = time.time()
        result = pipe(sample.copy(), generate_kwargs={"language": "en", "task": "transcribe"})
        times.append(time.time() - start_time)

    avg_time = sum(times) / len(times)
    print(f"Average time: {avg_time:.3f}s")

def main():
    print("Whisper Compilation Benchmark")
    benchmark_whisper(compile_model=False)
    benchmark_whisper(compile_model=True)

if __name__ == "__main__":
    main()
  1. How to Reproduce
  • Run the script with a problematic version (>=4.52.0) and the last known stable version (4.51.0).
# Problematic version
pip install "transformers>=4.52.0"
python benchmark_script.py

# Stable version
pip install "transformers==4.51.0"
python benchmark_script.py
  1. Observed Behavior & Logs
  • transformers==4.51.0 (Stable):
    • Not Compiled: ~0.67s
    • Compiled: ~0.66s (1.02x speedup)
    • Logs are clean.
  • transformers>=4.52.0 (Problematic):
    • Not Compiled: ~3.15s
    • Compiled: ~6.33s (0.50x speedup, i.e., 2x slower)
    • Logs are filled with repeated warnings, indicating recompilation failures:
      skipping cudagraphs due to mutated inputs (...)
      torch._dynamo hit config.recompile_limit (8)
      

Expected behavior

  • The torch.compile path is expected to be faster and more stable than eager mode, as seen in v4.51.0. A library upgrade should not introduce such a noticeable and persistent slowdown. The skipping cudagraphs warnings are triggered by in-place updates in the static KV cache (e.g., index_copy_), which disable CUDA graph capture and limit the performance benefits of compilation.

  • (Note: I have not fully reviewed Whisper’s generate internals, so if this behavior is intentional or my script setup is incorrect, please let me know.)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions