Skip to content

Commit b2426e4

Browse files
authored
chat : nemotron thinking & toolcalling support (#15676)
* feat: nemotron thinking & toolcalling support * Trailing whitespaces * Corrected template for Nemotron * Template and parser fixes * Final template and grammar changes * Whitespace * Always do lazy grammar processing since </think> tag will always be there. * Allow extra content after toolcall * Whitespace * New tests: thinking + tools, tools + content, thinking + tools + content (new!) * Whitespace * Remove cURL test script
1 parent 9e2b1e8 commit b2426e4

File tree

4 files changed

+333
-0
lines changed

4 files changed

+333
-0
lines changed

common/chat.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ const char * common_chat_format_name(common_chat_format format) {
623623
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
624624
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
625625
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
626+
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
626627
default:
627628
throw std::runtime_error("Unknown chat format");
628629
}
@@ -1184,6 +1185,67 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
11841185
});
11851186
return data;
11861187
}
1188+
1189+
static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
1190+
common_chat_params data;
1191+
1192+
// Generate the prompt using the apply() function with the template
1193+
data.prompt = apply(tmpl, inputs);
1194+
data.format = COMMON_CHAT_FORMAT_NEMOTRON_V2;
1195+
1196+
// Handle thinking tags appropriately based on inputs.enable_thinking
1197+
if (string_ends_with(data.prompt, "<think>\n")) {
1198+
if (!inputs.enable_thinking) {
1199+
data.prompt += "</think>";
1200+
} else {
1201+
data.thinking_forced_open = true;
1202+
}
1203+
}
1204+
1205+
// When tools are present, build grammar for the <TOOLCALL> format, similar to CommandR, but without tool call ID
1206+
if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) {
1207+
data.grammar_lazy = true;
1208+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1209+
auto schemas = json::array();
1210+
foreach_function(inputs.tools, [&](const json & tool) {
1211+
const auto & function = tool.at("function");
1212+
schemas.push_back({
1213+
{ "type", "object" },
1214+
{ "properties",
1215+
{
1216+
{ "name",
1217+
{
1218+
{ "type", "string" },
1219+
{ "const", function.at("name") },
1220+
} },
1221+
{ "arguments", function.at("parameters") },
1222+
} },
1223+
{ "required", json::array({ "name", "arguments" }) },
1224+
});
1225+
});
1226+
auto schema = json{
1227+
{ "type", "array" },
1228+
{ "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } },
1229+
{ "minItems", 1 },
1230+
};
1231+
if (!inputs.parallel_tool_calls) {
1232+
schema["maxItems"] = 1;
1233+
}
1234+
builder.add_rule("root",
1235+
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
1236+
"\"<TOOLCALL>\" " + builder.add_schema("tool_calls", schema) +
1237+
" \"</TOOLCALL>\"");
1238+
});
1239+
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
1240+
// If thinking_forced_open, then we capture the </think> tag in the grammar,
1241+
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
1242+
std::string(data.thinking_forced_open ?
1243+
"[\\s\\S]*?(</think>\\s*)" :
1244+
"(?:<think>[\\s\\S]*?</think>\\s*)?") +
1245+
"(<TOOLCALL>)[\\s\\S]*" });
1246+
}
1247+
return data;
1248+
}
11871249
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
11881250
if (!builder.syntax().parse_tool_calls) {
11891251
builder.add_content(builder.consume_rest());
@@ -2060,6 +2122,33 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
20602122
}
20612123
}
20622124

2125+
static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
2126+
// Parse thinking tags
2127+
builder.try_parse_reasoning("<think>", "</think>");
2128+
if (!builder.syntax().parse_tool_calls) {
2129+
builder.add_content(builder.consume_rest());
2130+
return;
2131+
}
2132+
2133+
// Look for tool calls
2134+
static const common_regex tool_call_regex(regex_escape("<TOOLCALL>"));
2135+
if (auto res = builder.try_find_regex(tool_call_regex)) {
2136+
builder.move_to(res->groups[0].end);
2137+
2138+
// Expect JSON array of tool calls
2139+
auto tool_calls_data = builder.consume_json();
2140+
if (tool_calls_data.json.is_array()) {
2141+
if (!builder.try_consume_literal("</TOOLCALL>")) {
2142+
throw common_chat_msg_partial_exception("Incomplete tool call");
2143+
}
2144+
builder.add_tool_calls(tool_calls_data.json);
2145+
} else {
2146+
throw common_chat_msg_partial_exception("Incomplete tool call");
2147+
}
2148+
}
2149+
builder.add_content(builder.consume_rest());
2150+
}
2151+
20632152
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
20642153
// Parse thinking tags first - this handles the main reasoning content
20652154
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
@@ -2293,6 +2382,11 @@ static common_chat_params common_chat_templates_apply_jinja(
22932382
return common_chat_params_init_seed_oss(tmpl, params, inputs);
22942383
}
22952384

2385+
// Nemotron v2
2386+
if (src.find("<SPECIAL_10>") != std::string::npos) {
2387+
return common_chat_params_init_nemotron_v2(tmpl, params);
2388+
}
2389+
22962390
// Use generic handler when mixing tools + JSON schema.
22972391
// TODO: support that mix in handlers below.
22982392
if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -2454,6 +2548,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
24542548
case COMMON_CHAT_FORMAT_SEED_OSS:
24552549
common_chat_parse_seed_oss(builder);
24562550
break;
2551+
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
2552+
common_chat_parse_nemotron_v2(builder);
2553+
break;
24572554
default:
24582555
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
24592556
}

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ enum common_chat_format {
112112
COMMON_CHAT_FORMAT_GRANITE,
113113
COMMON_CHAT_FORMAT_GPT_OSS,
114114
COMMON_CHAT_FORMAT_SEED_OSS,
115+
COMMON_CHAT_FORMAT_NEMOTRON_V2,
115116

116117
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
117118
};
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
{%- set ns = namespace(enable_thinking=true) -%}
2+
{%- for message in messages -%}
3+
{%- set content = message['content'] -%}
4+
{%- if message['role'] == 'user' or message['role'] == 'system' -%}
5+
{%- if '/think' in content -%}
6+
{%- set ns.enable_thinking = true -%}
7+
{%- elif '/no_think' in content -%}
8+
{%- set ns.enable_thinking = false -%}
9+
{%- endif -%}
10+
{%- endif -%}
11+
{%- endfor -%}
12+
13+
{%- if messages[0]['role'] != 'system' -%}
14+
{%- set ns.non_tool_system_content = '' -%}
15+
{{- '<SPECIAL_10>System
16+
' -}}
17+
{%- else -%}
18+
{%- set ns.non_tool_system_content = (messages[0]['content'] | default('', true)).replace('/think', '').replace('/no_think', '').strip() -%}
19+
{{- '<SPECIAL_10>System
20+
' + ns.non_tool_system_content }}
21+
{%- endif -%}
22+
23+
{%- if tools -%}
24+
{%- if ns.non_tool_system_content is defined and ns.non_tool_system_content != '' -%}
25+
{{- '
26+
27+
' -}}
28+
{%- endif -%}
29+
{{- 'You can use the following tools to assist the user if required:' -}}
30+
{{- '
31+
<AVAILABLE_TOOLS>[' -}}
32+
{%- for tool in tools -%}
33+
{{- (tool.function if tool.function is defined else tool) | tojson -}}
34+
{{- ', ' if not loop.last else '' -}}
35+
{%- endfor -%}
36+
{{- ']</AVAILABLE_TOOLS>
37+
38+
' -}}
39+
{{- 'If you decide to call any tool(s), use the following format:
40+
' -}}
41+
{{- '<TOOLCALL>[{{"name": "tool_name1", "arguments": "tool_args1"}}, ' -}}
42+
{{- '{{"name": "tool_name2", "arguments": "tool_args2"}}]</TOOLCALL>
43+
44+
' -}}
45+
{{- 'The user will execute tool-calls and return responses from tool(s) in this format:
46+
' -}}
47+
{{- '<TOOL_RESPONSE>[{{"tool_response1"}}, {{"tool_response2"}}]</TOOL_RESPONSE>
48+
49+
' -}}
50+
{{- 'Based on the tool responses, you can call additional tools if needed, correct tool calls if any errors are found, or just respond to the user.' -}}
51+
{%- endif -%}
52+
{{- '
53+
54+
' -}}
55+
{%- set messages = messages[1:] if messages[0]['role'] == 'system' else messages -%}
56+
{%- if messages[-1]['role'] == 'assistant' -%}
57+
{%- set ns.last_turn_assistant_content = (messages[-1]['content'] | default('', true)).strip() -%}
58+
{%- set ns.last_turn_assistant_tool_calls = messages[-1]['tool_calls'] if 'tool_calls' in messages[-1] else [] -%}
59+
{%- set messages = messages[:-1] -%}
60+
{%- endif -%}
61+
62+
{%- for message in messages %}
63+
{%- set content = message['content'] %}
64+
{%- if message['role'] == 'user' -%}
65+
{{- '<SPECIAL_11>User
66+
' + (content | default('', true)).replace('/think', '').replace('/no_think', '').strip() + '
67+
' }}
68+
{%- elif message['role'] == 'tool' -%}
69+
{%- if loop.first or (messages[loop.index0 - 1].role != 'tool') -%}
70+
{{- '<SPECIAL_11>User
71+
' + '<TOOL_RESPONSE>[' }}
72+
{%- endif -%}
73+
{{- message['content'] -}}
74+
{{- ', ' if not loop.last and (messages[loop.index0 + 1].role == 'tool') else '' -}}
75+
{%- if loop.last or (messages[loop.index0 + 1].role != 'tool') -%}
76+
{{- ']</TOOL_RESPONSE>' -}}
77+
{%- endif -%}
78+
{%- elif message['role'] == 'assistant' -%}
79+
{%- if content and '</think>' in content -%}
80+
{%- set content = (content.split('</think>')[1] | default('', true)).strip() %}
81+
{%- endif -%}
82+
{{- '<SPECIAL_11>Assistant
83+
' + ((content | default('', true)).strip() if content is not none else '') }}
84+
{%- if message.tool_calls -%}
85+
{%- if (content | default('', true)).strip() != '' -%}
86+
{{- '
87+
' -}}
88+
{%- endif -%}
89+
{{- '<TOOLCALL>[' -}}
90+
{%- for call in message.tool_calls -%}
91+
{%- set fn = call.function if call.function is defined else call -%}
92+
{{- '{"name": "' + fn.name + '", "arguments": ' -}}
93+
{%- if fn.arguments is string -%}
94+
{{- fn.arguments -}}
95+
{%- else -%}
96+
{{- fn.arguments | tojson -}}
97+
{%- endif -%}
98+
{{- '}' + (', ' if not loop.last else '') -}}
99+
{%- endfor -%}
100+
{{- ']</TOOLCALL>' -}}
101+
{%- endif -%}
102+
{{- '
103+
<SPECIAL_12>
104+
' -}}
105+
{%- endif -%}
106+
{%- endfor -%}
107+
108+
{%- if add_generation_prompt -%}
109+
{{- '<SPECIAL_11>Assistant
110+
' -}}
111+
{%- if ns.enable_thinking is defined and ns.enable_thinking is false -%}
112+
{{- '<think></think>' -}}
113+
{%- else -%}
114+
{{- '<think>
115+
' -}}
116+
{%- endif -%}
117+
{%- if ns.last_turn_assistant_content is defined and ns.last_turn_assistant_content != '' -%}
118+
{{- ns.last_turn_assistant_content -}}
119+
{%- endif -%}
120+
{%- else -%}
121+
{%- if ns.last_turn_assistant_content is defined and ns.last_turn_assistant_content != '' -%}
122+
{{- '<SPECIAL_11>Assistant
123+
' -}}
124+
{%- if ns.enable_thinking is defined and ns.enable_thinking is false -%}
125+
{{- '<think></think>' -}}
126+
{%- else -%}
127+
{{- '<think>
128+
' -}}
129+
{%- endif -%}
130+
{{- ns.last_turn_assistant_content -}}
131+
{%- if continue_final_message is defined -%}
132+
{%- if continue_final_message is false -%}
133+
{{- '
134+
<SPECIAL_12>
135+
' -}}
136+
{%- endif -%}
137+
{%- else -%}
138+
{{- '
139+
<SPECIAL_12>
140+
' -}}
141+
{%- endif -%}
142+
{%- endif -%}
143+
{%- if ns.last_turn_assistant_tool_calls is defined and ns.last_turn_assistant_tool_calls | length > 0 -%}
144+
{{- '<SPECIAL_11>Assistant
145+
' -}}
146+
{{- '<TOOLCALL>[' -}}
147+
{%- for call in ns.last_turn_assistant_tool_calls -%}
148+
{%- set fn = call.function if call.function is defined else call -%}
149+
{{- '{"name": "' + fn.name + '", "arguments": ' -}}
150+
{%- if fn.arguments is string -%}
151+
{{- fn.arguments -}}
152+
{%- else -%}
153+
{{- fn.arguments | tojson -}}
154+
{%- endif -%}
155+
{{- '}' + (', ' if not loop.last else '') -}}
156+
{%- endfor -%}
157+
{{- ']</TOOLCALL>' -}}
158+
{{- '<SPECIAL_12>
159+
160+
' -}}
161+
{%- endif -%}
162+
{%- endif -%}

tests/test-chat.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ const common_chat_msg message_assist_call_empty_args = simple_assist
420420
const common_chat_msg message_assist_call_cutoff_args = simple_assist_msg("", "", "special_function", "{\"arg");
421421
const common_chat_msg message_assist_call_thoughts = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\":1}");
422422
const common_chat_msg message_assist_call_thoughts_unparsed = simple_assist_msg("<think>I'm\nthinking</think>\n\n", "", "special_function", "{\"arg1\": 1}");
423+
const common_chat_msg message_assist_call_thoughts_content = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": 1}");
423424
const common_chat_msg message_assist_call_id = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "123456789");
424425
const common_chat_msg message_assist_call_idx = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "0");
425426
const common_chat_msg message_assist_thoughts_call_idx = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}", /* id = */ "0");
@@ -436,6 +437,7 @@ static void test_msgs_oaicompat_json_conversion() {
436437
message_assist_call,
437438
message_assist_call_thoughts,
438439
message_assist_call_thoughts_unparsed,
440+
message_assist_call_thoughts_content,
439441
message_assist_call_id,
440442
message_assist_call_idx,
441443
message_assist_call_python,
@@ -1755,6 +1757,77 @@ static void test_template_output_parsers() {
17551757
/* is_partial= */ false,
17561758
{COMMON_CHAT_FORMAT_SEED_OSS}));
17571759
}
1760+
1761+
{
1762+
auto tmpls = read_templates("models/templates/NVIDIA-Nemotron-Nano-v2.jinja");
1763+
std::vector<std::string> end_tokens{ "<SPECIAL_12>" };
1764+
1765+
assert_equals(COMMON_CHAT_FORMAT_NEMOTRON_V2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
1766+
assert_equals(COMMON_CHAT_FORMAT_NEMOTRON_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
1767+
1768+
// Test parsing regular content
1769+
assert_msg_equals(message_assist,
1770+
common_chat_parse(
1771+
"Hello, world!\nWhat's up?",
1772+
/* is_partial= */ false,
1773+
{COMMON_CHAT_FORMAT_NEMOTRON_V2}));
1774+
1775+
// Test parsing content with thinking
1776+
assert_msg_equals(message_assist_thoughts,
1777+
common_chat_parse(
1778+
"<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
1779+
/* is_partial= */ false,
1780+
{
1781+
/* .format = */ COMMON_CHAT_FORMAT_NEMOTRON_V2,
1782+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
1783+
}));
1784+
1785+
// Test parsing tool calls
1786+
assert_msg_equals(message_assist_call,
1787+
common_chat_parse(
1788+
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
1789+
/* is_partial= */ false,
1790+
{COMMON_CHAT_FORMAT_NEMOTRON_V2}));
1791+
1792+
// Test parsing tool calls with thinking
1793+
assert_msg_equals(message_assist_call_thoughts,
1794+
common_chat_parse(
1795+
"<think>I'm\nthinking</think><TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
1796+
/* is_partial= */ false,
1797+
{
1798+
/* .format = */ COMMON_CHAT_FORMAT_NEMOTRON_V2,
1799+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK
1800+
}));
1801+
1802+
// Test tool calls with extra content
1803+
assert_msg_equals(message_assist_call_content,
1804+
common_chat_parse(
1805+
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>Hello, world!\nWhat's up?",
1806+
/* is_partial= */ false,
1807+
{COMMON_CHAT_FORMAT_NEMOTRON_V2}
1808+
));
1809+
1810+
// Test tool calls with extra content AND thinking
1811+
assert_msg_equals(message_assist_call_thoughts_content,
1812+
common_chat_parse(
1813+
"<think>I'm\nthinking</think><TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>Hello, world!\nWhat's up?",
1814+
/* is_partial= */ false,
1815+
{
1816+
/* .format = */ COMMON_CHAT_FORMAT_NEMOTRON_V2,
1817+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK
1818+
}));
1819+
1820+
// Test template generation for regular content
1821+
test_templates(tmpls.get(), end_tokens, message_assist, tools,
1822+
"Hello, world!\nWhat's up?\n",
1823+
/* expect_grammar_triggered= */ false);
1824+
1825+
// Test template generation for tool calls
1826+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
1827+
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
1828+
/* expect_grammar_triggered= */ true
1829+
);
1830+
}
17581831
}
17591832

17601833
static void test_msg_diffs_compute() {

0 commit comments

Comments
 (0)