Skip to content

Commit ba3826b

Browse files
daje0601jcggl
authored andcommitted
[Model] Add LoRA support for Whisper models
This PR enables Multi-LoRA support for Whisper speech-to-text models, allowing users to serve multiple fine-tuned Whisper adapters from a single base model. Changes: - Add SupportsLoRA interface to WhisperForConditionalGeneration - Add embedding_modules and embedding_padding_modules attributes - Update packed_modules_mapping for LoRA compatibility - Extend MergedQKVParallelLinearWithLoRA to support KV-only (2-slice) configurations used in Whisper's cross-attention layers - Add fallback to max_target_positions in WorkerLoRAManager for Whisper compatibility - Add example script for Whisper Multi-LoRA inference - Add unit tests for Whisper LoRA support Signed-off-by: daje0601 <[email protected]>
1 parent 3b221cb commit ba3826b

File tree

5 files changed

+357
-37
lines changed

5 files changed

+357
-37
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
This example shows how to use multi-LoRA functionality with
5+
Whisper models for speech-to-text transcription.
6+
7+
Usage:
8+
python whisper_multilora_inference.py
9+
10+
Note: Replace LORA_PATH with your actual LoRA adapter path.
11+
If you don't have a LoRA adapter, the example will run with
12+
the base model only.
13+
"""
14+
15+
import os
16+
17+
from vllm import LLM, SamplingParams
18+
from vllm.assets.audio import AudioAsset
19+
from vllm.lora.request import LoRARequest
20+
21+
22+
def create_whisper_prompt(language: str = "en") -> dict:
23+
"""Create a Whisper transcription prompt with audio input.
24+
25+
Args:
26+
language: ISO 639-1 language code (e.g., "en", "ko", "ja")
27+
28+
Returns:
29+
Dictionary with prompt and multi-modal data
30+
"""
31+
# Load sample audio from vLLM assets
32+
audio_asset = AudioAsset("mary_had_lamb")
33+
audio_data = audio_asset.audio_and_sample_rate
34+
35+
# Whisper prompt format:
36+
# <|startoftranscript|><|language|><|task|><|notimestamps|>
37+
prompt = f"<|startoftranscript|><|{language}|><|transcribe|><|notimestamps|>"
38+
39+
return {
40+
"prompt": prompt,
41+
"multi_modal_data": {
42+
"audio": audio_data,
43+
},
44+
}
45+
46+
47+
def run_base_model_inference(llm: LLM, sampling_params: SamplingParams) -> None:
48+
"""Run inference using the base Whisper model without LoRA."""
49+
print("\n" + "=" * 60)
50+
print("Running inference with BASE MODEL (no LoRA)")
51+
print("=" * 60)
52+
53+
inputs = create_whisper_prompt(language="en")
54+
outputs = llm.generate([inputs], sampling_params=sampling_params)
55+
56+
for output in outputs:
57+
print(f"Transcription: {output.outputs[0].text}")
58+
59+
60+
def run_lora_inference(
61+
llm: LLM,
62+
sampling_params: SamplingParams,
63+
lora_path: str,
64+
lora_name: str,
65+
lora_id: int,
66+
) -> None:
67+
"""Run inference using a specific LoRA adapter.
68+
69+
Args:
70+
llm: The vLLM engine
71+
sampling_params: Sampling parameters
72+
lora_path: Path to the LoRA adapter
73+
lora_name: Name identifier for the LoRA
74+
lora_id: Unique integer ID for the LoRA
75+
"""
76+
print("\n" + "=" * 60)
77+
print(f"Running inference with LoRA: {lora_name}")
78+
print("=" * 60)
79+
80+
inputs = create_whisper_prompt(language="en")
81+
lora_request = LoRARequest(lora_name, lora_id, lora_path)
82+
83+
outputs = llm.generate(
84+
[inputs],
85+
sampling_params=sampling_params,
86+
lora_request=lora_request,
87+
)
88+
89+
for output in outputs:
90+
print(f"Transcription: {output.outputs[0].text}")
91+
92+
93+
def main():
94+
"""Main function demonstrating Whisper Multi-LoRA inference."""
95+
# Initialize Whisper model with LoRA support enabled
96+
print("Initializing Whisper model with Multi-LoRA support...")
97+
llm = LLM(
98+
model="openai/whisper-large-v3-turbo",
99+
enable_lora=True,
100+
max_loras=4, # Maximum number of LoRAs to keep in memory
101+
max_lora_rank=64, # Maximum LoRA rank supported
102+
max_model_len=448, # Whisper's max target positions
103+
dtype="half",
104+
gpu_memory_utilization=0.8,
105+
trust_remote_code=True,
106+
)
107+
108+
sampling_params = SamplingParams(
109+
temperature=0,
110+
max_tokens=200,
111+
)
112+
113+
# Run base model inference
114+
run_base_model_inference(llm, sampling_params)
115+
116+
# Example LoRA paths - replace with your actual LoRA adapters
117+
lora_paths = [
118+
("lora_adapter_1", "/path/to/your/lora_adapter_1"),
119+
("lora_adapter_2", "/path/to/your/lora_adapter_2"),
120+
]
121+
122+
# Run inference with each LoRA adapter (if paths exist)
123+
for lora_id, (lora_name, lora_path) in enumerate(lora_paths, start=1):
124+
if os.path.exists(lora_path):
125+
run_lora_inference(llm, sampling_params, lora_path, lora_name, lora_id)
126+
else:
127+
print(f"\nSkipping {lora_name}: path does not exist ({lora_path})")
128+
print("To use LoRA adapters, update lora_paths with valid paths.")
129+
130+
print("\n" + "=" * 60)
131+
print("Multi-LoRA inference complete!")
132+
print("=" * 60)
133+
134+
135+
if __name__ == "__main__":
136+
main()

tests/lora/test_whisper_lora.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Tests for Whisper Multi-LoRA support.
5+
6+
This module tests:
7+
1. WhisperForConditionalGeneration LoRA interface compliance
8+
2. MergedQKVParallelLinearWithLoRA support for KV-only (2-slice) configuration
9+
3. WorkerLoRAManager compatibility with Whisper's max_target_positions
10+
"""
11+
12+
import pytest
13+
import torch
14+
15+
from vllm.lora.layers import (
16+
MergedQKVParallelLinearWithLoRA,
17+
)
18+
from vllm.model_executor.layers.linear import QKVParallelLinear
19+
from vllm.model_executor.models.whisper import WhisperForConditionalGeneration
20+
from vllm.platforms import current_platform
21+
22+
pytestmark = pytest.mark.skipif(
23+
not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
24+
reason="Backend not supported",
25+
)
26+
27+
28+
class TestWhisperLoRAInterface:
29+
"""Test that WhisperForConditionalGeneration has proper LoRA support."""
30+
31+
def test_supports_lora_attribute(self):
32+
"""Verify that WhisperForConditionalGeneration has SupportsLoRA interface."""
33+
from vllm.model_executor.models.interfaces import SupportsLoRA
34+
35+
assert issubclass(WhisperForConditionalGeneration, SupportsLoRA), (
36+
"WhisperForConditionalGeneration should inherit from SupportsLoRA"
37+
)
38+
39+
def test_embedding_modules_defined(self):
40+
"""Verify embedding_modules attribute is defined."""
41+
assert hasattr(WhisperForConditionalGeneration, "embedding_modules")
42+
assert isinstance(WhisperForConditionalGeneration.embedding_modules, dict)
43+
44+
def test_embedding_padding_modules_defined(self):
45+
"""Verify embedding_padding_modules attribute is defined."""
46+
assert hasattr(WhisperForConditionalGeneration, "embedding_padding_modules")
47+
assert isinstance(
48+
WhisperForConditionalGeneration.embedding_padding_modules, list
49+
)
50+
51+
def test_packed_modules_mapping_format(self):
52+
"""Verify packed_modules_mapping has correct format for LoRA."""
53+
mapping = WhisperForConditionalGeneration.packed_modules_mapping
54+
55+
# Should have qkv_proj and kv_proj mappings
56+
assert "qkv_proj" in mapping, "Missing qkv_proj in packed_modules_mapping"
57+
assert "kv_proj" in mapping, "Missing kv_proj in packed_modules_mapping"
58+
59+
# qkv_proj should map to [q_proj, k_proj, v_proj]
60+
assert mapping["qkv_proj"] == ["q_proj", "k_proj", "v_proj"]
61+
62+
# kv_proj should map to [k_proj, v_proj] (for cross-attention)
63+
assert mapping["kv_proj"] == ["k_proj", "v_proj"]
64+
65+
66+
class TestMergedQKVParallelLinearWithLoRAKVOnly:
67+
"""Test MergedQKVParallelLinearWithLoRA with KV-only (2-slice) configuration."""
68+
69+
def test_can_replace_layer_accepts_2_modules(self):
70+
"""Verify can_replace_layer accepts 2-module (KV-only) configurations."""
71+
from vllm.config.lora import LoRAConfig
72+
73+
# Create a mock QKVParallelLinear layer
74+
# This simulates a KV-only projection (like Whisper's encoder_attn.kv_proj)
75+
linear = QKVParallelLinear(
76+
hidden_size=512,
77+
head_size=64,
78+
total_num_heads=8,
79+
total_num_kv_heads=8,
80+
bias=False,
81+
params_dtype=torch.float16,
82+
)
83+
84+
lora_config = LoRAConfig(
85+
max_lora_rank=32,
86+
max_loras=4,
87+
max_cpu_loras=4,
88+
lora_extra_vocab_size=0,
89+
)
90+
91+
# Test with 2 modules (KV-only, like encoder_attn.kv_proj)
92+
packed_modules_2 = ["k_proj", "v_proj"]
93+
result_2 = MergedQKVParallelLinearWithLoRA.can_replace_layer(
94+
source_layer=linear,
95+
lora_config=lora_config,
96+
packed_modules_list=packed_modules_2,
97+
model_config=None,
98+
)
99+
assert result_2 is True, "Should accept 2-module (KV-only) configuration"
100+
101+
# Test with 3 modules (QKV, like self_attn.qkv_proj)
102+
packed_modules_3 = ["q_proj", "k_proj", "v_proj"]
103+
result_3 = MergedQKVParallelLinearWithLoRA.can_replace_layer(
104+
source_layer=linear,
105+
lora_config=lora_config,
106+
packed_modules_list=packed_modules_3,
107+
model_config=None,
108+
)
109+
assert result_3 is True, "Should accept 3-module (QKV) configuration"
110+
111+
# Test with 1 module (should be rejected)
112+
packed_modules_1 = ["q_proj"]
113+
result_1 = MergedQKVParallelLinearWithLoRA.can_replace_layer(
114+
source_layer=linear,
115+
lora_config=lora_config,
116+
packed_modules_list=packed_modules_1,
117+
model_config=None,
118+
)
119+
assert result_1 is False, "Should reject 1-module configuration"
120+
121+
122+
class TestWorkerLoRAManagerWhisperCompat:
123+
"""Test WorkerLoRAManager compatibility with Whisper config."""
124+
125+
def test_max_position_embeddings_fallback(self):
126+
"""Test that max_target_positions is used when missing."""
127+
128+
# Create a mock config similar to Whisper's
129+
class MockWhisperConfig:
130+
def __init__(self):
131+
self.max_target_positions = 448
132+
# Note: no max_position_embeddings attribute
133+
134+
def get_text_config(self):
135+
return self
136+
137+
config = MockWhisperConfig()
138+
139+
# Simulate the logic from WorkerLoRAManager
140+
max_pos = getattr(
141+
config,
142+
"max_position_embeddings",
143+
getattr(config, "max_target_positions", None),
144+
)
145+
146+
assert max_pos == 448, "Should fall back to max_target_positions"
147+
148+
def test_max_position_embeddings_priority(self):
149+
"""Test that max_position_embeddings takes priority when present."""
150+
151+
class MockLLMConfig:
152+
def __init__(self):
153+
self.max_position_embeddings = 4096
154+
self.max_target_positions = 448
155+
156+
def get_text_config(self):
157+
return self
158+
159+
config = MockLLMConfig()
160+
161+
# Simulate the logic from WorkerLoRAManager
162+
max_pos = getattr(
163+
config,
164+
"max_position_embeddings",
165+
getattr(config, "max_target_positions", None),
166+
)
167+
168+
assert max_pos == 4096, "Should use max_position_embeddings when present"

vllm/lora/layers/column_parallel_linear.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,6 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
356356

357357
def __init__(self, base_layer: QKVParallelLinear) -> None:
358358
super().__init__(base_layer)
359-
# There are three LoRA layer.
360-
self.n_slices = len(self.base_layer.output_sizes)
361359

362360
self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
363361
self.kv_proj_shard_size = (
@@ -366,16 +364,23 @@ def __init__(self, base_layer: QKVParallelLinear) -> None:
366364
self.q_shard_id = self.tp_rank
367365
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
368366

369-
self.output_slices = (
370-
self.q_proj_shard_size,
371-
self.kv_proj_shard_size,
372-
self.kv_proj_shard_size,
373-
)
374-
self.output_ids = (
375-
self.q_shard_id,
376-
self.kv_shard_id,
377-
self.kv_shard_id,
378-
)
367+
# Build output_slices and output_ids dynamically to support both
368+
# QKV (3 slices) and KV-only (2 slices) configurations.
369+
# KV-only is used in cross-attention layers (e.g., Whisper encoder_attn).
370+
slices = []
371+
ids = []
372+
if self.q_proj_shard_size > 0:
373+
slices.append(self.q_proj_shard_size)
374+
ids.append(self.q_shard_id)
375+
if self.kv_proj_shard_size > 0:
376+
slices.append(self.kv_proj_shard_size)
377+
ids.append(self.kv_shard_id)
378+
slices.append(self.kv_proj_shard_size)
379+
ids.append(self.kv_shard_id)
380+
381+
self.output_slices = tuple(slices)
382+
self.output_ids = tuple(ids)
383+
self.n_slices = len(self.output_slices)
379384

380385
def create_lora_weights(
381386
self,
@@ -398,7 +403,11 @@ def can_replace_layer(
398403
packed_modules_list: list,
399404
model_config: PretrainedConfig | None = None,
400405
) -> bool:
401-
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3
406+
# Support both QKV (3 modules) and KV-only (2 modules) configurations
407+
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) in (
408+
2,
409+
3,
410+
)
402411

403412

404413
# These following layers are based on the tensor parallelism strategy given in
@@ -539,21 +548,18 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
539548
def slice_lora_a(
540549
self, lora_a: list[torch.Tensor | None]
541550
) -> list[torch.Tensor | None]:
542-
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
543-
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
544-
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
545-
lora_a = [
546-
lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :]
547-
if lora_a[0] is not None
548-
else None,
549-
lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :]
550-
if lora_a[1] is not None
551-
else None,
552-
lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :]
553-
if lora_a[2] is not None
554-
else None,
555-
]
556-
return lora_a
551+
# NOTE: lora_a contains n_slices subloras, and each sublora could be None.
552+
# n_slices is 3 for QKV and 2 for KV-only configurations.
553+
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(self.n_slices)]
554+
start_idx = [self.tp_rank * shard_size[i] for i in range(self.n_slices)]
555+
result: list[torch.Tensor | None] = []
556+
for i in range(self.n_slices):
557+
lora_a_i = lora_a[i]
558+
if lora_a_i is not None:
559+
result.append(lora_a_i[start_idx[i] : start_idx[i] + shard_size[i], :])
560+
else:
561+
result.append(None)
562+
return result
557563

558564
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
559565
return _mcp_apply(x, bias, self)

0 commit comments

Comments
 (0)