Skip to content

Commit 62e9b2c

Browse files
committed
Format to make both yapf and ruff happy
Signed-off-by: Chendi.Xue <[email protected]>
1 parent b9b1921 commit 62e9b2c

File tree

3 files changed

+35
-28
lines changed

3 files changed

+35
-28
lines changed

examples/nixl/test_accuracy.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def run_simple_prompt():
2727
client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL)
2828
completion = client.completions.create(
2929
model=MODEL_NAME, prompt=SIMPLE_PROMPT
30-
)
30+
) # yapf: disable
3131

3232
print("-" * 50)
3333
print(f"Completion results for {MODEL_NAME}:")
@@ -43,7 +43,7 @@ def test_accuracy():
4343
f"model={MODEL_NAME},"
4444
f"base_url={BASE_URL}/completions,"
4545
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False"
46-
)
46+
) # yapf: disable
4747

4848
results = lm_eval.simple_evaluate(
4949
model="local-completions",
@@ -58,11 +58,12 @@ def test_accuracy():
5858
print(
5959
f"Warning: No expected value found for {MODEL_NAME}. "
6060
"Skipping accuracy check."
61-
)
61+
) # yapf: disable
6262
print(f"Measured value: {measured_value}")
6363
return
6464

6565
assert (
6666
measured_value - RTOL < expected_value
6767
and measured_value + RTOL > expected_value
68-
), f"Expected: {expected_value} | Measured: {measured_value}"
68+
), \
69+
f"Expected: {expected_value} | Measured: {measured_value}" # yapf: disable

examples/nixl/test_edge_cases.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None:
1212
raise ValueError(
1313
"Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT."
14-
)
14+
) # yapf: disable
1515

1616
LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501
1717
PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501
@@ -41,11 +41,11 @@ def test_edge_cases():
4141
# less than the length of the block size.
4242
completion = proxy_client.completions.create(
4343
model=MODEL, prompt=SHORT_PROMPT, temperature=0
44-
)
44+
) # yapf: disable
4545
proxy_response = completion.choices[0].text
4646
completion = prefill_client.completions.create(
4747
model=MODEL, prompt=SHORT_PROMPT, temperature=0
48-
)
48+
) # yapf: disable
4949
prefill_response = completion.choices[0].text
5050
print(f"SMALL PROMPT: {proxy_response=}")
5151
assert proxy_response == prefill_response
@@ -55,12 +55,12 @@ def test_edge_cases():
5555
# (2a): prime the D worker.
5656
completion = decode_client.completions.create(
5757
model=MODEL, prompt=PROMPT, temperature=0
58-
)
58+
) # yapf: disable
5959
decode_response = completion.choices[0].text
6060
# (2b): send via the P/D setup
6161
completion = proxy_client.completions.create(
6262
model=MODEL, prompt=PROMPT, temperature=0
63-
)
63+
) # yapf: disable
6464
proxy_response = completion.choices[0].text
6565
print(f"FULL CACHE HIT: {proxy_response=}")
6666
assert proxy_response == decode_response
@@ -69,11 +69,11 @@ def test_edge_cases():
6969
# hit on the D worker.
7070
completion = proxy_client.completions.create(
7171
model=MODEL, prompt=LONG_PROMPT, temperature=0
72-
)
72+
) # yapf: disable
7373
proxy_response = completion.choices[0].text
7474
completion = prefill_client.completions.create(
7575
model=MODEL, prompt=LONG_PROMPT, temperature=0
76-
)
76+
) # yapf: disable
7777
prefill_response = completion.choices[0].text
7878
print(f"PARTIAL CACHE HIT: {proxy_response=}")
7979
assert proxy_response == prefill_response

examples/nixl/toy_proxy_server.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ async def lifespan(app: FastAPI):
3737
"port": port,
3838
"id": i,
3939
}
40-
)
40+
) # yapf: disable
4141

4242
# Create decode clients
4343
for i, (host, port) in enumerate(global_args.decoder_instances):
@@ -51,20 +51,20 @@ async def lifespan(app: FastAPI):
5151
"port": port,
5252
"id": i,
5353
}
54-
)
54+
) # yapf: disable
5555

5656
# Initialize round-robin iterators
5757
app.state.prefill_iterator = itertools.cycle(
5858
range(len(app.state.prefill_clients))
59-
)
59+
) # yapf: disable
6060
app.state.decode_iterator = itertools.cycle(
6161
range(len(app.state.decode_clients))
62-
)
62+
) # yapf: disable
6363

6464
print(
6565
f"Initialized {len(app.state.prefill_clients)} prefill clients "
6666
f"and {len(app.state.decode_clients)} decode clients."
67-
)
67+
) # yapf: disable
6868

6969
yield
7070

@@ -111,7 +111,11 @@ def parse_args():
111111
default=["localhost"],
112112
)
113113
parser.add_argument(
114-
"--decoder-ports", "--decoder-port", type=int, nargs="+", default=[8200]
114+
"--decoder-ports",
115+
"--decoder-port",
116+
type=int,
117+
nargs="+",
118+
default=[8200],
115119
)
116120

117121
args = parser.parse_args()
@@ -120,17 +124,17 @@ def parse_args():
120124
if len(args.prefiller_hosts) != len(args.prefiller_ports):
121125
raise ValueError(
122126
"Number of prefiller hosts must match number of prefiller ports"
123-
)
127+
) # yapf: disable
124128

125129
if len(args.decoder_hosts) != len(args.decoder_ports):
126130
raise ValueError(
127131
"Number of decoder hosts must match number of decoder ports"
128-
)
132+
) # yapf: disable
129133

130134
# Create tuples of (host, port) for each service type
131135
args.prefiller_instances = list(
132136
zip(args.prefiller_hosts, args.prefiller_ports)
133-
)
137+
) # yapf: disable
134138
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
135139

136140
return args
@@ -159,7 +163,7 @@ def get_next_client(app, service_type: str):
159163

160164
async def send_request_to_service(
161165
client_info: dict, endpoint: str, req_data: dict, request_id: str
162-
):
166+
): # yapf: disable
163167
"""
164168
Send a request to a service using a client from the pool.
165169
"""
@@ -185,15 +189,15 @@ async def send_request_to_service(
185189

186190
response = await client_info["client"].post(
187191
endpoint, json=req_data, headers=headers
188-
)
192+
) # yapf: disable
189193
response.raise_for_status()
190194

191195
return response
192196

193197

194198
async def stream_service_response(
195199
client_info: dict, endpoint: str, req_data: dict, request_id: str
196-
):
200+
): # yapf: disable
197201
"""
198202
Asynchronously stream response from a service using a client from the pool.
199203
"""
@@ -204,7 +208,7 @@ async def stream_service_response(
204208

205209
async with client_info["client"].stream(
206210
"POST", endpoint, json=req_data, headers=headers
207-
) as response:
211+
) as response: # yapf: disable
208212
response.raise_for_status()
209213
async for chunk in response.aiter_bytes():
210214
yield chunk
@@ -221,7 +225,7 @@ async def _handle_completions(api: str, request: Request):
221225
# Send request to prefill service
222226
response = await send_request_to_service(
223227
prefill_client_info, api, req_data, request_id
224-
)
228+
) # yapf: disable
225229

226230
# Extract the needed fields
227231
response_json = response.json()
@@ -238,19 +242,21 @@ async def _handle_completions(api: str, request: Request):
238242
async def generate_stream():
239243
async for chunk in stream_service_response(
240244
decode_client_info, api, req_data, request_id=request_id
241-
):
245+
): # yapf: disable
242246
yield chunk
243247

244248
return StreamingResponse(
245249
generate_stream(), media_type="application/json"
246-
)
250+
) # yapf: disable
247251

248252
except Exception as e:
249253
import sys
250254
import traceback
251255

252256
exc_info = sys.exc_info()
253-
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
257+
print(
258+
f"Error occurred in disagg prefill proxy server - {api} endpoint"
259+
) # yapf: disable
254260
print(e)
255261
print("".join(traceback.format_exception(*exc_info)))
256262
raise

0 commit comments

Comments
 (0)