Skip to content

Commit 53cc80f

Browse files
committed
common : add parser for ministral/mistral 3
1 parent 4aacf75 commit 53cc80f

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed

common/chat.cpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "chat.h"
22
#include "chat-parser.h"
3+
#include "chat-peg-parser.h"
34
#include "common.h"
45
#include "json-partial.h"
56
#include "json-schema-to-grammar.h"
@@ -979,6 +980,117 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
979980
return data;
980981
}
981982

983+
static common_chat_params common_chat_params_init_mistral_3(const common_chat_template & tmpl, const struct templates_params & inputs) {
984+
common_chat_params data;
985+
986+
// Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja
987+
auto adjusted_messages = json::array();
988+
for (const auto & msg : inputs.messages) {
989+
auto role = msg.value("role", "");
990+
if (role != "system" && role != "assistant") {
991+
// Only adjust system and assistant messages. Interestingly, the system message may contain thinking.
992+
adjusted_messages.push_back(msg);
993+
continue;
994+
}
995+
996+
auto content = json::array();
997+
998+
// If message contains `reasoning_content`, add it as a block of type `thinking`
999+
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
1000+
content.push_back({
1001+
{"type", "thinking"},
1002+
{"thinking", msg.at("reasoning_content").get<std::string>()},
1003+
});
1004+
}
1005+
1006+
// If message contains `content`, add it as a block of type `text`
1007+
if (msg.contains("content")) {
1008+
if (msg.at("content").is_string()) {
1009+
content.push_back({
1010+
{"type", "text"},
1011+
{"text", msg.at("content").get<std::string>()},
1012+
});
1013+
} else if (msg.at("content").is_array()) {
1014+
auto blocks = msg.at("content");
1015+
content.insert(content.end(), blocks.begin(), blocks.end());
1016+
}
1017+
}
1018+
1019+
auto adjusted = msg;
1020+
adjusted["content"] = content;
1021+
adjusted.erase("reasoning_content");
1022+
adjusted_messages.push_back(adjusted);
1023+
}
1024+
1025+
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
1026+
auto include_grammar = true;
1027+
1028+
data.prompt = apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
1029+
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
1030+
data.preserved_tokens = {
1031+
"[THINK]",
1032+
"[/THINK]",
1033+
"[TOOL_CALLS]",
1034+
"[ARGS]",
1035+
};
1036+
1037+
auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
1038+
auto reasoning = p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]");
1039+
1040+
// Response format parser
1041+
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
1042+
// Ministral wants to emit json surrounded code fences
1043+
return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```";
1044+
}
1045+
1046+
// Tool call parser
1047+
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
1048+
auto tool_choice = p.choice();
1049+
foreach_function(inputs.tools, [&](const json & tool) {
1050+
const auto & function = tool.at("function");
1051+
std::string name = function.at("name");
1052+
const auto & schema = function.at("parameters");
1053+
1054+
tool_choice |= p.rule("tool-" + name,
1055+
p.tool_open(p.tool_name(p.literal(name)) + "[ARGS]")
1056+
+ p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))
1057+
);
1058+
});
1059+
1060+
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
1061+
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
1062+
auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls));
1063+
1064+
return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls;
1065+
}
1066+
1067+
// Content only parser
1068+
include_grammar = false;
1069+
return reasoning << p.content(p.rest());
1070+
});
1071+
1072+
data.parser = parser.save();
1073+
1074+
if (include_grammar) {
1075+
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
1076+
1077+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1078+
foreach_function(inputs.tools, [&](const json & tool) {
1079+
const auto & function = tool.at("function");
1080+
auto schema = function.at("parameters");
1081+
builder.resolve_refs(schema);
1082+
});
1083+
parser.build_grammar(builder, data.grammar_lazy);
1084+
});
1085+
1086+
data.grammar_triggers = {
1087+
{COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"}
1088+
};
1089+
}
1090+
1091+
return data;
1092+
}
1093+
9821094
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
9831095
common_chat_params data;
9841096
data.prompt = apply(tmpl, inputs);
@@ -2496,6 +2608,13 @@ static common_chat_params common_chat_templates_apply_jinja(
24962608
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
24972609
}
24982610

2611+
// Ministral/Mistral 3
2612+
if (src.find("[SYSTEM_PROMPT]") != std::string::npos &&
2613+
src.find("[TOOL_CALLS]") != std::string::npos &&
2614+
src.find("[ARGS]") != std::string::npos) {
2615+
return common_chat_params_init_mistral_3(tmpl, params);
2616+
}
2617+
24992618
if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) {
25002619
return common_chat_params_init_magistral(tmpl, params);
25012620
}

0 commit comments

Comments
 (0)