Skip to content

Commit 4f32ae1

Browse files
committed
Merge remote-tracking branch 'origin/main' into wuxun/v1-dp-attention
2 parents fa5da77 + c2a0171 commit 4f32ae1

File tree

6 files changed

+655
-23
lines changed

6 files changed

+655
-23
lines changed

tests/unit_tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
initialize_model_parallel)
44
import pytest
55
import tempfile
6+
from huggingface_hub import snapshot_download
67

78

89
@pytest.fixture
@@ -18,3 +19,14 @@ def dist_init():
1819
initialize_model_parallel(1, 1)
1920
yield
2021
cleanup_dist_env_and_memory()
22+
23+
24+
@pytest.fixture(scope="session")
25+
def sql_lora_huggingface_id():
26+
# huggingface repo id is used to test lora runtime downloading.
27+
return "yard1/llama-2-7b-sql-lora-test"
28+
29+
30+
@pytest.fixture(scope="session")
31+
def sql_lora_files(sql_lora_huggingface_id):
32+
return snapshot_download(repo_id=sql_lora_huggingface_id)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import pytest
2+
from typing import Optional
3+
4+
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
5+
from vllm.lora.request import LoRARequest
6+
7+
MODEL_PATH = "/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf"
8+
9+
10+
def create_test_prompts(
11+
lora_path: str
12+
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
13+
"""Create a list of test prompts with their sampling parameters.
14+
15+
2 requests for base model, 4 requests for the LoRA. We define 2
16+
different LoRA adapters (using the same model for demo purposes).
17+
"""
18+
return [
19+
(
20+
"A robot may not injure a human being",
21+
SamplingParams(
22+
temperature=0.0,
23+
#logprobs=1,
24+
#prompt_logprobs=1,
25+
max_tokens=128),
26+
None),
27+
(
28+
"To be or not to be,",
29+
SamplingParams(
30+
temperature=0.0,
31+
top_k=5,
32+
#presence_penalty=0.2,
33+
max_tokens=128),
34+
None),
35+
(
36+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
37+
SamplingParams(
38+
temperature=0.0,
39+
#logprobs=1,
40+
#prompt_logprobs=1,
41+
max_tokens=128,
42+
stop_token_ids=[32003]),
43+
LoRARequest("sql-lora", 1, lora_path)),
44+
(
45+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
46+
SamplingParams(temperature=0,
47+
max_tokens=128,
48+
stop_token_ids=[32003]),
49+
LoRARequest("sql-lora", 1, lora_path)),
50+
(
51+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
52+
SamplingParams(
53+
temperature=0.0,
54+
#logprobs=1,
55+
#prompt_logprobs=1,
56+
max_tokens=128,
57+
stop_token_ids=[32003]),
58+
LoRARequest("sql-lora2", 2, lora_path)),
59+
(
60+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
61+
SamplingParams(temperature=0,
62+
max_tokens=128,
63+
stop_token_ids=[32003]),
64+
LoRARequest("sql-lora", 1, lora_path)),
65+
]
66+
67+
68+
def process_requests(engine: LLMEngine,
69+
test_prompts: list[tuple[str, SamplingParams,
70+
Optional[LoRARequest]]]):
71+
"""Continuously process a list of prompts and handle the outputs."""
72+
request_id = 0
73+
result = {}
74+
75+
while test_prompts or engine.has_unfinished_requests():
76+
if test_prompts:
77+
prompt, sampling_params, lora_request = test_prompts.pop(0)
78+
engine.add_request(str(request_id),
79+
prompt,
80+
sampling_params,
81+
lora_request=lora_request)
82+
request_id += 1
83+
84+
request_outputs: list[RequestOutput] = engine.step()
85+
86+
for request_output in request_outputs:
87+
if request_output.finished:
88+
result[
89+
request_output.request_id] = request_output.outputs[0].text
90+
return result
91+
92+
93+
expected_output = [
94+
" or, through inaction, allow a human being to come to harm.\nA robot must obey the orders given it by human beings except where such orders would conflict with the First Law.\nA robot must protect its own existence as long as such protection does not conflict with the First or Second Law.\nThe Three Laws of Robotics were created by Isaac Asimov in 1942. They are the foundation of robotics and artificial intelligence.\nThe Three Laws of Robotics are the foundation of robotics and artificial intelligence. They were created by Isaac Asimov in 194", # noqa: E501
95+
" that is the question.\nThe question is not whether you will be a leader, but whether you will be a good leader.\nThe question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The", # noqa: E501
96+
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
97+
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501
98+
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
99+
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' " # noqa: E501
100+
]
101+
102+
103+
def _test_llama_multilora(sql_lora_files, tp_size):
104+
"""Main function that sets up and runs the prompt processing."""
105+
engine_args = EngineArgs(model=MODEL_PATH,
106+
enable_lora=True,
107+
max_loras=2,
108+
max_lora_rank=8,
109+
max_num_seqs=256,
110+
dtype='bfloat16',
111+
tensor_parallel_size=tp_size)
112+
engine = LLMEngine.from_engine_args(engine_args)
113+
test_prompts = create_test_prompts(sql_lora_files)
114+
results = process_requests(engine, test_prompts)
115+
generated_texts = [results[key] for key in sorted(results)]
116+
assert generated_texts == expected_output
117+
118+
119+
@pytest.mark.xfail(reason="Weka not available")
120+
def test_llama_multilora_1x(sql_lora_files):
121+
_test_llama_multilora(sql_lora_files, 1)
122+
123+
124+
#def test_llama_multilora_2x(sql_lora_files):
125+
# _test_llama_multilora(sql_lora_files, 2)
126+
127+
#def test_llama_multilora_4x(sql_lora_files):
128+
# _test_llama_multilora(sql_lora_files, 4)
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
from typing import Union
5+
6+
import vllm
7+
from vllm.lora.request import LoRARequest
8+
#from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test
9+
10+
MODEL_PATH = "/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf"
11+
12+
EXPECTED_NO_LORA_OUTPUT = [
13+
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant", # noqa: E501
14+
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501
15+
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501
16+
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio", # noqa: E501
17+
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501
18+
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for", # noqa: E501
19+
]
20+
EXPECTED_LORA_OUTPUT = [
21+
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
22+
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501
23+
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩ok", # noqa: E501
24+
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
25+
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'minnesota lynx' ", # noqa: E501
26+
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501
27+
]
28+
29+
30+
def do_sample(llm: vllm.LLM,
31+
lora_path: str,
32+
lora_id: int,
33+
tensorizer_config_dict: Union[dict, None] = None) -> list[str]:
34+
prompts = [
35+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
36+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
37+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501
38+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
39+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
40+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
41+
]
42+
43+
sampling_params = vllm.SamplingParams(temperature=0,
44+
max_tokens=64,
45+
skip_special_tokens=False,
46+
stop=["[/assistant]"])
47+
48+
if tensorizer_config_dict is not None:
49+
outputs = llm.generate(
50+
prompts,
51+
sampling_params,
52+
lora_request=LoRARequest(
53+
str(lora_id),
54+
lora_id,
55+
lora_path,
56+
tensorizer_config_dict=tensorizer_config_dict)
57+
if lora_id else None)
58+
else:
59+
outputs = llm.generate(
60+
prompts,
61+
sampling_params,
62+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
63+
if lora_id else None)
64+
# Print the outputs.
65+
generated_texts: list[str] = []
66+
for output in outputs:
67+
prompt = output.prompt
68+
generated_text = output.outputs[0].text
69+
generated_texts.append(generated_text)
70+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
71+
return generated_texts
72+
73+
74+
def generate_and_test(llm,
75+
sql_lora_files,
76+
tensorizer_config_dict: Union[dict, None] = None):
77+
print("lora adapter created")
78+
assert do_sample(llm,
79+
sql_lora_files,
80+
tensorizer_config_dict=tensorizer_config_dict,
81+
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
82+
83+
print("lora 1")
84+
assert do_sample(llm,
85+
sql_lora_files,
86+
tensorizer_config_dict=tensorizer_config_dict,
87+
lora_id=1) == EXPECTED_LORA_OUTPUT
88+
89+
print("no lora")
90+
assert do_sample(llm,
91+
sql_lora_files,
92+
tensorizer_config_dict=tensorizer_config_dict,
93+
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
94+
95+
print("lora 2")
96+
assert do_sample(llm,
97+
sql_lora_files,
98+
tensorizer_config_dict=tensorizer_config_dict,
99+
lora_id=2) == EXPECTED_LORA_OUTPUT
100+
101+
print("removing lora")
102+
103+
104+
#@create_new_process_for_each_test()
105+
@pytest.mark.xfail(reason="Weka not available")
106+
def test_llama_lora(sql_lora_files):
107+
108+
llm = vllm.LLM(
109+
MODEL_PATH,
110+
enable_lora=True,
111+
# also test odd max_num_seqs
112+
max_num_seqs=13,
113+
max_loras=4,
114+
dtype='bfloat16',
115+
)
116+
generate_and_test(llm, sql_lora_files)
117+
118+
119+
'''@multi_gpu_test(num_gpus=4)
120+
@create_new_process_for_each_test()
121+
def test_llama_lora_tp4(sql_lora_files):
122+
123+
llm = vllm.LLM(
124+
MODEL_PATH,
125+
enable_lora=True,
126+
max_num_seqs=16,
127+
max_loras=4,
128+
tensor_parallel_size=4,
129+
enable_chunked_prefill=True,
130+
)
131+
generate_and_test(llm, sql_lora_files)
132+
133+
134+
@multi_gpu_test(num_gpus=4)
135+
@create_new_process_for_each_test()
136+
def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
137+
138+
llm = vllm.LLM(
139+
MODEL_PATH,
140+
enable_lora=True,
141+
max_num_seqs=16,
142+
max_loras=4,
143+
tensor_parallel_size=4,
144+
fully_sharded_loras=True,
145+
enable_chunked_prefill=True,
146+
)
147+
generate_and_test(llm, sql_lora_files)
148+
149+
150+
@multi_gpu_test(num_gpus=2)
151+
@create_new_process_for_each_test()
152+
def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
153+
sql_lora_huggingface_id):
154+
155+
# Run the tensorizing of the LoRA adapter and the model in a subprocess
156+
# to guarantee cleanup
157+
158+
tp_size = 2
159+
model_name = "model-rank-%03d.tensors"
160+
161+
model_ref = MODEL_PATH
162+
lora_path = sql_lora_huggingface_id
163+
suffix = "test"
164+
try:
165+
result = subprocess.run([
166+
sys.executable,
167+
f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model",
168+
MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size",
169+
str(tp_size), "serialize", "--serialized-directory",
170+
str(tmp_path), "--suffix", suffix, "--serialization-kwargs",
171+
'{"limit_cpu_concurrency": 4}'
172+
],
173+
check=True,
174+
capture_output=True,
175+
text=True)
176+
except subprocess.CalledProcessError as e:
177+
print("Tensorizing failed.")
178+
print("STDOUT:\n", e.stdout)
179+
print("STDERR:\n", e.stderr)
180+
raise
181+
182+
print("STDOUT:\n", result.stdout)
183+
184+
model_uri = tmp_path / "vllm" / model_ref / suffix / model_name
185+
tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri))
186+
187+
loaded_llm = LLM(model=model_ref,
188+
load_format="tensorizer",
189+
enable_lora=True,
190+
enforce_eager=True,
191+
model_loader_extra_config=tensorizer_config,
192+
max_num_seqs=13,
193+
tensor_parallel_size=2,
194+
max_loras=2)
195+
196+
tc_as_dict = tensorizer_config.to_serializable()
197+
198+
print("lora adapter created")
199+
assert do_sample(loaded_llm,
200+
sql_lora_files,
201+
tensorizer_config_dict=tc_as_dict,
202+
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
203+
204+
print("lora 1")
205+
assert do_sample(loaded_llm,
206+
sql_lora_files,
207+
tensorizer_config_dict=tc_as_dict,
208+
lora_id=1) == EXPECTED_LORA_OUTPUT'''

0 commit comments

Comments
 (0)