From cd909c2df435393a4d295b5fedada10713df0f60 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 24 Jul 2025 21:39:17 +0200 Subject: [PATCH 01/10] mtmd : add support for Voxtral --- .gitignore | 1 + convert_hf_to_gguf.py | 76 ++++- gguf-py/gguf/constants.py | 1 + gguf-py/gguf/vocab.py | 280 ++++++++++++++++++ .../unsloth-mistral-Devstral-Small-2507.jinja | 105 +++++++ .../requirements-convert_hf_to_gguf.txt | 2 + tools/mtmd/clip-impl.h | 2 + tools/mtmd/clip.cpp | 63 ++-- 8 files changed, 510 insertions(+), 20 deletions(-) create mode 100644 models/templates/unsloth-mistral-Devstral-Small-2507.jinja diff --git a/.gitignore b/.gitignore index f8ceb1560a1df..f48ce4cacd144 100644 --- a/.gitignore +++ b/.gitignore @@ -82,6 +82,7 @@ models/* models-mnt !models/.editorconfig !models/ggml-vocab-*.gguf* +!models/templates # Zig zig-out/ diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e12c922bd9ab4..0dbd783818770 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2260,6 +2260,63 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("VoxtralForConditionalGeneration") +class VoxtralModel(LlamaModel): + model_arch = gguf.MODEL_ARCH.LLAMA + + def set_vocab(self): + vocab = gguf.vocab.MistralVocab(self.dir_model) + self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model) + + tokens = [] + scores = [] + toktypes = [] + + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + assert len(tokens) == vocab.vocab_size, ( + f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})" + ) + + if vocab.tokenizer_type == gguf.vocab.MistralTokenizerType.tekken: + self.gguf_writer.add_tokenizer_pre("tekken") + self.gguf_writer.add_token_merges( + vocab.extract_vocab_merges_from_model() + ) + + logger.info( + f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}." + ) + + self.gguf_writer.add_bos_token_id(vocab.bos_id) + self.gguf_writer.add_eos_token_id(vocab.eos_id) + self.gguf_writer.add_unk_token_id(vocab.unk_id) + self.gguf_writer.add_pad_token_id(vocab.pad_id) + + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_vocab_size(vocab.vocab_size) + + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(False) + + script_dir = Path(__file__).parent + template_path = script_dir / "models/templates/unsloth-mistral-Devstral-Small-2507.jinja" + with open(template_path, "r", encoding="utf-8") as f: + template = f.read() + self.gguf_writer.add_chat_template(template) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + name = name.replace("language_model.", "") + if "multi_modal_projector" in name or "audio_tower" in name: + return [] + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("DeciLMForCausalLM") class DeciModel(TextModel): model_arch = gguf.MODEL_ARCH.DECI @@ -7231,9 +7288,10 @@ class WhisperEncoderModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.hparams["hidden_size"] = self.hparams["d_model"] - self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"] - self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"] + if "hidden_size" not in self.hparams and "intermediate_size" not in self.hparams: + self.hparams["hidden_size"] = self.hparams["d_model"] + self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"] + self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"] def set_gguf_parameters(self): super().set_gguf_parameters() @@ -7272,9 +7330,21 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel): def set_gguf_parameters(self): super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.ULTRAVOX) self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) +@ModelBase.register("VoxtralForConditionalGeneration") +class VoxtralWhisperEncoderModel(WhisperEncoderModel): + has_vision_encoder = False # no vision encoder + has_audio_encoder = True + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.VOXTRAL) + self.gguf_writer.add_audio_stack_factor(4) # == intermediate_size // hidden_size + + @ModelBase.register("FalconH1ForCausalLM") class FalconH1Model(Mamba2Model): model_arch = gguf.MODEL_ARCH.FALCON_H1 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 680210db7e9d5..e7b2a7a5e5e47 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2704,6 +2704,7 @@ class VisionProjectorType: INTERNVL = "internvl" QWEN2A = "qwen2a" # audio QWEN25O = "qwen2.5o" # omni + VOXTRAL = "voxtral" # Items here are (block size, type size) diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 635fcef35e235..6c876c8565f76 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -1,5 +1,6 @@ from __future__ import annotations +from enum import Enum import re import logging import json @@ -12,6 +13,26 @@ except ImportError: SentencePieceProcessor = None +try: + from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + # from mistral_common.tokens.tokenizers.utils import ( + # _filter_valid_tokenizer_files, + # ) + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer, + ) +except ImportError: + _mistral_common_installed = False + MistralTokenizer = None + Tekkenizer = None + SentencePieceTokenizer = None + _filter_valid_tokenizer_files = None +else: + _mistral_common_installed = True + _filter_valid_tokenizer_files = lambda x: x # noqa: E731 + + import gguf from .gguf_writer import GGUFWriter @@ -592,3 +613,262 @@ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: def __repr__(self) -> str: return f"" + + +class MistralTokenizerType(str, Enum): + spm = "spm" + tekken = "tekken" + + +# Copied from Transformers (Apache 2.0) +# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1544 + +def bytes_to_unicode() -> dict[int, str]: + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs_str = [chr(n) for n in cs] + return dict(zip(bs, cs_str)) + + +class MistralVocab(Vocab): + tokenizer_model = "mistral" + name = "mistral" + + added_tokens_dict: dict[str, int] = {} + added_tokens_list: list[str] = [] + + def __init__(self, base_path: Path): + if not _mistral_common_installed: + raise ImportError( + "To use MistralVocab, please install the `mistral-common` package. " + "You can install it with `pip install mistral-common`." + ) + assert _filter_valid_tokenizer_files is not None, "mistral_common is not installed" + assert MistralTokenizer is not None, "mistral_common is not installed" + assert Tekkenizer is not None, "mistral_common is not installed" + + logger.info(f"Loading Mistral tokenizer from {base_path}") + + # Find the tokenizer files + all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()] + valid_tokenizer_files = _filter_valid_tokenizer_files(all_files) + + if len(valid_tokenizer_files) == 0: + raise ValueError(f"No tokenizer file found in the directory: {base_path}") + # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one. + if len(valid_tokenizer_files) > 1: + if "tekken.json" in valid_tokenizer_files: + tokenizer_file = "tekken.json" + else: + tokenizer_file = sorted(valid_tokenizer_files)[-1] + logger.warning( + f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}" + ) + else: + tokenizer_file = valid_tokenizer_files[0] + + self.tokenizer = MistralTokenizer.from_file( + base_path / tokenizer_file + ).instruct_tokenizer.tokenizer + self.tokenizer_type = ( + MistralTokenizerType.tekken + if isinstance(self.tokenizer, Tekkenizer) + else MistralTokenizerType.spm + ) + self.vocab_size = self.tokenizer.n_words + self.fname_tokenizer = base_path / tokenizer_file + self._name = ( + "mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version + ) + + @property + def tokenizer_name(self) -> str: + return self._name + + @property + def gguf_tokenizer_model(self) -> str: + return "llama" if self.tokenizer_type == MistralTokenizerType.spm else "gpt2" + + def _sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + assert SentencePieceTokenizer is not None, "mistral_common is not installed" + assert isinstance(self.tokenizer, SentencePieceTokenizer), ( + f"Expected SentencePieceTokenizer, got {type(self.tokenizer)}" + ) + + for i in range(self.tokenizer._model.vocab_size()): + piece = self.tokenizer._model.IdToPiece(i) + text = piece.encode("utf-8") + score: float = self.tokenizer._model.GetScore(i) + + toktype = gguf.TokenType.NORMAL + if self.tokenizer._model.IsUnknown(i): + toktype = gguf.TokenType.UNKNOWN + if self.tokenizer._model.IsControl(i): + toktype = gguf.TokenType.CONTROL + + if self.tokenizer._model.IsUnused(i): + toktype = gguf.TokenType.UNUSED + if self.tokenizer._model.IsByte(i): + toktype = gguf.TokenType.BYTE + + yield text, score, toktype + + def _tekken_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + assert Tekkenizer is not None, "mistral_common is not installed" + assert isinstance(self.tokenizer, Tekkenizer), ( + f"Expected Tekkenizer, got {type(self.tokenizer)}" + ) + + byte_encoder = bytes_to_unicode() + for token_id in range(self.tokenizer.num_special_tokens): + yield ( + self.tokenizer.id_to_piece(token_id).encode("utf-8"), + 0, + gguf.TokenType.CONTROL + ) + for token in self.tokenizer._tekken_token2id_nospecial: + yield ( + self.token_bytes_to_string(token, byte_encoder).encode("utf-8"), + 0, + gguf.TokenType.NORMAL, + ) + + def get_token_id(self, token: str) -> int: + assert SentencePieceTokenizer is not None and Tekkenizer is not None, "mistral_common is not installed" + if self.tokenizer_type == MistralTokenizerType.spm: + assert isinstance(self.tokenizer, SentencePieceTokenizer) + return self.tokenizer._vocab.index(token) + elif self.tokenizer_type == MistralTokenizerType.tekken: + assert isinstance(self.tokenizer, Tekkenizer) + return ( + self.tokenizer._vocab.index(token) + self.tokenizer.num_special_tokens + ) + else: + raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}") + + @property + def bos_id(self) -> int: + return self.tokenizer.bos_id + + @property + def eos_id(self) -> int: + return self.tokenizer.eos_id + + @property + def pad_id(self) -> int: + if self.tokenizer.pad_id == -1: + return self.eos_id + return self.tokenizer.pad_id + + @property + def unk_id(self) -> int: + return self.tokenizer.unk_id + + @property + def bos_token(self) -> str: + return self.tokenizer.id_to_piece(self.tokenizer.bos_id) + + @property + def eos_token(self) -> str: + return self.tokenizer.id_to_piece(self.tokenizer.eos_id) + + @property + def pad_token(self) -> str: + return self.tokenizer.id_to_piece(self.tokenizer.pad_id) + + @property + def unk_token(self) -> str: + return self.tokenizer.id_to_piece(self.tokenizer.unk_id) + + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + if self.tokenizer_type == MistralTokenizerType.spm: + yield from self._sentencepiece_tokens() + + elif self.tokenizer_type == MistralTokenizerType.tekken: + yield from self._tekken_tokens() + + else: + raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}") + + @staticmethod + def token_bytes_to_string(b, byte_encoder): + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) + + def extract_vocab_merges_from_model(self): + # Adapted from Transformers (Apache 2.0) + # https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py + assert Tekkenizer is not None and isinstance(self.tokenizer, Tekkenizer), ( + f"Expected Tekkenizer, got {type(self.tokenizer)}" + ) + mergeable_ranks = self.tokenizer._model._mergeable_ranks + token_bytes_map = { + rank: token_bytes for token_bytes, rank in mergeable_ranks.items() + } + merge_pairs = [] + + # Sort vocab by rank to ensure correct merge order + for i in range(256, self.vocab_size - self.tokenizer.num_special_tokens): + merged_token = token_bytes_map[i] + local = [] + for j in range(1, len(merged_token)): + left = merged_token[:j] + right = merged_token[j:] + if ( + left in mergeable_ranks + and right in mergeable_ranks + and (left + right) in mergeable_ranks + ): + local.append((left, right, i)) + if not local: + raise ValueError( + f"Could not find valid merge for token at rank {i}: {merged_token.decode('latin-1')}" + ) + local = sorted( + local, + key=lambda x: (mergeable_ranks[x[0]], mergeable_ranks[x[1]]), + reverse=False, + ) + merge_pairs.extend(local) + merge_pairs = sorted(merge_pairs, key=lambda val: val[2], reverse=False) + + byte_encoder = bytes_to_unicode() + + decoded_merge_pairs = [ + [ + self.token_bytes_to_string(val[0], byte_encoder), + self.token_bytes_to_string(val[1], byte_encoder), + ] + for val in merge_pairs + ] + + merges = [ + " ".join( + [ + # ensure the spaces are properly encoded + "".join(chr(ord(c) + 256) if c == " " else c for c in part) + for part in pair + ] + ) + for pair in decoded_merge_pairs + ] + + return merges diff --git a/models/templates/unsloth-mistral-Devstral-Small-2507.jinja b/models/templates/unsloth-mistral-Devstral-Small-2507.jinja new file mode 100644 index 0000000000000..28b60c7ce3e44 --- /dev/null +++ b/models/templates/unsloth-mistral-Devstral-Small-2507.jinja @@ -0,0 +1,105 @@ +{#- Copyright 2025-present the Unsloth team. All rights reserved. #} +{#- Licensed under the Apache License, Version 2.0 (the "License") #} +{#- Edits made by Unsloth #} +{%- set default_system_message = 'You are Devstral, a helpful agentic model trained by Mistral AI and using the OpenHands scaffold. You can interact with a computer to solve tasks.\n\n\nYour primary role is to assist users by executing commands, modifying code, and solving technical problems effectively. You should be thorough, methodical, and prioritize quality over speed.\n* If the user asks a question, like \"why is X happening\", don\'t try to fix the problem. Just give an answer to the question.\n\n\n\n* Each action you take is somewhat expensive. Wherever possible, combine multiple actions into a single action, e.g. combine multiple bash commands into one, using sed and grep to edit/view multiple files at once.\n* When exploring the codebase, use efficient tools like find, grep, and git commands with appropriate filters to minimize unnecessary operations.\n\n\n\n* When a user provides a file path, do NOT assume it\'s relative to the current working directory. First explore the file system to locate the file before working on it.\n* If asked to edit a file, edit the file directly, rather than creating a new file with a different filename.\n* For global search-and-replace operations, consider using `sed` instead of opening file editors multiple times.\n\n\n\n* Write clean, efficient code with minimal comments. Avoid redundancy in comments: Do not repeat information that can be easily inferred from the code itself.\n* When implementing solutions, focus on making the minimal changes needed to solve the problem.\n* Before implementing any changes, first thoroughly understand the codebase through exploration.\n* If you are adding a lot of code to a function or file, consider splitting the function or file into smaller pieces when appropriate.\n\n\n\n* When configuring git credentials, use \"openhands\" as the user.name and \"openhands@all-hands.dev\" as the user.email by default, unless explicitly instructed otherwise.\n* Exercise caution with git operations. Do NOT make potentially dangerous changes (e.g., pushing to main, deleting repositories) unless explicitly asked to do so.\n* When committing changes, use `git status` to see all modified files, and stage all files necessary for the commit. Use `git commit -a` whenever possible.\n* Do NOT commit files that typically shouldn\'t go into version control (e.g., node_modules/, .env files, build directories, cache files, large binaries) unless explicitly instructed by the user.\n* If unsure about committing certain files, check for the presence of .gitignore files or ask the user for clarification.\n\n\n\n* When creating pull requests, create only ONE per session/issue unless explicitly instructed otherwise.\n* When working with an existing PR, update it with new commits rather than creating additional PRs for the same issue.\n* When updating a PR, preserve the original PR title and purpose, updating description only when necessary.\n\n\n\n1. EXPLORATION: Thoroughly explore relevant files and understand the context before proposing solutions\n2. ANALYSIS: Consider multiple approaches and select the most promising one\n3. TESTING:\n * For bug fixes: Create tests to verify issues before implementing fixes\n * For new features: Consider test-driven development when appropriate\n * If the repository lacks testing infrastructure and implementing tests would require extensive setup, consult with the user before investing time in building testing infrastructure\n * If the environment is not set up to run tests, consult with the user first before investing time to install all dependencies\n4. IMPLEMENTATION: Make focused, minimal changes to address the problem\n5. VERIFICATION: If the environment is set up to run tests, test your implementation thoroughly, including edge cases. If the environment is not set up to run tests, consult with the user first before investing time to run tests.\n\n\n\n* Only use GITHUB_TOKEN and other credentials in ways the user has explicitly requested and would expect.\n* Use APIs to work with GitHub or other platforms, unless the user asks otherwise or your task requires browsing.\n\n\n\n* When user asks you to run an application, don\'t stop if the application is not installed. Instead, please install the application and run the command again.\n* If you encounter missing dependencies:\n 1. First, look around in the repository for existing dependency files (requirements.txt, pyproject.toml, package.json, Gemfile, etc.)\n 2. If dependency files exist, use them to install all dependencies at once (e.g., `pip install -r requirements.txt`, `npm install`, etc.)\n 3. Only install individual packages directly if no dependency files are found or if only specific packages are needed\n* Similarly, if you encounter missing dependencies for essential tools requested by the user, install them when possible.\n\n\n\n* If you\'ve made repeated attempts to solve a problem but tests still fail or the user reports it\'s still broken:\n 1. Step back and reflect on 5-7 different possible sources of the problem\n 2. Assess the likelihood of each possible cause\n 3. Methodically address the most likely causes, starting with the highest probability\n 4. Document your reasoning process\n* When you run into any major issue while executing a plan from the user, please don\'t try to directly work around it. Instead, propose a new plan and confirm with the user before proceeding.\n' %} + +{{- bos_token }} + +{%- if messages[0]['role'] == 'system' %} + {%- if messages[0]['content'] is string %} + {%- set system_message = messages[0]['content'] %} + {%- else %} + {%- set system_message = messages[0]['content'][0]['text'] %} + {%- endif %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set system_message = default_system_message %} + {%- set loop_messages = messages %} +{%- endif %} +{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }} + + +{#- Tool description appended ONLY to last user message. Edits made by Unsloth #} +{#- Tool description appended also if last message is tool. Edits made by Unsloth #} +{%- set tools_description = "" %} +{%- set has_tools = false %} + +{%- if tools is defined and tools is not none and tools|length > 0 %} + + {%- set has_tools = true %} + {%- set tools_description = "[AVAILABLE_TOOLS]" + (tools | tojson) + "[/AVAILABLE_TOOLS]" %} + + {{- tools_description }} + +{%- endif %} + +{%- for message in loop_messages %} + {%- if message['role'] == 'user' %} + + {%- if message['content'] is string %} + {{- '[INST]' + message['content'] + '[/INST]' }} + {%- else %} + {{- '[INST]' }} + {%- for block in message['content'] %} + {%- if block['type'] == 'text' %} + + {#- Original did not have content which is weird. Added by Un-sloth. #} + {%- if block['text'] is defined %} + {{- block['text'] }} + {%- else %} + {{- block['content'] }} + {%- endif %} + + {%- elif block['type'] in ['image', 'image_url'] %} + {{- '[IMG]' }} + {%- else %} + {{- raise_exception('Only text and image blocks are supported in message content!') }} + {%- endif %} + {%- endfor %} + {{- '[/INST]' }} + {%- endif %} + + {%- elif message['role'] == 'system' %} + {%- if message['content'] is string %} + {{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }} + {%- else %} + {{- '[SYSTEM_PROMPT]' + message['content'][0]['text'] + '[/SYSTEM_PROMPT]' }} + {%- endif %} + + + {%- elif message['role'] == 'assistant' %} + {%- if message['content'] is string %} + {{- message['content'] }} + {%- else %} + {{- message['content'][0]['text'] }} + {%- endif %} + + {#- If User,Assistant,Tool,Tool we also need to append tools_description. Edits made by Unsloth #} + + {%- if message['tool_calls'] is defined and message['tool_calls'] is not none %} + {%- for tool in message['tool_calls'] %} + {%- set arguments = tool['function']['arguments'] %} + {%- if arguments is not string %} + {%- set arguments = arguments|tojson %} + {%- endif %} + {#- Must list tool calls AFTER assistant. Edits made by Un-sloth #} + {{- "[TOOL_CALLS]" + tool['function']['name'] + "[ARGS]" + arguments }} + {%- endfor %} + {%- endif %} + + {{- eos_token }} + + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- "[TOOL_RESULTS]" + content|string + "[/TOOL_RESULTS]" }} + + {%- else %} + {{- raise_exception('Only user, systemm assistant and tool roles are supported in the custom template made by Unsloth!') }} + {%- endif %} +{%- endfor %} +{#- Copyright 2025-present the Unsloth team. All rights reserved. #} +{#- Licensed under the Apache License, Version 2.0 (the "License") #} \ No newline at end of file diff --git a/requirements/requirements-convert_hf_to_gguf.txt b/requirements/requirements-convert_hf_to_gguf.txt index 431c596c12354..7a8fa7a6280ad 100644 --- a/requirements/requirements-convert_hf_to_gguf.txt +++ b/requirements/requirements-convert_hf_to_gguf.txt @@ -1,3 +1,5 @@ +mistral-common>=1.8.0 + -r ./requirements-convert_legacy_llama.txt --extra-index-url https://download.pytorch.org/whl/cpu torch~=2.2.1; platform_machine != "s390x" diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 62c936ed00f77..c8822dcf5c34c 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -131,6 +131,7 @@ enum projector_type { PROJECTOR_TYPE_LLAMA4, PROJECTOR_TYPE_QWEN2A, PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx + PROJECTOR_TYPE_VOXTRAL, PROJECTOR_TYPE_UNKNOWN, }; @@ -150,6 +151,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LLAMA4, "llama4"}, { PROJECTOR_TYPE_QWEN2A, "qwen2a"}, { PROJECTOR_TYPE_QWEN25O, "qwen2.5o"}, + { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index be191404cfc75..d96293f18fe6d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1484,20 +1484,7 @@ struct clip_graph { cb(cur, "after_transformer", -1); if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) { - // StackAudioFrames - // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py - { - int64_t stride = n_embd * hparams.proj_stack_factor; - int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride); - int64_t pad = padded_len - ggml_nelements(cur); - if (pad > 0) { - cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0); - cur = ggml_pad(ctx0, cur, pad, 0, 0, 0); - } - cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, - ggml_row_size(cur->type, stride), 0); - } - + cur = build_whisper_stack_audio_frames(cur); cb(cur, "after_stacked", -1); // UltravoxProjector @@ -1526,6 +1513,13 @@ struct clip_graph { cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur); cur = ggml_add(ctx0, cur, model.mm_fc_b); + } else if (ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL) { + cur = build_whisper_stack_audio_frames(cur); + cb(cur, "after_stacked", -1); + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + cur = ggml_relu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); + } else { GGML_ABORT("%s: unknown projector type", __func__); } @@ -1537,6 +1531,21 @@ struct clip_graph { return gf; } + ggml_tensor * build_whisper_stack_audio_frames(ggml_tensor * cur) { + // StackAudioFrames + // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py + int64_t stride = n_embd * hparams.proj_stack_factor; + int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride); + int64_t pad = padded_len - ggml_nelements(cur); + if (pad > 0) { + cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0); + cur = ggml_pad(ctx0, cur, pad, 0, 0, 0); + } + cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, + ggml_row_size(cur->type, stride), 0); + return cur; + } + private: // // utility functions @@ -1671,7 +1680,7 @@ struct clip_graph { } // TODO @ngxson : find a way to move this outside - if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) { + if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A || ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL) { ggml_tensor * cur = inpL; cur = ggml_transpose(ctx0, cur); cur = ggml_cont(ctx0, cur); @@ -1985,6 +1994,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 res = graph.build_llama4(); } break; case PROJECTOR_TYPE_ULTRAVOX: + case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_QWEN2A: { res = graph.build_whisper_enc(); @@ -2259,8 +2269,10 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_VOXTRAL: { - bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX; + bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX || + model.proj_type == PROJECTOR_TYPE_VOXTRAL; get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack); if (hparams.n_mel_bins != 128) { throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__)); @@ -2544,6 +2556,15 @@ struct clip_model_loader { model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight")); model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias")); } break; + case PROJECTOR_TYPE_VOXTRAL: + { + model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); + model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); + model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); + model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); + model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); + model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight")); + } break; case PROJECTOR_TYPE_INTERNVL: { model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); @@ -3570,11 +3591,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im int scale_factor = ctx->model.hparams.proj_scale_factor; n_patches_sq /= (scale_factor * scale_factor); } break; + case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_ULTRAVOX: { const int proj_stack_factor = ctx->model.hparams.proj_stack_factor; const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor); n_patches_sq = n_len / proj_stack_factor / 2; + + if (proj == PROJECTOR_TYPE_VOXTRAL) { + n_patches_sq /= 2; // divide by 2 because of nn.AvgPool1d(2, stride=2) + } } break; case PROJECTOR_TYPE_QWEN2A: { @@ -3986,6 +4012,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_ULTRAVOX: + case PROJECTOR_TYPE_VOXTRAL: { // do nothing } break; @@ -4086,6 +4113,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_IDEFICS3: return ctx->model.projection->ne[1]; case PROJECTOR_TYPE_ULTRAVOX: + case PROJECTOR_TYPE_VOXTRAL: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_INTERNVL: return ctx->model.mm_3_w->ne[1]; @@ -4132,7 +4160,8 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) { bool clip_has_whisper_encoder(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX - || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A; + || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A + || ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL; } bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { From 5fc3507d1d2a6074c0565414606a14699a25e5a3 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 24 Jul 2025 23:43:44 +0200 Subject: [PATCH 02/10] clean up --- tools/mtmd/clip.cpp | 68 +++++++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d96293f18fe6d..7ea6eb91219cd 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -354,6 +354,16 @@ struct clip_model { ggml_tensor * conv1d_2_b = nullptr; ggml_tensor * mm_norm_pre_w = nullptr; ggml_tensor * mm_norm_mid_w = nullptr; + + bool audio_has_avgpool() const { + return proj_type == PROJECTOR_TYPE_QWEN2A + || proj_type == PROJECTOR_TYPE_VOXTRAL; + } + + bool audio_has_stack_frames() const { + return proj_type == PROJECTOR_TYPE_ULTRAVOX + || proj_type == PROJECTOR_TYPE_VOXTRAL; + } }; struct clip_ctx { @@ -1483,10 +1493,22 @@ struct clip_graph { cb(cur, "after_transformer", -1); - if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) { - cur = build_whisper_stack_audio_frames(cur); + if (model.audio_has_stack_frames()) { + // StackAudioFrames + // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py + int64_t stride = n_embd * hparams.proj_stack_factor; + int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride); + int64_t pad = padded_len - ggml_nelements(cur); + if (pad > 0) { + cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0); + cur = ggml_pad(ctx0, cur, pad, 0, 0, 0); + } + cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, + ggml_row_size(cur->type, stride), 0); cb(cur, "after_stacked", -1); + } + if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) { // UltravoxProjector { // pre-norm @@ -1514,7 +1536,7 @@ struct clip_graph { cur = ggml_add(ctx0, cur, model.mm_fc_b); } else if (ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL) { - cur = build_whisper_stack_audio_frames(cur); + // projector cb(cur, "after_stacked", -1); cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); cur = ggml_relu(ctx0, cur); @@ -1531,21 +1553,6 @@ struct clip_graph { return gf; } - ggml_tensor * build_whisper_stack_audio_frames(ggml_tensor * cur) { - // StackAudioFrames - // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py - int64_t stride = n_embd * hparams.proj_stack_factor; - int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride); - int64_t pad = padded_len - ggml_nelements(cur); - if (pad > 0) { - cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0); - cur = ggml_pad(ctx0, cur, pad, 0, 0, 0); - } - cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, - ggml_row_size(cur->type, stride), 0); - return cur; - } - private: // // utility functions @@ -1679,8 +1686,7 @@ struct clip_graph { inpL = cur; } - // TODO @ngxson : find a way to move this outside - if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A || ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL) { + if (ctx->model.audio_has_avgpool()) { ggml_tensor * cur = inpL; cur = ggml_transpose(ctx0, cur); cur = ggml_cont(ctx0, cur); @@ -3593,21 +3599,23 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im } break; case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_ULTRAVOX: + case PROJECTOR_TYPE_QWEN2A: { + // whisper downscales input token by half after conv1d + n_patches_sq = img->nx / 2; + const int proj_stack_factor = ctx->model.hparams.proj_stack_factor; - const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor); - n_patches_sq = n_len / proj_stack_factor / 2; + if (ctx->model.audio_has_stack_frames()) { + GGML_ASSERT(proj_stack_factor > 0); + const int n_len = CLIP_ALIGN(n_patches_sq, proj_stack_factor); + n_patches_sq = n_len / proj_stack_factor; + } - if (proj == PROJECTOR_TYPE_VOXTRAL) { - n_patches_sq /= 2; // divide by 2 because of nn.AvgPool1d(2, stride=2) + if (ctx->model.audio_has_avgpool()) { + // divide by 2 because of nn.AvgPool1d(2, stride=2) + n_patches_sq /= 2; } } break; - case PROJECTOR_TYPE_QWEN2A: - { - // divide by 2 because of whisper - // another divide by 2 because of nn.AvgPool1d(2, stride=2) - n_patches_sq = img->nx / 4; - } break; default: GGML_ABORT("unsupported projector type"); } From 2da31eddc25e362249f5eac095c9200bda1a0dd9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 24 Jul 2025 23:45:15 +0200 Subject: [PATCH 03/10] fix python requirements --- requirements/requirements-pydantic.txt | 2 +- tools/mtmd/clip.cpp | 2 +- tools/mtmd/requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements/requirements-pydantic.txt b/requirements/requirements-pydantic.txt index bdd423e07ea36..67d4c1e557d77 100644 --- a/requirements/requirements-pydantic.txt +++ b/requirements/requirements-pydantic.txt @@ -1,3 +1,3 @@ docstring_parser~=0.15 -pydantic~=2.6.3 +pydantic~=2.11.7 requests diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 7ea6eb91219cd..41248f3cfbd67 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2277,7 +2277,7 @@ struct clip_model_loader { case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_VOXTRAL: { - bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX || + bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX || model.proj_type == PROJECTOR_TYPE_VOXTRAL; get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack); if (hparams.n_mel_bins != 128) { diff --git a/tools/mtmd/requirements.txt b/tools/mtmd/requirements.txt index cbcbf26c9b4e9..ad069f774566f 100644 --- a/tools/mtmd/requirements.txt +++ b/tools/mtmd/requirements.txt @@ -1,5 +1,5 @@ -r ../../requirements/requirements-convert_legacy_llama.txt --extra-index-url https://download.pytorch.org/whl/cpu -pillow~=10.2.0 +pillow~=11.3.0 torch~=2.2.1 torchvision~=0.17.1 From 97119dd7351e9c947dea0053dbfe4e605e25a7b4 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 27 Jul 2025 23:43:13 +0200 Subject: [PATCH 04/10] add [BEGIN_AUDIO] token --- tools/mtmd/mtmd.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index e3829738338c3..45b2f1f251742 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -289,6 +289,10 @@ struct mtmd_context { aud_beg = "<|audio_bos|>"; aud_end = "<|audio_eos|>"; + } else if (proj == PROJECTOR_TYPE_ULTRAVOX) { + // [BEGIN_AUDIO] ... (embeddings) ... + aud_beg = "[BEGIN_AUDIO]"; + } } From b828887ae24191f0a9dd70049c206e5f69198ff0 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 27 Jul 2025 23:50:27 +0200 Subject: [PATCH 05/10] also support Devstral conversion --- convert_hf_to_gguf.py | 114 ++++++++++++++++++++---------------------- gguf-py/gguf/vocab.py | 2 +- 2 files changed, 56 insertions(+), 60 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ee4dfbc487bc8..e27318be3aad0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1900,6 +1900,7 @@ def prepare_tensors(self): "MixtralForCausalLM", "VLlama3ForCausalLM", "LlavaForConditionalGeneration", + "VoxtralForConditionalGeneration", "LlamaModel") class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA @@ -1912,6 +1913,11 @@ def __init__(self, *args, **kwargs): self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) def set_vocab(self): + path_tekken_json = self.dir_model / "tekken.json" + path_tokenizer_json = self.dir_model / "tokenizer.json" + if path_tekken_json.is_file() and not path_tokenizer_json.is_file(): + return self.set_vocab_tekken() + try: self._set_vocab_sentencepiece() except FileNotFoundError: @@ -1944,6 +1950,52 @@ def set_vocab(self): if self.hparams.get("vocab_size", 32000) == 49152: self.gguf_writer.add_add_bos_token(False) + def set_vocab_tekken(self): + vocab = gguf.vocab.MistralVocab(self.dir_model) + self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model) + + tokens = [] + scores = [] + toktypes = [] + + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + assert len(tokens) == vocab.vocab_size, ( + f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})" + ) + + if vocab.tokenizer_type == gguf.vocab.MistralTokenizerType.tekken: + self.gguf_writer.add_tokenizer_pre("tekken") + self.gguf_writer.add_token_merges( + vocab.extract_vocab_merges_from_model() + ) + + logger.info( + f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}." + ) + + self.gguf_writer.add_bos_token_id(vocab.bos_id) + self.gguf_writer.add_eos_token_id(vocab.eos_id) + self.gguf_writer.add_unk_token_id(vocab.unk_id) + self.gguf_writer.add_pad_token_id(vocab.pad_id) + + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_vocab_size(vocab.vocab_size) + + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(False) + + script_dir = Path(__file__).parent + template_path = script_dir / "models/templates/unsloth-mistral-Devstral-Small-2507.jinja" + with open(template_path, "r", encoding="utf-8") as f: + template = f.read() + self.gguf_writer.add_chat_template(template) + def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams @@ -1971,12 +2023,13 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") - is_vision_tensor = "vision_tower" in name \ + is_multimodal_tensor = "vision_tower" in name \ or "vision_model" in name \ + or "audio_tower" in name \ or "model.connector" in name \ or "multi_modal_projector" in name - if is_vision_tensor: + if is_multimodal_tensor: return [] # skip vision tensors elif self.hf_arch == "LlamaModel": name = "model." + name @@ -2260,63 +2313,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): return super().modify_tensors(data_torch, name, bid) -@ModelBase.register("VoxtralForConditionalGeneration") -class VoxtralModel(LlamaModel): - model_arch = gguf.MODEL_ARCH.LLAMA - - def set_vocab(self): - vocab = gguf.vocab.MistralVocab(self.dir_model) - self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model) - - tokens = [] - scores = [] - toktypes = [] - - for text, score, toktype in vocab.all_tokens(): - tokens.append(text) - scores.append(score) - toktypes.append(toktype) - - assert len(tokens) == vocab.vocab_size, ( - f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})" - ) - - if vocab.tokenizer_type == gguf.vocab.MistralTokenizerType.tekken: - self.gguf_writer.add_tokenizer_pre("tekken") - self.gguf_writer.add_token_merges( - vocab.extract_vocab_merges_from_model() - ) - - logger.info( - f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}." - ) - - self.gguf_writer.add_bos_token_id(vocab.bos_id) - self.gguf_writer.add_eos_token_id(vocab.eos_id) - self.gguf_writer.add_unk_token_id(vocab.unk_id) - self.gguf_writer.add_pad_token_id(vocab.pad_id) - - self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_scores(scores) - self.gguf_writer.add_token_types(toktypes) - self.gguf_writer.add_vocab_size(vocab.vocab_size) - - self.gguf_writer.add_add_bos_token(True) - self.gguf_writer.add_add_eos_token(False) - - script_dir = Path(__file__).parent - template_path = script_dir / "models/templates/unsloth-mistral-Devstral-Small-2507.jinja" - with open(template_path, "r", encoding="utf-8") as f: - template = f.read() - self.gguf_writer.add_chat_template(template) - - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): - name = name.replace("language_model.", "") - if "multi_modal_projector" in name or "audio_tower" in name: - return [] - return super().modify_tensors(data_torch, name, bid) - - @ModelBase.register("DeciLMForCausalLM") class DeciModel(TextModel): model_arch = gguf.MODEL_ARCH.DECI diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 6c876c8565f76..0437cbf71bc09 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -18,7 +18,7 @@ from mistral_common.tokens.tokenizers.tekken import Tekkenizer # from mistral_common.tokens.tokenizers.utils import ( # _filter_valid_tokenizer_files, - # ) + # ) # FIXME: this function is removed in newer versions of mistral_common from mistral_common.tokens.tokenizers.sentencepiece import ( SentencePieceTokenizer, ) From 738be198482b540ef707f5e249198a7ca178618f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 27 Jul 2025 23:55:06 +0200 Subject: [PATCH 06/10] add docs and tests --- docs/multimodal.md | 3 +++ tools/mtmd/tests.sh | 1 + 2 files changed, 4 insertions(+) diff --git a/docs/multimodal.md b/docs/multimodal.md index edbd081df7969..e2e12d07df11c 100644 --- a/docs/multimodal.md +++ b/docs/multimodal.md @@ -97,6 +97,9 @@ NOTE: some models may require large context window, for example: `-c 8192` # Qwen2-Audio and SeaLLM-Audio # note: no pre-quantized GGUF this model, as they have very poor result # ref: https://github.com/ggml-org/llama.cpp/pull/13760 + +# Mistral's Voxtral +(tool_name) -hf ggml-org/Voxtral-Mini-3B-2507-GGUF ``` **Mixed modalities**: diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index b25024c2f10ef..e73cf96af2941 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -71,6 +71,7 @@ add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" +add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M" # to test the big models, run: ./tests.sh big if [ "$RUN_BIG_TESTS" = true ]; then From 8b2d72dae54e2d6247b7bd34ea7e511f031c274a Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 28 Jul 2025 00:06:51 +0200 Subject: [PATCH 07/10] fix regression for ultravox --- tools/mtmd/clip.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d36a7d7e55269..863e07dac744a 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -3601,8 +3601,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: { - // whisper downscales input token by half after conv1d - n_patches_sq = img->nx / 2; + n_patches_sq = img->nx; const int proj_stack_factor = ctx->model.hparams.proj_stack_factor; if (ctx->model.audio_has_stack_frames()) { @@ -3611,6 +3610,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches_sq = n_len / proj_stack_factor; } + // whisper downscales input token by half after conv1d + n_patches_sq /= 2; + if (ctx->model.audio_has_avgpool()) { // divide by 2 because of nn.AvgPool1d(2, stride=2) n_patches_sq /= 2; From 01bf68722ed110775d62dd7ad4da488591ec9997 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 28 Jul 2025 00:10:10 +0200 Subject: [PATCH 08/10] minor coding style improvement --- tools/mtmd/clip.cpp | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 863e07dac744a..f15e95b1fc26c 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1510,25 +1510,23 @@ struct clip_graph { if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) { // UltravoxProjector - { - // pre-norm - cur = ggml_rms_norm(ctx0, cur, 1e-6); - cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); + // pre-norm + cur = ggml_rms_norm(ctx0, cur, 1e-6); + cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); - // ffn in - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + // ffn in + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - // swiglu - // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half - cur = ggml_swiglu_swapped(ctx0, cur); + // swiglu + // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half + cur = ggml_swiglu_swapped(ctx0, cur); - // mid-norm - cur = ggml_rms_norm(ctx0, cur, 1e-6); - cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w); + // mid-norm + cur = ggml_rms_norm(ctx0, cur, 1e-6); + cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w); - // ffn out - cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); - } + // ffn out + cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) { // projector From 4556b40370fccd6f7cfabd79dc9408e544b5cf70 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 28 Jul 2025 10:49:21 +0200 Subject: [PATCH 09/10] correct project activation fn --- tools/mtmd/clip.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f15e95b1fc26c..a4b62f9afe3bf 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1535,9 +1535,8 @@ struct clip_graph { } else if (ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL) { // projector - cb(cur, "after_stacked", -1); cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - cur = ggml_relu(ctx0, cur); + cur = ggml_gelu_erf(ctx0, cur); cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); } else { From 8c543f7ca7f065f75e6356d2dfcfcc907c0a0cc2 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 28 Jul 2025 12:47:07 +0200 Subject: [PATCH 10/10] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- gguf-py/gguf/vocab.py | 7 +++---- requirements/requirements-convert_hf_to_gguf.txt | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 0437cbf71bc09..e1d5aaf47ac46 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -16,9 +16,9 @@ try: from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.tekken import Tekkenizer - # from mistral_common.tokens.tokenizers.utils import ( - # _filter_valid_tokenizer_files, - # ) # FIXME: this function is removed in newer versions of mistral_common + from mistral_common.tokens.tokenizers.utils import ( + _filter_valid_tokenizer_files, + ) from mistral_common.tokens.tokenizers.sentencepiece import ( SentencePieceTokenizer, ) @@ -30,7 +30,6 @@ _filter_valid_tokenizer_files = None else: _mistral_common_installed = True - _filter_valid_tokenizer_files = lambda x: x # noqa: E731 import gguf diff --git a/requirements/requirements-convert_hf_to_gguf.txt b/requirements/requirements-convert_hf_to_gguf.txt index 7a8fa7a6280ad..fd21ec479541f 100644 --- a/requirements/requirements-convert_hf_to_gguf.txt +++ b/requirements/requirements-convert_hf_to_gguf.txt @@ -1,4 +1,4 @@ -mistral-common>=1.8.0 +mistral-common>=1.8.3 -r ./requirements-convert_legacy_llama.txt --extra-index-url https://download.pytorch.org/whl/cpu