@@ -84,27 +84,14 @@ def __init__(self,
84
84
self .model_name = model_name
85
85
self .pbar = None
86
86
87
- def _inference (self ,
88
- req_queue : Queue ,
89
- res_queue : Queue ,
90
- session_id : int ,
91
- stream_output : bool ,
92
- image_url : str = None ):
87
+ def _inference (self , req_queue : Queue , res_queue : Queue , session_id : int ,
88
+ stream_output : bool ):
93
89
94
90
stats = []
95
91
client = APIClient (self .server_addr , api_key = self .api_key )
96
92
97
93
for prompt , input_seqlen , output_seqlen in iter (
98
94
req_queue .get , [None , None , None ]):
99
- if image_url is not None :
100
- prompt = [
101
- dict (role = 'user' ,
102
- content = [
103
- dict (type = 'text' , text = prompt ),
104
- dict (type = 'image_url' ,
105
- image_url = dict (url = image_url ))
106
- ])
107
- ]
108
95
timestamps = []
109
96
timestamps .append (time .perf_counter ())
110
97
for output in client .chat_completions_v1 (
@@ -136,8 +123,7 @@ def _inference(self,
136
123
def process_request (self ,
137
124
requests ,
138
125
concurrency : int = 1 ,
139
- stream_output : bool = False ,
140
- img_hw : str = None ):
126
+ stream_output : bool = False ):
141
127
res_queue = Queue ()
142
128
req_queue = Queue ()
143
129
threads = []
@@ -152,28 +138,10 @@ def process_request(self,
152
138
153
139
start = time .time ()
154
140
155
- if img_hw is not None :
156
- import PIL
157
-
158
- from lmdeploy .vl .utils import encode_image_base64
159
- h , w = [int (s ) for s in img_hw .split ('x' )]
160
- data = np .random .randint (low = 0 ,
161
- high = 255 ,
162
- size = h * w * 3 ,
163
- dtype = np .uint8 )
164
- data = data .reshape (h , w , 3 )
165
- img = PIL .Image .fromarray (data , 'RGB' )
166
- encoded = encode_image_base64 (img )
167
- image_url = f'data:image/jpeg;base64,{ encoded } '
168
- else :
169
- image_url = None
170
-
171
141
# start threads
172
142
for i in range (concurrency ):
173
143
t = Thread (target = self ._inference ,
174
- # args=(req_queue, res_queue, i, stream_output))
175
- args = (req_queue , res_queue , i , stream_output ,
176
- image_url ))
144
+ args = (req_queue , res_queue , i , stream_output ))
177
145
t .start ()
178
146
threads .append (t )
179
147
@@ -254,8 +222,7 @@ def main(server_addr: str,
254
222
temperature : float = 1.0 ,
255
223
stream_output : bool = False ,
256
224
csv : str = './profile_api_server.csv' ,
257
- seed : int = 0 ,
258
- img_hw : str = None ):
225
+ seed : int = 0 ):
259
226
"""Benchmark the request througput of api server.
260
227
261
228
Args:
@@ -273,8 +240,6 @@ def main(server_addr: str,
273
240
stream_output (bool, optional): Indicator for streaming output. Defaults to False.
274
241
csv (str, optional): The path to save the result.
275
242
seed (int, optional): Seed used in sampling prompts from dataset. Defaults to 0.
276
- img_hw (str, optional): The image size to benchmark vl serving, such as '512x512'.
277
- Default to None, which means to benchmark language model only.
278
243
""" # noqa
279
244
if not server_addr .startswith ('http://' ):
280
245
print (f'[WARNING] server_addr of the api_server should '
@@ -293,7 +258,7 @@ def main(server_addr: str,
293
258
294
259
requests = sample_requests (dataset , num_prompts , engine .tokenizer )
295
260
296
- engine .process_request (requests , concurrency , stream_output , img_hw )
261
+ engine .process_request (requests , concurrency , stream_output )
297
262
298
263
299
264
if __name__ == '__main__' :
0 commit comments