Skip to content

Commit c1e7fb9

Browse files
authored
[TRTLLM-7207][feat] Chat completions API for gpt-oss (#7261)
Signed-off-by: Pengyun Lin <[email protected]>
1 parent f30768e commit c1e7fb9

File tree

10 files changed

+2050
-168
lines changed

10 files changed

+2050
-168
lines changed

examples/models/core/gpt_oss/README.md

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,7 @@ OpenAI MoE models support function calling. Here is an example based on [XGramma
3535
First, launch a server with XGrammar enabled:
3636

3737
```bash
38-
cat > ./extra_llm_api_options.yaml <<EOF
39-
guided_decoding_backend: xgrammar
40-
EOF
41-
42-
trtllm-serve <model> \
43-
--backend pytorch \
44-
--extra_llm_api_options extra_llm_api_options.yaml
38+
trtllm-serve <model>
4539
```
4640

4741
Run the [openai_chat_client_function_calling.py](./openai_chat_client_function_calling.py) script, which queries the LLM server in two steps:
@@ -68,14 +62,9 @@ The output would look similar to:
6862

6963
```txt
7064
[USER PROMPT] What is the weather like in SF?
71-
[RESPONSE 1] <|channel|>analysis<|message|>The user asks: "What is the weather like in SF?" They want the weather in SF. SF likely refers to San Francisco. We need to get the current weather. We can use get_current_weather function. We need to provide location string "San Francisco, CA". We can also ask for format? By default celsius. But maybe user expects Fahrenheit? They didn't specify. We can provide celsius or Fahrenheit. We can choose default celsius. But maybe better to provide Fahrenheit because US. But default is celsius. We can provide both? We can call function with format "fahrenheit" to be user-friendly. But the function default is celsius. We can override. Let's call get_current_weather with location "San Francisco, CA" and format "fahrenheit". Then we will get the weather. Then we will respond with friendly tone. We need to call the function.<|end|><|start|>assistant<|channel|>commentary to=get_current_weather <|constrain|>json<|message|>{
72-
"location": "San Francisco, CA",
73-
"format": "fahrenheit"
74-
}<|call|>
75-
[FUNCTION CALL] get_current_weather(**{'location': 'San Francisco, CA', 'format': 'fahrenheit'})
76-
[RESPONSE 2] <|channel|>analysis<|message|>The user asked: "What is the weather like in SF?" We have fetched the weather: sunny true, temperature 68 (F). We need to respond in a friendly tone. Provide a friendly answer: "It's sunny and 68°F in San Francisco." Possibly add a friendly comment. Also ask if they want more details.<|end|><|start|>assistant<|channel|>final<|message|>Sure thing! It’s a pleasant 68 °F in San Francisco right now, and the sun is shining. 🌞
77-
78-
Anything else you'd like to know about the weather or maybe some fun things to do in the city today?<|return|>
65+
[RESPONSE 1] [COT] Need to call get_current_weather.
66+
[RESPONSE 1] [FUNCTION CALL] get_current_weather(**{'location': 'San Francisco, CA'})
67+
[RESPONSE 2] It’s a bright, sunny day in San Francisco with the temperature around 20 °C (68 °F). Enjoy the pleasant weather!
7968
```
8069

8170
The function call works successfully:
@@ -95,14 +84,14 @@ The output would look like:
9584

9685
```txt
9786
[USER PROMPT] What is the weather like in NY and SF?
98-
[RESPONSE 1] <|channel|>analysis<|message|>The user asks: "What is the weather like in NY and SF?" They want the weather in New York and San Francisco. We need to provide the weather. We can use the function get_multiple_weathers. We need to provide the list of city and state strings. For New York, we can use "New York, NY". For San Francisco, "San Francisco, CA". We can call get_multiple_weathers with those two locations. We should specify format? The default is celsius. But maybe the user might want Fahrenheit? They didn't specify. We can just use default celsius. But maybe we can provide both? But the function only returns one format. We can just use default celsius. But we can also ask the user? But the user asked "What is the weather like in NY and SF?" We can just provide the weather. We can call the function. Then we will get the weather data. Then we can respond with a friendly tone. We need to call the function.<|end|><|start|>assistant<|channel|>commentary to=get_multiple_weathers <|constrain|>json<|message|>{"locations":["New York, NY","San Francisco, CA"]}<|call|>
99-
[FUNCTION CALL] get_multiple_weathers(**{'locations': ['New York, NY', 'San Francisco, CA']})
100-
[RESPONSE 2] <|channel|>analysis<|message|>The user asked: "What is the weather like in NY and SF?" We called get_multiple_weathers with locations ["New York, NY","San Francisco, CA"]. The function returned: [{"sunny": true, "temperature": 20}, {"sunny": true, "temperature": 20}]. That seems to be a list of two objects, each with sunny: true, temperature: 20. But we need to interpret the function output. The function get_multiple_weathers presumably returns a list of weather data for each location. But the returned data is ambiguous: we don't know which corresponds to which location. But we can assume the order matches the input order: first is New York, second is San Francisco. The temperature is 20 degrees Celsius? The function didn't specify units, but default is celsius. So 20°C. And sunny: true. So both are sunny and 20°C. We should respond in a friendly tone, summarizing the weather for both cities. We can mention that it's sunny and 20°C in both New York and San Francisco. We can also mention that it's a nice day. We can ask if they want more details. We should not mention the function call. We should just provide the answer.<|end|><|start|>assistant<|channel|>final<|message|>Sure thing! Here’s the scoop:
101-
102-
- **New York, NY**: It’s sunny and a comfortable 20 °C (68 °F).
103-
- **San Francisco, CA**: Also sunny with a pleasant 20 °C (68 °F).
104-
105-
Looks like both coasts are enjoying a bright, mild day. Let me know if you’d like a forecast for later or any other details!<|return|>
87+
[RESPONSE 1] [COT] Need to call get_multiple_weathers.
88+
[RESPONSE 1] [FUNCTION CALL] get_multiple_weathers(**{'locations': ['New York, NY', 'San Francisco, CA'], 'format': 'celsius'})
89+
[RESPONSE 2] Here’s a quick snapshot of the current weather in both cities:
90+
91+
| City | Weather | Temperature |
92+
|------|---------|-------------|
93+
| New York | ☀️ Sunny | 20 °C |
94+
| San Francisco | ☀️ Sunny | 20 °C |
10695
```
10796

10897
Once again, the function call works successfully, this time using a different function: `get_multiple_weathers`.

examples/models/core/gpt_oss/openai_chat_client_function_calling.py

Lines changed: 69 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,58 @@
11
import argparse
22
import json
3-
import re
43

54
from openai import OpenAI
65

7-
system_prompt = """You are ChatGPT, a large language model trained by OpenAI.
8-
Knowledge cutoff: 2024-06
9-
Current date: 2025-06-28
10-
11-
Reasoning: high
12-
13-
# Valid channels: analysis, commentary, final. Channel must be included for every message.
14-
Calls to these tools must go to the commentary channel: 'functions'."""
15-
16-
developer_prompt = """# Instructions
17-
18-
Use a friendly tone.
19-
20-
# Tools
21-
22-
## functions
23-
24-
namespace functions {
25-
26-
// Gets the location of the user.
27-
type get_location = () => any;
28-
29-
// Gets the current weather in the provided location.
30-
type get_current_weather = (_: {
31-
// The city and state, e.g. San Francisco, CA
32-
location: string,
33-
format?: "celsius" | "fahrenheit", // default: celsius
34-
}) => any;
35-
36-
// Gets the current weather in the provided list of locations.
37-
type get_multiple_weathers = (_: {
38-
// List of city and state, e.g. ["San Francisco, CA", "New York, NY"]
39-
locations: string[],
40-
format?: "celsius" | "fahrenheit", // default: celsius
41-
}) => any;
42-
43-
} // namespace functions"""
44-
45-
schema_get_current_weather = {
46-
"type": "object",
47-
"properties": {
48-
"location": {
49-
"type": "string",
50-
"description": "The city and state, e.g. San Francisco, CA",
51-
},
52-
"format": {
53-
"type": "string",
54-
"description": "default: celsius",
55-
"enum": ["celsius", "fahrenheit"],
56-
},
57-
},
58-
"required": ["location"],
6+
tool_get_current_weather = {
7+
"type": "function",
8+
"function": {
9+
"name": "get_current_weather",
10+
"description": "Gets the current weather in the provided location.",
11+
"parameters": {
12+
"type": "object",
13+
"properties": {
14+
"location": {
15+
"type": "string",
16+
"description": "The city and state, e.g. San Francisco, CA",
17+
},
18+
"format": {
19+
"type": "string",
20+
"description": "default: celsius",
21+
"enum": ["celsius", "fahrenheit"],
22+
},
23+
},
24+
"required": ["location"],
25+
}
26+
}
5927
}
6028

61-
schema_get_multiple_weathers = {
62-
"type": "object",
63-
"properties": {
64-
"locations": {
65-
"type":
66-
"array",
67-
"items": {
68-
"type": "string"
29+
tool_get_multiple_weathers = {
30+
"type": "function",
31+
"function": {
32+
"name": "get_multiple_weathers",
33+
"description":
34+
"Gets the current weather in the provided list of locations.",
35+
"parameters": {
36+
"type": "object",
37+
"properties": {
38+
"locations": {
39+
"type":
40+
"array",
41+
"items": {
42+
"type": "string"
43+
},
44+
"description":
45+
'List of city and state, e.g. ["San Francisco, CA", "New York, NY"]',
46+
},
47+
"format": {
48+
"type": "string",
49+
"description": "default: celsius",
50+
"enum": ["celsius", "fahrenheit"],
51+
},
6952
},
70-
"description":
71-
'List of city and state, e.g. ["San Francisco, CA", "New York, NY"]',
72-
},
73-
"format": {
74-
"type": "string",
75-
"description": "default: celsius",
76-
"enum": ["celsius", "fahrenheit"],
77-
},
78-
},
79-
"required": ["locations"],
53+
"required": ["locations"],
54+
}
55+
}
8056
}
8157

8258

@@ -103,14 +79,6 @@ def main():
10379
)
10480

10581
messages = [
106-
{
107-
"role": "system",
108-
"content": system_prompt,
109-
},
110-
{
111-
"role": "developer",
112-
"content": developer_prompt,
113-
},
11482
{
11583
"role": "user",
11684
"content": args.prompt,
@@ -122,65 +90,41 @@ def main():
12290
model=args.model,
12391
messages=messages,
12492
max_completion_tokens=500,
125-
response_format={
126-
"type":
127-
"structural_tag",
128-
"structures": [{
129-
"begin":
130-
"<|channel|>commentary to=get_current_weather <|constrain|>json<|message|>",
131-
"schema": schema_get_current_weather,
132-
"end": "<|call|>",
133-
}, {
134-
"begin":
135-
"<|channel|>commentary to=get_multiple_weathers <|constrain|>json<|message|>",
136-
"schema": schema_get_multiple_weathers,
137-
"end": "<|call|>",
138-
}],
139-
"triggers": ["<|channel|>commentary to="],
140-
},
141-
stop=["<|call|>"],
142-
extra_body={
143-
"skip_special_tokens": False,
144-
"include_stop_str_in_output": True,
145-
},
93+
tools=[tool_get_current_weather, tool_get_multiple_weathers],
14694
)
147-
148-
response_text = chat_completion.choices[0].message.content
149-
print(f"[RESPONSE 1] {response_text}")
150-
151-
for regex, tool in [
152-
(r"(<\|channel\|>commentary to=get_current_weather <\|constrain\|>json<\|message\|>)([\S\s]+)(<\|call\|>)",
153-
get_current_weather),
154-
(r"(<\|channel\|>commentary to=get_multiple_weathers <\|constrain\|>json<\|message\|>)([\S\s]+)(<\|call\|>)",
155-
get_multiple_weathers)
156-
]:
157-
match = re.search(regex, response_text)
158-
if match is not None:
159-
break
160-
else:
161-
print("Failed to call functions, exiting...")
162-
return
163-
164-
kwargs = json.loads(match.group(2))
165-
print(f"[FUNCTION CALL] {tool.__name__}(**{kwargs})")
95+
tools = {
96+
"get_current_weather": get_current_weather,
97+
"get_multiple_weathers": get_multiple_weathers
98+
}
99+
message = chat_completion.choices[0].message
100+
assert message, "Empty Message"
101+
assert message.tool_calls, "Empty tool calls"
102+
assert message.content is None, "Empty content expected"
103+
reasoning = message.reasoning if hasattr(message, "reasoning") else None
104+
tool_call = message.tool_calls[0]
105+
func_name = tool_call.function.name
106+
assert func_name in tools, "Invalid function name"
107+
kwargs = json.loads(tool_call.function.arguments)
108+
109+
tool = tools[func_name]
110+
print(f"[RESPONSE 1] [COT] {reasoning}")
111+
print(f"[RESPONSE 1] [FUNCTION CALL] {tool.__name__}(**{kwargs})")
166112
answer = tool(**kwargs)
167113

168114
messages.extend([{
169115
"role": "assistant",
170-
"content": match.group(0),
116+
"reasoning": reasoning,
117+
"tool_calls": [tool_call],
171118
}, {
172-
"role": f"{tool.__name__} to=assistant",
119+
"role": "tool",
173120
"content": json.dumps(answer),
121+
"tool_call_id": tool_call.id
174122
}])
175123

176124
chat_completion = client.chat.completions.create(
177125
model=args.model,
178126
messages=messages,
179127
max_completion_tokens=500,
180-
extra_body={
181-
"skip_special_tokens": False,
182-
"include_stop_str_in_output": True,
183-
},
184128
)
185129

186130
response_text = chat_completion.choices[0].message.content

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,4 @@ soundfile
6767
triton==3.3.1; platform_machine == "x86_64"
6868
tiktoken
6969
blobfile
70+
openai-harmony==0.0.4

tensorrt_llm/executor/worker.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -466,25 +466,41 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
466466

467467
def _deduce_max_tokens(request: GenerationRequest,
468468
executor_config: tllm.ExecutorConfig) -> int:
469-
if request.sampling_params.max_tokens:
470-
return request.sampling_params.max_tokens
471469
# deduce max_tokens when it's not set by user
470+
max_tokens = request.sampling_params.max_tokens
472471
query_token_len = len(
473472
request.query_token_ids) if request.query_token_ids else 0
474473
cp_size = 1 if (not hasattr(executor_config, "mapping")
475474
or executor_config.mapping.cp_size
476475
is None) else executor_config.mapping.cp_size
477476
if not hasattr(executor_config, "max_seq_len"):
478-
raise RuntimeError(
479-
"max_tokens for sampling is not set and cannot be deduced")
477+
logger.warning("`default_max_tokens` cannot be deduced")
478+
if max_tokens is None:
479+
raise ValueError(
480+
"`max_tokens` must be set when `default_max_tokens` cannot be deduced"
481+
)
480482
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
481483
default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len
482-
if default_max_tokens < 0:
483-
raise ValueError(
484-
f"Deduced max_tokens {default_max_tokens} is less than 0, because"
485-
f"prompt length {splited_prompt_len} plus query length {query_token_len} "
486-
f"is larger than max_seq_len {executor_config.max_seq_len}")
487-
return default_max_tokens
484+
if default_max_tokens <= 0:
485+
logger.warning(
486+
f"`default_max_tokens` ({default_max_tokens}) should be greater than 0, "
487+
f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({executor_config.max_seq_len})"
488+
f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})"
489+
)
490+
if max_tokens is None:
491+
raise ValueError(
492+
"`max_tokens` must be set when `default_max_tokens` is illegal"
493+
)
494+
# default_max_tokens is the biggest available value
495+
if max_tokens is None:
496+
return default_max_tokens
497+
elif max_tokens > default_max_tokens:
498+
logger.warning(
499+
f"User-specified `max_tokens` ({max_tokens}) is greater than deduced "
500+
f"`default_max_tokens` ({default_max_tokens}), using default_max_tokens instead."
501+
)
502+
return default_max_tokens
503+
return max_tokens
488504

489505
try:
490506
executor_request = tllm.Request(

0 commit comments

Comments
 (0)