Skip to content

Commit de9cfba

Browse files
committed
fix
Signed-off-by: Zhenhuan Chen <[email protected]>
1 parent c033d5b commit de9cfba

File tree

8 files changed

+27
-42
lines changed

8 files changed

+27
-42
lines changed

examples/scaffolding/contrib/Dynasor/scaffolding_dynasor_run.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,13 @@ async def task(prompt: str):
5858
async for result in llm.generate_async(prompt):
5959
i += 1
6060
print(">>>", i, result)
61-
async for output in result.output:
61+
async for output in result.cur_output:
6262
print(">>>", i, len(output.outputs[0].token_ids), "\n",
6363
output.outputs[0].text)
64-
print(
65-
f">>> final output {len(result.output.outputs[0].token_ids)}\n",
66-
result.output.outputs[0].text)
64+
print(f">>> final output {len(result.outputs[0].token_ids)}\n",
65+
result.outputs[0].text)
6766

67+
# Need to provide LLM's event loop to get results in the middle of the whole process.
6868
asyncio.run_coroutine_threadsafe(task(prompts[0]), llm.loop).result()
6969
else:
7070
results = llm.generate(prompts)
@@ -83,8 +83,8 @@ def main():
8383

8484
prompts = [
8585
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\r\n\r\n",
86-
# "There exist real numbers $x$ and $y$, both greater than 1, such that $\\log_x\\left(y^x\\right)=\\log_y\\left(x^{4y}\\right)=10$. Find $xy$.",
87-
# "Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
86+
"There exist real numbers $x$ and $y$, both greater than 1, such that $\\log_x\\left(y^x\\right)=\\log_y\\left(x^{4y}\\right)=10$. Find $xy$.",
87+
"Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
8888
]
8989

9090
llm_worker = TRTLLMWorker.init_with_new_llm(

examples/scaffolding/run_basic_generation.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,19 @@ def parse_arguments():
1919

2020

2121
def test_sync(prompts, proposer_worker):
22-
prototype_controller = NativeGenerationController(
23-
sampling_params={"temperature": 0.9})
22+
prototype_controller = NativeGenerationController(sampling_params={
23+
"temperature": 0.9,
24+
"max_tokens": 1024,
25+
})
2426

2527
llm = ScaffoldingLlm(
2628
prototype_controller,
2729
{NativeGenerationController.WorkerTag.GENERATION: proposer_worker},
2830
)
2931
results = llm.generate(prompts)
3032
for result in results:
31-
print(result.output.outputs[0].text)
33+
print(len(result.outputs[0].token_ids))
34+
print(result.outputs[0].text)
3235
print(f'main shutting down...')
3336
llm.shutdown()
3437
print(f'worker shutting down...')
@@ -42,7 +45,7 @@ async def test_async_func(prompt, proposer_worker):
4245
prototype_controller = NativeGenerationController(
4346
sampling_params={
4447
"temperature": 0.9,
45-
"max_tokens": 64
48+
"max_tokens": 1024,
4649
},
4750
streaming=True,
4851
)
@@ -55,11 +58,11 @@ async def test_async_func(prompt, proposer_worker):
5558
async for result in llm.generate_async(prompt):
5659
i += 1
5760
print(">>>", i, result)
58-
async for output in result.output:
61+
async for output in result.cur_output:
5962
print(">>>", i, len(output.outputs[0].token_ids), "\n",
6063
output.outputs[0].text)
61-
print(f">>> final output {len(output.outputs[0].token_ids)}\n",
62-
output.outputs[0].text)
64+
print(f">>> final output {len(result.outputs[0].token_ids)}\n",
65+
result.outputs[0].text)
6366

6467
print(f'main shutting down...')
6568
llm.shutdown()

tensorrt_llm/scaffolding/contrib/Dynasor/dynasor_controller.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,6 @@ def process(self, tasks: List[GenerationTask], **kwargs):
102102
probe_task.input_str = current_prompt + self.probe_suffix
103103

104104
# For the probe task, append the suffix to force a chain-of-thought leading to an answer.
105-
print("[DynasorGenerationController] probe_task")
106-
# yield [probe_task, proposer_task]
107105
yield [proposer_task, probe_task]
108106

109107
# Retrieve the output from the probe task.
@@ -141,10 +139,7 @@ def process(self, tasks: List[GenerationTask], **kwargs):
141139
probe_answers[-1] + "}\n\\]")
142140
return
143141

144-
# if not confident, do another round of generation
145-
# print("[DynasorGenerationController] proposer_task")
146-
# yield [proposer_task]
147-
142+
# If not confident, do another round of generation
148143
# Append the newly generated text from the proposer to the current prompt for the next iteration.
149144
current_prompt += proposer_task.output_str
150145

tensorrt_llm/scaffolding/controller.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from tensorrt_llm.scaffolding.math_utils import get_digit_majority_vote_result
1212
from tensorrt_llm.scaffolding.task import GenerationTask, Task
1313

14-
# from .result import ScaffoldingOutput
15-
1614

1715
class Controller(ABC):
1816

@@ -27,7 +25,6 @@ def generate(self, prompt: str, **kwargs) -> GenerationResult:
2725

2826
yield from self.process([task], **kwargs)
2927

30-
# print("[Controller.generate] task.output in generate", task.result)
3128
return task.create_scaffolding_output()
3229

3330
def process(self, tasks: List[Task], **kwargs):

tensorrt_llm/scaffolding/result.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,33 @@ class ScaffoldingResult:
1717
def __init__(self, streaming_event: Optional[asyncio.Event] = None):
1818
super().__init__()
1919
self.aqueue = asyncio.Queue()
20-
self.output = None
20+
self.cur_output = None
2121
self._done = False
2222
self.task_collections = None
2323
self.streaming_event = streaming_event
2424

2525
def set_output(self, output: GenerationResult):
26-
print("[set_output] called")
2726
self.aqueue.put_nowait(output)
2827
self._done = True
29-
print("[set_output] put")
3028

3129
async def set_output_async(self, output: GenerationResult):
32-
print("[set_output_async] called")
3330
await self.aqueue.put(output)
34-
print("[set_output_async] put")
3531

3632
def set_task_collections(self, task_collections: Mapping[str,
3733
"TaskCollection"]):
3834
self.task_collections = task_collections
3935

36+
@property
37+
def outputs(self):
38+
return self.cur_output.outputs if self.cur_output else None
39+
4040
@property
4141
def finished(self) -> bool:
42-
return self.output is not None and self.output.finished
42+
return self.cur_output is not None and self.cur_output.finished
4343

4444
async def _aresult_step(self):
45-
print("[_aresult_step] waiting for response")
4645
# TODO: error handling or raise exception?
4746
response = await self.aqueue.get()
48-
print("[_aresult_step] response received")
4947
if response is None:
5048
raise Exception("ScaffoldingLlm execution failed")
5149
self._handle_response(response)
@@ -79,7 +77,6 @@ def __aiter__(self):
7977

8078
async def __anext__(self):
8179
if self.finished:
82-
print("[_aresult_step] streaming_event set")
8380
self.streaming_event.set() if self.streaming_event else None
8481
if self._done and self.finished:
8582
raise StopAsyncIteration
@@ -88,4 +85,4 @@ async def __anext__(self):
8885
return self
8986

9087
def _handle_response(self, response: GenerationResult):
91-
self.output = response # .outputs[0].text
88+
self.cur_output = response

tensorrt_llm/scaffolding/scaffolding_llm.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def __init__(
3030
self.workers = workers
3131

3232
self.loop = self._get_loop()
33-
print("own_loop:", self.own_loop)
3433
asyncio.set_event_loop(self.loop)
3534
self.task_queue = asyncio.Queue()
3635
self.main_loop_stop_event = asyncio.Event()
@@ -85,7 +84,6 @@ async def _handle_task_list(self,
8584
for task in tasks:
8685
if task.streaming:
8786
await request.result.set_output_async(task.result)
88-
print("[_handle_task_list] streaming_event wait")
8987
self.streaming_event.clear()
9088
await self.streaming_event.wait()
9189

@@ -113,8 +111,6 @@ async def _handle_single_request(self, request: ScaffoldingRequest):
113111
finally:
114112
self.running_req_count -= 1
115113
self._maybe_schedule()
116-
print(f"[Request finished] running_req_count: "
117-
f"{self.running_req_count}")
118114

119115
def _create_controller_generator(self, request: ScaffoldingRequest):
120116
"""Create a generator wrapper for the controller."""
@@ -141,7 +137,6 @@ def _maybe_schedule(self, request: ScaffoldingRequest = None):
141137

142138
while (self.running_req_count < self.max_parallel_requests
143139
and self.pending_queue):
144-
print(f"[Scheduling] running_req_count: {self.running_req_count}")
145140
next_request = self.pending_queue.popleft()
146141
self._schedule_request(next_request)
147142

tensorrt_llm/scaffolding/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class GenerationTask(Task):
4343
frequency_penalty: Optional[float] = 0.0
4444
logit_bias: Optional[Dict[str, float]] = None
4545
num_logprobs: Optional[int] = None
46-
max_tokens: Optional[int] = 2048
46+
max_tokens: Optional[int] = None
4747
n: int = 1
4848
presence_penalty: Optional[float] = 0.0
4949
seed: Optional[int] = None

tensorrt_llm/scaffolding/worker.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,18 +187,16 @@ def convert_task_params(self, task: GenerationTask):
187187
async def generation_handler(self, task: GenerationTask) -> TaskStatus:
188188
sampling_params = self.convert_task_params(task)
189189

190-
print("[generation_handler] task.streaming:", task.streaming)
190+
# If the task is streaming, we will return result directly for
191+
# async iteration outside. Otherwise, we will wait.
191192
if task.streaming:
192-
# If the task is streaming, we need to use the async generate method
193-
# and handle the streaming output.
194193
result = self.llm.generate_async(task.input_str,
195194
sampling_params=sampling_params,
196195
streaming=True)
197196
else:
198197
result = await self.llm.generate_async(
199198
task.input_str, sampling_params=sampling_params)
200199
task.result = result
201-
# print("[generation_handler] task.result:", task.result)
202200

203201
# TODO: error handle
204202
return TaskStatus.SUCCESS

0 commit comments

Comments
 (0)