@@ -37,7 +37,7 @@ async def lifespan(app: FastAPI):
3737 "port" : port ,
3838 "id" : i ,
3939 }
40- )
40+ ) # yapf: disable
4141
4242 # Create decode clients
4343 for i , (host , port ) in enumerate (global_args .decoder_instances ):
@@ -51,20 +51,20 @@ async def lifespan(app: FastAPI):
5151 "port" : port ,
5252 "id" : i ,
5353 }
54- )
54+ ) # yapf: disable
5555
5656 # Initialize round-robin iterators
5757 app .state .prefill_iterator = itertools .cycle (
5858 range (len (app .state .prefill_clients ))
59- )
59+ ) # yapf: disable
6060 app .state .decode_iterator = itertools .cycle (
6161 range (len (app .state .decode_clients ))
62- )
62+ ) # yapf: disable
6363
6464 print (
6565 f"Initialized { len (app .state .prefill_clients )} prefill clients "
6666 f"and { len (app .state .decode_clients )} decode clients."
67- )
67+ ) # yapf: disable
6868
6969 yield
7070
@@ -111,7 +111,11 @@ def parse_args():
111111 default = ["localhost" ],
112112 )
113113 parser .add_argument (
114- "--decoder-ports" , "--decoder-port" , type = int , nargs = "+" , default = [8200 ]
114+ "--decoder-ports" ,
115+ "--decoder-port" ,
116+ type = int ,
117+ nargs = "+" ,
118+ default = [8200 ],
115119 )
116120
117121 args = parser .parse_args ()
@@ -120,17 +124,17 @@ def parse_args():
120124 if len (args .prefiller_hosts ) != len (args .prefiller_ports ):
121125 raise ValueError (
122126 "Number of prefiller hosts must match number of prefiller ports"
123- )
127+ ) # yapf: disable
124128
125129 if len (args .decoder_hosts ) != len (args .decoder_ports ):
126130 raise ValueError (
127131 "Number of decoder hosts must match number of decoder ports"
128- )
132+ ) # yapf: disable
129133
130134 # Create tuples of (host, port) for each service type
131135 args .prefiller_instances = list (
132136 zip (args .prefiller_hosts , args .prefiller_ports )
133- )
137+ ) # yapf: disable
134138 args .decoder_instances = list (zip (args .decoder_hosts , args .decoder_ports ))
135139
136140 return args
@@ -159,7 +163,7 @@ def get_next_client(app, service_type: str):
159163
160164async def send_request_to_service (
161165 client_info : dict , endpoint : str , req_data : dict , request_id : str
162- ):
166+ ): # yapf: disable
163167 """
164168 Send a request to a service using a client from the pool.
165169 """
@@ -185,15 +189,15 @@ async def send_request_to_service(
185189
186190 response = await client_info ["client" ].post (
187191 endpoint , json = req_data , headers = headers
188- )
192+ ) # yapf: disable
189193 response .raise_for_status ()
190194
191195 return response
192196
193197
194198async def stream_service_response (
195199 client_info : dict , endpoint : str , req_data : dict , request_id : str
196- ):
200+ ): # yapf: disable
197201 """
198202 Asynchronously stream response from a service using a client from the pool.
199203 """
@@ -204,7 +208,7 @@ async def stream_service_response(
204208
205209 async with client_info ["client" ].stream (
206210 "POST" , endpoint , json = req_data , headers = headers
207- ) as response :
211+ ) as response : # yapf: disable
208212 response .raise_for_status ()
209213 async for chunk in response .aiter_bytes ():
210214 yield chunk
@@ -221,7 +225,7 @@ async def _handle_completions(api: str, request: Request):
221225 # Send request to prefill service
222226 response = await send_request_to_service (
223227 prefill_client_info , api , req_data , request_id
224- )
228+ ) # yapf: disable
225229
226230 # Extract the needed fields
227231 response_json = response .json ()
@@ -238,19 +242,21 @@ async def _handle_completions(api: str, request: Request):
238242 async def generate_stream ():
239243 async for chunk in stream_service_response (
240244 decode_client_info , api , req_data , request_id = request_id
241- ):
245+ ): # yapf: disable
242246 yield chunk
243247
244248 return StreamingResponse (
245249 generate_stream (), media_type = "application/json"
246- )
250+ ) # yapf: disable
247251
248252 except Exception as e :
249253 import sys
250254 import traceback
251255
252256 exc_info = sys .exc_info ()
253- print (f"Error occurred in disagg prefill proxy server - { api } endpoint" )
257+ print (
258+ f"Error occurred in disagg prefill proxy server - { api } endpoint"
259+ ) # yapf: disable
254260 print (e )
255261 print ("" .join (traceback .format_exception (* exc_info )))
256262 raise
0 commit comments