Skip to content

Commit 8f36ce4

Browse files
committed
fix
Signed-off-by: Zhenhuan Chen <[email protected]>
1 parent 28d7934 commit 8f36ce4

File tree

8 files changed

+14
-32
lines changed

8 files changed

+14
-32
lines changed

examples/scaffolding/contrib/Dynasor/scaffolding_dynasor_run.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ async def task(prompt: str):
6565
f">>> final output {len(result.output.outputs[0].token_ids)}\n",
6666
result.output.outputs[0].text)
6767

68+
# Need to provide LLM's event loop to get results in the middle of the whole process.
6869
asyncio.run_coroutine_threadsafe(task(prompts[0]), llm.loop).result()
6970
else:
7071
results = llm.generate(prompts)
@@ -83,8 +84,8 @@ def main():
8384

8485
prompts = [
8586
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\r\n\r\n",
86-
# "There exist real numbers $x$ and $y$, both greater than 1, such that $\\log_x\\left(y^x\\right)=\\log_y\\left(x^{4y}\\right)=10$. Find $xy$.",
87-
# "Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
87+
"There exist real numbers $x$ and $y$, both greater than 1, such that $\\log_x\\left(y^x\\right)=\\log_y\\left(x^{4y}\\right)=10$. Find $xy$.",
88+
"Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
8889
]
8990

9091
llm_worker = TRTLLMWorker.init_with_new_llm(

examples/scaffolding/run_basic_generation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,18 @@ def parse_arguments():
1919

2020

2121
def test_sync(prompts, proposer_worker):
22-
prototype_controller = NativeGenerationController(
23-
sampling_params={"temperature": 0.9})
22+
prototype_controller = NativeGenerationController(sampling_params={
23+
"temperature": 0.9,
24+
"max_tokens": 1024,
25+
})
2426

2527
llm = ScaffoldingLlm(
2628
prototype_controller,
2729
{NativeGenerationController.WorkerTag.GENERATION: proposer_worker},
2830
)
2931
results = llm.generate(prompts)
3032
for result in results:
33+
print(len(result.output.outputs[0].token_ids))
3134
print(result.output.outputs[0].text)
3235
print(f'main shutting down...')
3336
llm.shutdown()
@@ -42,7 +45,7 @@ async def test_async_func(prompt, proposer_worker):
4245
prototype_controller = NativeGenerationController(
4346
sampling_params={
4447
"temperature": 0.9,
45-
"max_tokens": 64
48+
"max_tokens": 1024,
4649
},
4750
streaming=True,
4851
)

tensorrt_llm/scaffolding/contrib/Dynasor/dynasor_controller.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,6 @@ def process(self, tasks: List[GenerationTask], **kwargs):
102102
probe_task.input_str = current_prompt + self.probe_suffix
103103

104104
# For the probe task, append the suffix to force a chain-of-thought leading to an answer.
105-
print("[DynasorGenerationController] probe_task")
106-
# yield [probe_task, proposer_task]
107105
yield [proposer_task, probe_task]
108106

109107
# Retrieve the output from the probe task.
@@ -141,10 +139,7 @@ def process(self, tasks: List[GenerationTask], **kwargs):
141139
probe_answers[-1] + "}\n\\]")
142140
return
143141

144-
# if not confident, do another round of generation
145-
# print("[DynasorGenerationController] proposer_task")
146-
# yield [proposer_task]
147-
142+
# If not confident, do another round of generation
148143
# Append the newly generated text from the proposer to the current prompt for the next iteration.
149144
current_prompt += proposer_task.output_str
150145

tensorrt_llm/scaffolding/controller.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from tensorrt_llm.scaffolding.math_utils import get_digit_majority_vote_result
1212
from tensorrt_llm.scaffolding.task import GenerationTask, Task
1313

14-
# from .result import ScaffoldingOutput
15-
1614

1715
class Controller(ABC):
1816

@@ -27,7 +25,6 @@ def generate(self, prompt: str, **kwargs) -> GenerationResult:
2725

2826
yield from self.process([task], **kwargs)
2927

30-
# print("[Controller.generate] task.output in generate", task.result)
3128
return task.create_scaffolding_output()
3229

3330
def process(self, tasks: List[Task], **kwargs):

tensorrt_llm/scaffolding/result.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,11 @@ def __init__(self, streaming_event: Optional[asyncio.Event] = None):
2323
self.streaming_event = streaming_event
2424

2525
def set_output(self, output: GenerationResult):
26-
print("[set_output] called")
2726
self.aqueue.put_nowait(output)
2827
self._done = True
29-
print("[set_output] put")
3028

3129
async def set_output_async(self, output: GenerationResult):
32-
print("[set_output_async] called")
3330
await self.aqueue.put(output)
34-
print("[set_output_async] put")
3531

3632
def set_task_collections(self, task_collections: Mapping[str,
3733
"TaskCollection"]):
@@ -42,10 +38,8 @@ def finished(self) -> bool:
4238
return self.output is not None and self.output.finished
4339

4440
async def _aresult_step(self):
45-
print("[_aresult_step] waiting for response")
4641
# TODO: error handling or raise exception?
4742
response = await self.aqueue.get()
48-
print("[_aresult_step] response received")
4943
if response is None:
5044
raise Exception("ScaffoldingLlm execution failed")
5145
self._handle_response(response)
@@ -79,7 +73,6 @@ def __aiter__(self):
7973

8074
async def __anext__(self):
8175
if self.finished:
82-
print("[_aresult_step] streaming_event set")
8376
self.streaming_event.set() if self.streaming_event else None
8477
if self._done and self.finished:
8578
raise StopAsyncIteration
@@ -88,4 +81,4 @@ async def __anext__(self):
8881
return self
8982

9083
def _handle_response(self, response: GenerationResult):
91-
self.output = response # .outputs[0].text
84+
self.output = response

tensorrt_llm/scaffolding/scaffolding_llm.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def __init__(
3030
self.workers = workers
3131

3232
self.loop = self._get_loop()
33-
print("own_loop:", self.own_loop)
3433
asyncio.set_event_loop(self.loop)
3534
self.task_queue = asyncio.Queue()
3635
self.main_loop_stop_event = asyncio.Event()
@@ -85,7 +84,6 @@ async def _handle_task_list(self,
8584
for task in tasks:
8685
if task.streaming:
8786
await request.result.set_output_async(task.result)
88-
print("[_handle_task_list] streaming_event wait")
8987
self.streaming_event.clear()
9088
await self.streaming_event.wait()
9189

@@ -113,8 +111,6 @@ async def _handle_single_request(self, request: ScaffoldingRequest):
113111
finally:
114112
self.running_req_count -= 1
115113
self._maybe_schedule()
116-
print(f"[Request finished] running_req_count: "
117-
f"{self.running_req_count}")
118114

119115
def _create_controller_generator(self, request: ScaffoldingRequest):
120116
"""Create a generator wrapper for the controller."""
@@ -141,7 +137,6 @@ def _maybe_schedule(self, request: ScaffoldingRequest = None):
141137

142138
while (self.running_req_count < self.max_parallel_requests
143139
and self.pending_queue):
144-
print(f"[Scheduling] running_req_count: {self.running_req_count}")
145140
next_request = self.pending_queue.popleft()
146141
self._schedule_request(next_request)
147142

tensorrt_llm/scaffolding/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class GenerationTask(Task):
4343
frequency_penalty: Optional[float] = 0.0
4444
logit_bias: Optional[Dict[str, float]] = None
4545
num_logprobs: Optional[int] = None
46-
max_tokens: Optional[int] = 2048
46+
max_tokens: Optional[int] = None
4747
n: int = 1
4848
presence_penalty: Optional[float] = 0.0
4949
seed: Optional[int] = None

tensorrt_llm/scaffolding/worker.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,18 +187,16 @@ def convert_task_params(self, task: GenerationTask):
187187
async def generation_handler(self, task: GenerationTask) -> TaskStatus:
188188
sampling_params = self.convert_task_params(task)
189189

190-
print("[generation_handler] task.streaming:", task.streaming)
190+
# If the task is streaming, we will return result directly for
191+
# async iteration outside. Otherwise, we will wait.
191192
if task.streaming:
192-
# If the task is streaming, we need to use the async generate method
193-
# and handle the streaming output.
194193
result = self.llm.generate_async(task.input_str,
195194
sampling_params=sampling_params,
196195
streaming=True)
197196
else:
198197
result = await self.llm.generate_async(
199198
task.input_str, sampling_params=sampling_params)
200199
task.result = result
201-
# print("[generation_handler] task.result:", task.result)
202200

203201
# TODO: error handle
204202
return TaskStatus.SUCCESS

0 commit comments

Comments
 (0)