Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 12 additions & 23 deletions examples/models/core/gpt_oss/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,7 @@ OpenAI MoE models support function calling. Here is an example based on [XGramma
First, launch a server with XGrammar enabled:

```bash
cat > ./extra_llm_api_options.yaml <<EOF
guided_decoding_backend: xgrammar
EOF

trtllm-serve <model> \
--backend pytorch \
--extra_llm_api_options extra_llm_api_options.yaml
trtllm-serve <model>
```

Run the [openai_chat_client_function_calling.py](./openai_chat_client_function_calling.py) script, which queries the LLM server in two steps:
Expand All @@ -68,14 +62,9 @@ The output would look similar to:

```txt
[USER PROMPT] What is the weather like in SF?
[RESPONSE 1] <|channel|>analysis<|message|>The user asks: "What is the weather like in SF?" They want the weather in SF. SF likely refers to San Francisco. We need to get the current weather. We can use get_current_weather function. We need to provide location string "San Francisco, CA". We can also ask for format? By default celsius. But maybe user expects Fahrenheit? They didn't specify. We can provide celsius or Fahrenheit. We can choose default celsius. But maybe better to provide Fahrenheit because US. But default is celsius. We can provide both? We can call function with format "fahrenheit" to be user-friendly. But the function default is celsius. We can override. Let's call get_current_weather with location "San Francisco, CA" and format "fahrenheit". Then we will get the weather. Then we will respond with friendly tone. We need to call the function.<|end|><|start|>assistant<|channel|>commentary to=get_current_weather <|constrain|>json<|message|>{
"location": "San Francisco, CA",
"format": "fahrenheit"
}<|call|>
[FUNCTION CALL] get_current_weather(**{'location': 'San Francisco, CA', 'format': 'fahrenheit'})
[RESPONSE 2] <|channel|>analysis<|message|>The user asked: "What is the weather like in SF?" We have fetched the weather: sunny true, temperature 68 (F). We need to respond in a friendly tone. Provide a friendly answer: "It's sunny and 68°F in San Francisco." Possibly add a friendly comment. Also ask if they want more details.<|end|><|start|>assistant<|channel|>final<|message|>Sure thing! It’s a pleasant 68 °F in San Francisco right now, and the sun is shining. 🌞

Anything else you'd like to know about the weather or maybe some fun things to do in the city today?<|return|>
[RESPONSE 1] [COT] Need to call get_current_weather.
[RESPONSE 1] [FUNCTION CALL] get_current_weather(**{'location': 'San Francisco, CA'})
[RESPONSE 2] It’s a bright, sunny day in San Francisco with the temperature around 20 °C (68 °F). Enjoy the pleasant weather!
```

The function call works successfully:
Expand All @@ -95,14 +84,14 @@ The output would look like:

```txt
[USER PROMPT] What is the weather like in NY and SF?
[RESPONSE 1] <|channel|>analysis<|message|>The user asks: "What is the weather like in NY and SF?" They want the weather in New York and San Francisco. We need to provide the weather. We can use the function get_multiple_weathers. We need to provide the list of city and state strings. For New York, we can use "New York, NY". For San Francisco, "San Francisco, CA". We can call get_multiple_weathers with those two locations. We should specify format? The default is celsius. But maybe the user might want Fahrenheit? They didn't specify. We can just use default celsius. But maybe we can provide both? But the function only returns one format. We can just use default celsius. But we can also ask the user? But the user asked "What is the weather like in NY and SF?" We can just provide the weather. We can call the function. Then we will get the weather data. Then we can respond with a friendly tone. We need to call the function.<|end|><|start|>assistant<|channel|>commentary to=get_multiple_weathers <|constrain|>json<|message|>{"locations":["New York, NY","San Francisco, CA"]}<|call|>
[FUNCTION CALL] get_multiple_weathers(**{'locations': ['New York, NY', 'San Francisco, CA']})
[RESPONSE 2] <|channel|>analysis<|message|>The user asked: "What is the weather like in NY and SF?" We called get_multiple_weathers with locations ["New York, NY","San Francisco, CA"]. The function returned: [{"sunny": true, "temperature": 20}, {"sunny": true, "temperature": 20}]. That seems to be a list of two objects, each with sunny: true, temperature: 20. But we need to interpret the function output. The function get_multiple_weathers presumably returns a list of weather data for each location. But the returned data is ambiguous: we don't know which corresponds to which location. But we can assume the order matches the input order: first is New York, second is San Francisco. The temperature is 20 degrees Celsius? The function didn't specify units, but default is celsius. So 20°C. And sunny: true. So both are sunny and 20°C. We should respond in a friendly tone, summarizing the weather for both cities. We can mention that it's sunny and 20°C in both New York and San Francisco. We can also mention that it's a nice day. We can ask if they want more details. We should not mention the function call. We should just provide the answer.<|end|><|start|>assistant<|channel|>final<|message|>Sure thing! Here’s the scoop:

- **New York, NY**: It’s sunny and a comfortable 20 °C (68 °F).
- **San Francisco, CA**: Also sunny with a pleasant 20 °C (68 °F).

Looks like both coasts are enjoying a bright, mild day. Let me know if you’d like a forecast for later or any other details!<|return|>
[RESPONSE 1] [COT] Need to call get_multiple_weathers.
[RESPONSE 1] [FUNCTION CALL] get_multiple_weathers(**{'locations': ['New York, NY', 'San Francisco, CA'], 'format': 'celsius'})
[RESPONSE 2] Here’s a quick snapshot of the current weather in both cities:

| City | Weather | Temperature |
|------|---------|-------------|
| New York | ☀️ Sunny | 20 °C |
| San Francisco | ☀️ Sunny | 20 °C |
```

Once again, the function call works successfully, this time using a different function: `get_multiple_weathers`.
Expand Down
194 changes: 69 additions & 125 deletions examples/models/core/gpt_oss/openai_chat_client_function_calling.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,58 @@
import argparse
import json
import re

from openai import OpenAI

system_prompt = """You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2025-06-28

Reasoning: high

# Valid channels: analysis, commentary, final. Channel must be included for every message.
Calls to these tools must go to the commentary channel: 'functions'."""

developer_prompt = """# Instructions

Use a friendly tone.

# Tools

## functions

namespace functions {

// Gets the location of the user.
type get_location = () => any;

// Gets the current weather in the provided location.
type get_current_weather = (_: {
// The city and state, e.g. San Francisco, CA
location: string,
format?: "celsius" | "fahrenheit", // default: celsius
}) => any;

// Gets the current weather in the provided list of locations.
type get_multiple_weathers = (_: {
// List of city and state, e.g. ["San Francisco, CA", "New York, NY"]
locations: string[],
format?: "celsius" | "fahrenheit", // default: celsius
}) => any;

} // namespace functions"""

schema_get_current_weather = {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"description": "default: celsius",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
tool_get_current_weather = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Gets the current weather in the provided location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"description": "default: celsius",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
}
}
}

schema_get_multiple_weathers = {
"type": "object",
"properties": {
"locations": {
"type":
"array",
"items": {
"type": "string"
tool_get_multiple_weathers = {
"type": "function",
"function": {
"name": "get_multiple_weathers",
"description":
"Gets the current weather in the provided list of locations.",
"parameters": {
"type": "object",
"properties": {
"locations": {
"type":
"array",
"items": {
"type": "string"
},
"description":
'List of city and state, e.g. ["San Francisco, CA", "New York, NY"]',
},
"format": {
"type": "string",
"description": "default: celsius",
"enum": ["celsius", "fahrenheit"],
},
},
"description":
'List of city and state, e.g. ["San Francisco, CA", "New York, NY"]',
},
"format": {
"type": "string",
"description": "default: celsius",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["locations"],
"required": ["locations"],
}
}
}


Expand All @@ -103,14 +79,6 @@ def main():
)

messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "developer",
"content": developer_prompt,
},
{
"role": "user",
"content": args.prompt,
Expand All @@ -122,65 +90,41 @@ def main():
model=args.model,
messages=messages,
max_completion_tokens=500,
response_format={
"type":
"structural_tag",
"structures": [{
"begin":
"<|channel|>commentary to=get_current_weather <|constrain|>json<|message|>",
"schema": schema_get_current_weather,
"end": "<|call|>",
}, {
"begin":
"<|channel|>commentary to=get_multiple_weathers <|constrain|>json<|message|>",
"schema": schema_get_multiple_weathers,
"end": "<|call|>",
}],
"triggers": ["<|channel|>commentary to="],
},
stop=["<|call|>"],
extra_body={
"skip_special_tokens": False,
"include_stop_str_in_output": True,
},
tools=[tool_get_current_weather, tool_get_multiple_weathers],
)

response_text = chat_completion.choices[0].message.content
print(f"[RESPONSE 1] {response_text}")

for regex, tool in [
(r"(<\|channel\|>commentary to=get_current_weather <\|constrain\|>json<\|message\|>)([\S\s]+)(<\|call\|>)",
get_current_weather),
(r"(<\|channel\|>commentary to=get_multiple_weathers <\|constrain\|>json<\|message\|>)([\S\s]+)(<\|call\|>)",
get_multiple_weathers)
]:
match = re.search(regex, response_text)
if match is not None:
break
else:
print("Failed to call functions, exiting...")
return

kwargs = json.loads(match.group(2))
print(f"[FUNCTION CALL] {tool.__name__}(**{kwargs})")
tools = {
"get_current_weather": get_current_weather,
"get_multiple_weathers": get_multiple_weathers
}
message = chat_completion.choices[0].message
assert message, "Empty Message"
assert message.tool_calls, "Empty tool calls"
assert message.content is None, "Empty content expected"
reasoning = message.reasoning if hasattr(message, "reasoning") else None
tool_call = message.tool_calls[0]
func_name = tool_call.function.name
assert func_name in tools, "Invalid function name"
kwargs = json.loads(tool_call.function.arguments)

tool = tools[func_name]
print(f"[RESPONSE 1] [COT] {reasoning}")
print(f"[RESPONSE 1] [FUNCTION CALL] {tool.__name__}(**{kwargs})")
answer = tool(**kwargs)

messages.extend([{
"role": "assistant",
"content": match.group(0),
"reasoning": reasoning,
"tool_calls": [tool_call],
}, {
"role": f"{tool.__name__} to=assistant",
"role": "tool",
"content": json.dumps(answer),
"tool_call_id": tool_call.id
}])

chat_completion = client.chat.completions.create(
model=args.model,
messages=messages,
max_completion_tokens=500,
extra_body={
"skip_special_tokens": False,
"include_stop_str_in_output": True,
},
)

response_text = chat_completion.choices[0].message.content
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@ soundfile
triton==3.3.1; platform_machine == "x86_64"
tiktoken
blobfile
openai-harmony==0.0.4
36 changes: 26 additions & 10 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,25 +466,41 @@ def _enqueue_request(self, request: GenerationRequest) -> int:

def _deduce_max_tokens(request: GenerationRequest,
executor_config: tllm.ExecutorConfig) -> int:
if request.sampling_params.max_tokens:
return request.sampling_params.max_tokens
# deduce max_tokens when it's not set by user
max_tokens = request.sampling_params.max_tokens
query_token_len = len(
request.query_token_ids) if request.query_token_ids else 0
cp_size = 1 if (not hasattr(executor_config, "mapping")
or executor_config.mapping.cp_size
is None) else executor_config.mapping.cp_size
if not hasattr(executor_config, "max_seq_len"):
raise RuntimeError(
"max_tokens for sampling is not set and cannot be deduced")
logger.warning("`default_max_tokens` cannot be deduced")
if max_tokens is None:
raise ValueError(
"`max_tokens` must be set when `default_max_tokens` cannot be deduced"
)
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len
if default_max_tokens < 0:
raise ValueError(
f"Deduced max_tokens {default_max_tokens} is less than 0, because"
f"prompt length {splited_prompt_len} plus query length {query_token_len} "
f"is larger than max_seq_len {executor_config.max_seq_len}")
return default_max_tokens
if default_max_tokens <= 0:
logger.warning(
f"`default_max_tokens` ({default_max_tokens}) should be greater than 0, "
f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({executor_config.max_seq_len})"
f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})"
)
if max_tokens is None:
raise ValueError(
"`max_tokens` must be set when `default_max_tokens` is illegal"
)
# default_max_tokens is the biggest available value
if max_tokens is None:
return default_max_tokens
elif max_tokens > default_max_tokens:
logger.warning(
f"User-specified `max_tokens` ({max_tokens}) is greater than deduced "
f"`default_max_tokens` ({default_max_tokens}), using default_max_tokens instead."
)
return default_max_tokens
return max_tokens

try:
executor_request = tllm.Request(
Expand Down
Loading