|
1 | 1 | #include "chat.h" |
2 | 2 | #include "chat-parser.h" |
| 3 | +#include "chat-peg-parser.h" |
3 | 4 | #include "common.h" |
4 | 5 | #include "json-partial.h" |
5 | 6 | #include "json-schema-to-grammar.h" |
@@ -979,6 +980,117 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat |
979 | 980 | return data; |
980 | 981 | } |
981 | 982 |
|
| 983 | +static common_chat_params common_chat_params_init_mistral_3(const common_chat_template & tmpl, const struct templates_params & inputs) { |
| 984 | + common_chat_params data; |
| 985 | + |
| 986 | + // Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja |
| 987 | + auto adjusted_messages = json::array(); |
| 988 | + for (const auto & msg : inputs.messages) { |
| 989 | + auto role = msg.value("role", ""); |
| 990 | + if (role != "system" && role != "assistant") { |
| 991 | + // Only adjust system and assistant messages. Interestingly, the system message may contain thinking. |
| 992 | + adjusted_messages.push_back(msg); |
| 993 | + continue; |
| 994 | + } |
| 995 | + |
| 996 | + auto content = json::array(); |
| 997 | + |
| 998 | + // If message contains `reasoning_content`, add it as a block of type `thinking` |
| 999 | + if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) { |
| 1000 | + content.push_back({ |
| 1001 | + {"type", "thinking"}, |
| 1002 | + {"thinking", msg.at("reasoning_content").get<std::string>()}, |
| 1003 | + }); |
| 1004 | + } |
| 1005 | + |
| 1006 | + // If message contains `content`, add it as a block of type `text` |
| 1007 | + if (msg.contains("content")) { |
| 1008 | + if (msg.at("content").is_string()) { |
| 1009 | + content.push_back({ |
| 1010 | + {"type", "text"}, |
| 1011 | + {"text", msg.at("content").get<std::string>()}, |
| 1012 | + }); |
| 1013 | + } else if (msg.at("content").is_array()) { |
| 1014 | + auto blocks = msg.at("content"); |
| 1015 | + content.insert(content.end(), blocks.begin(), blocks.end()); |
| 1016 | + } |
| 1017 | + } |
| 1018 | + |
| 1019 | + auto adjusted = msg; |
| 1020 | + adjusted["content"] = content; |
| 1021 | + adjusted.erase("reasoning_content"); |
| 1022 | + adjusted_messages.push_back(adjusted); |
| 1023 | + } |
| 1024 | + |
| 1025 | + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); |
| 1026 | + auto include_grammar = true; |
| 1027 | + |
| 1028 | + data.prompt = apply(tmpl, inputs, /* messages_override = */ adjusted_messages); |
| 1029 | + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; |
| 1030 | + data.preserved_tokens = { |
| 1031 | + "[THINK]", |
| 1032 | + "[/THINK]", |
| 1033 | + "[TOOL_CALLS]", |
| 1034 | + "[ARGS]", |
| 1035 | + }; |
| 1036 | + |
| 1037 | + auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { |
| 1038 | + auto reasoning = p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]"); |
| 1039 | + |
| 1040 | + // Response format parser |
| 1041 | + if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { |
| 1042 | + // Ministral wants to emit json surrounded code fences |
| 1043 | + return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```"; |
| 1044 | + } |
| 1045 | + |
| 1046 | + // Tool call parser |
| 1047 | + if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { |
| 1048 | + auto tool_choice = p.choice(); |
| 1049 | + foreach_function(inputs.tools, [&](const json & tool) { |
| 1050 | + const auto & function = tool.at("function"); |
| 1051 | + std::string name = function.at("name"); |
| 1052 | + const auto & schema = function.at("parameters"); |
| 1053 | + |
| 1054 | + tool_choice |= p.rule("tool-" + name, |
| 1055 | + p.tool_open(p.tool_name(p.literal(name)) + "[ARGS]") |
| 1056 | + + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)) |
| 1057 | + ); |
| 1058 | + }); |
| 1059 | + |
| 1060 | + auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; |
| 1061 | + auto max_calls = inputs.parallel_tool_calls ? -1 : 1; |
| 1062 | + auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls)); |
| 1063 | + |
| 1064 | + return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls; |
| 1065 | + } |
| 1066 | + |
| 1067 | + // Content only parser |
| 1068 | + include_grammar = false; |
| 1069 | + return reasoning << p.content(p.rest()); |
| 1070 | + }); |
| 1071 | + |
| 1072 | + data.parser = parser.save(); |
| 1073 | + |
| 1074 | + if (include_grammar) { |
| 1075 | + data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; |
| 1076 | + |
| 1077 | + data.grammar = build_grammar([&](const common_grammar_builder & builder) { |
| 1078 | + foreach_function(inputs.tools, [&](const json & tool) { |
| 1079 | + const auto & function = tool.at("function"); |
| 1080 | + auto schema = function.at("parameters"); |
| 1081 | + builder.resolve_refs(schema); |
| 1082 | + }); |
| 1083 | + parser.build_grammar(builder, data.grammar_lazy); |
| 1084 | + }); |
| 1085 | + |
| 1086 | + data.grammar_triggers = { |
| 1087 | + {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"} |
| 1088 | + }; |
| 1089 | + } |
| 1090 | + |
| 1091 | + return data; |
| 1092 | +} |
| 1093 | + |
982 | 1094 | static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) { |
983 | 1095 | common_chat_params data; |
984 | 1096 | data.prompt = apply(tmpl, inputs); |
@@ -2496,6 +2608,13 @@ static common_chat_params common_chat_templates_apply_jinja( |
2496 | 2608 | return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools); |
2497 | 2609 | } |
2498 | 2610 |
|
| 2611 | + // Ministral/Mistral 3 |
| 2612 | + if (src.find("[SYSTEM_PROMPT]") != std::string::npos && |
| 2613 | + src.find("[TOOL_CALLS]") != std::string::npos && |
| 2614 | + src.find("[ARGS]") != std::string::npos) { |
| 2615 | + return common_chat_params_init_mistral_3(tmpl, params); |
| 2616 | + } |
| 2617 | + |
2499 | 2618 | if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) { |
2500 | 2619 | return common_chat_params_init_magistral(tmpl, params); |
2501 | 2620 | } |
|
0 commit comments